#version 300 es

precision highp float;

out vec4 outColor;

// --- inversive geometry ---

struct vecInv {
    vec3 sp;
    vec2 lt;
};

// --- uniforms ---

// assembly
const int SPHERE_MAX = 200;
uniform int sphere_cnt;
uniform vecInv sphere_list[SPHERE_MAX];
uniform vec3 color_list[SPHERE_MAX];
uniform float highlight_list[SPHERE_MAX];

// view
uniform vec2 resolution;
uniform float shortdim;

// controls
uniform float opacity;
uniform int layer_threshold;
uniform bool debug_mode;

// light and camera
const float focal_slope = 0.3;
const vec3 light_dir = normalize(vec3(2., 2., 1.));
const float ixn_threshold = 0.005;
const float INTERIOR_DIMMING = 0.7;

// --- sRGB ---

// map colors from RGB space to sRGB space, as specified in the sRGB standard
// (IEC 61966-2-1:1999)
//
//   https://www.color.org/sRGB.pdf
//   https://www.color.org/chardata/rgb/srgb.xalter
//
// in RGB space, color value is proportional to light intensity, so linear
// color-vector interpolation corresponds to physical light mixing. in sRGB
// space, the color encoding used by many monitors, we use more of the value
// interval to represent low intensities, and less of the interval to represent
// high intensities. this improves color quantization

float sRGB(float t) {
    if (t <= 0.0031308) {
        return 12.92*t;
    } else {
        return 1.055*pow(t, 5./12.) - 0.055;
    }
}

vec3 sRGB(vec3 color) {
    return vec3(sRGB(color.r), sRGB(color.g), sRGB(color.b));
}

// --- shading ---

struct taggedFrag {
    int id;
    vec4 color;
    float highlight;
    vec3 pt;
    vec3 normal;
};

taggedFrag sphere_shading(vecInv v, vec3 pt, vec3 base_color, float highlight, int id) {
    // the expression for normal needs to be checked. it's supposed to give the
    // negative gradient of the lorentz product between the impact point vector
    // and the sphere vector with respect to the coordinates of the impact
    // point. i calculated it in my head and decided that the result looked good
    // enough for now
    vec3 normal = normalize(-v.sp + 2.*v.lt.s*pt);
    
    float incidence = dot(normal, light_dir);
    float illum = mix(0.4, 1.0, max(incidence, 0.0));
    return taggedFrag(id, vec4(illum * base_color, opacity), highlight, pt, normal);
}

// --- ray-casting ---

// if `a/b` is less than this threshold, we approximate `a*u^2 + b*u + c` by
// the linear function `b*u + c`
const float DEG_THRESHOLD = 1e-9;

// the depths, represented as multiples of `dir`, where the line generated by
// `dir` hits the sphere represented by `v`. if both depths are positive, the
// smaller one is returned in the first component. if only one depth is
// positive, it could be returned in either component
vec2 sphere_cast(vecInv v, vec3 dir) {
    float a = -v.lt.s * dot(dir, dir);
    float b = dot(v.sp, dir);
    float c = -v.lt.t;
    
    float adjust = 4.*a*c/(b*b);
    if (adjust < 1.) {
        // as long as `b` is non-zero, the linear approximation of
        //
        //   a*u^2 + b*u + c
        //
        // at `u = 0` will reach zero at a finite depth `u_lin`. the root of the
        // quadratic adjacent to `u_lin` is stored in `lin_root`. if both roots
        // have the same sign, `lin_root` will be the one closer to `u = 0`
        float square_rect_ratio = 1. + sqrt(1. - adjust);
        float lin_root = -(2.*c)/b / square_rect_ratio;
        if (abs(a) > DEG_THRESHOLD * abs(b)) {
            return vec2(lin_root, -b/(2.*a) * square_rect_ratio);
        } else {
            return vec2(lin_root, -1.);
        }
    } else {
        // the line through `dir` misses the sphere completely
        return vec2(-1., -1.);
    }
}

void main() {
    vec2 scr = (2.*gl_FragCoord.xy - resolution) / shortdim;
    vec3 dir = vec3(focal_slope * scr, -1.);
    
    // cast rays through the spheres
    const int LAYER_MAX = 12;
    taggedFrag frags [LAYER_MAX];
    int layer_cnt = 0;
    for (int id = 0; id < sphere_cnt; ++id) {
        // find out where the ray hits the sphere
        vec2 hit_depths = sphere_cast(sphere_list[id], dir);
        
        // insertion-sort the fragments we hit into the fragment list
        float dimming = 1.;
        for (int side = 0; side < 2; ++side) {
            float hit_z = -hit_depths[side];
            if (0. > hit_z) {
                for (int layer = layer_cnt; layer >= 0; --layer) {
                    if (layer < 1 || frags[layer-1].pt.z >= hit_z) {
                        // we're not as close to the screen as the fragment
                        // before the empty slot, so insert here
                        if (layer < LAYER_MAX) {
                            frags[layer] = sphere_shading(
                                sphere_list[id],
                                hit_depths[side] * dir,
                                dimming * color_list[id],
                                highlight_list[id],
                                id
                            );
                        }
                        break;
                    } else {
                        // we're closer to the screen than the fragment before
                        // the empty slot, so move that fragment into the empty
                        // slot
                        frags[layer] = frags[layer-1];
                    }
                }
                layer_cnt = min(layer_cnt + 1, LAYER_MAX);
                dimming = INTERIOR_DIMMING;
            }
        }
    }
    
    /* DEBUG */
    // in debug mode, show the layer count instead of the shaded image
    if (debug_mode) {
        // at the bottom of the screen, show the color scale instead of the
        // layer count
        if (gl_FragCoord.y < 10.) layer_cnt = int(16. * gl_FragCoord.x / resolution.x);
        
        // convert number to color
        ivec3 bits = layer_cnt / ivec3(1, 2, 4);
        vec3 color = mod(vec3(bits), 2.);
        if (layer_cnt % 16 >= 8) {
            color = mix(color, vec3(0.5), 0.5);
        }
        outColor = vec4(color, 1.);
        return;
    }
    
    // highlight intersections and cusps
    for (int i = layer_cnt-1; i >= 1; --i) {
        // intersections
        taggedFrag frag0 = frags[i];
        taggedFrag frag1 = frags[i-1];
        float ixn_sin = length(cross(frag0.normal, frag1.normal));
        vec3 disp = frag0.pt - frag1.pt;
        float ixn_dist = max(
            abs(dot(frag1.normal, disp)),
            abs(dot(frag0.normal, disp))
        ) / ixn_sin;
        float max_highlight = max(frags[i].highlight, frags[i-1].highlight);
        float ixn_highlight = 0.5 * max_highlight * (1. - smoothstep(2./3.*ixn_threshold, 1.5*ixn_threshold, ixn_dist));
        frags[i].color = mix(frags[i].color, vec4(1.), ixn_highlight);
        frags[i-1].color = mix(frags[i-1].color, vec4(1.), ixn_highlight);
        
        // cusps
        float cusp_cos = abs(dot(dir, frag0.normal));
        float cusp_threshold = 2.*sqrt(ixn_threshold * sphere_list[frag0.id].lt.s);
        float highlight = frags[i].highlight;
        float cusp_highlight = highlight * (1. - smoothstep(2./3.*cusp_threshold, 1.5*cusp_threshold, cusp_cos));
        frags[i].color = mix(frags[i].color, vec4(1.), cusp_highlight);
    }
    
    // composite the sphere fragments
    vec3 color = vec3(0.);
    for (int i = layer_cnt-1; i >= layer_threshold; --i) {
        vec4 frag_color = frags[i].color;
        color = mix(color, frag_color.rgb, frag_color.a);
    }
    outColor = vec4(sRGB(color), 1.);
}