// based on the WebGL example in the `wasm-bindgen` guide
//
//   https://rustwasm.github.io/wasm-bindgen/examples/webgl.html
//
// and this StackOverflow answer by wangdq
//
//   https://stackoverflow.com/a/39684775
//

extern crate js_sys;
use sycamore::{prelude::*, rt::{JsCast, JsValue}};
use web_sys::{console, WebGl2RenderingContext, WebGlShader};

fn compile_shader(
    context: &WebGl2RenderingContext,
    shader_type: u32,
    source: &str,
) -> WebGlShader {
    let shader = context.create_shader(shader_type).unwrap();
    context.shader_source(&shader, source);
    context.compile_shader(&shader);
    shader
}

// load the given data into the vertex input of the given name
fn bind_vertex_attrib(
    context: &WebGl2RenderingContext,
    index: u32,
    size: i32,
    data: &[f32]
) {
    // create a data buffer and bind it to ARRAY_BUFFER
    let buffer = context.create_buffer().unwrap();
    context.bind_buffer(WebGl2RenderingContext::ARRAY_BUFFER, Some(&buffer));
    
    // load the given data into the buffer. the function `Float32Array::view`
    // creates a raw view into our module's `WebAssembly.Memory` buffer.
    // allocating more memory will change the buffer, invalidating the view.
    // that means we have to make sure we don't allocate any memory until the
    // view is dropped
    unsafe {
        context.buffer_data_with_array_buffer_view(
            WebGl2RenderingContext::ARRAY_BUFFER,
            &js_sys::Float32Array::view(&data),
            WebGl2RenderingContext::STATIC_DRAW,
        );
    }
    
    // allow the target attribute to be used
    context.enable_vertex_attrib_array(index);
    
    // take whatever's bound to ARRAY_BUFFER---here, the data buffer created
    // above---and bind it to the target attribute
    //
    //   https://developer.mozilla.org/en-US/docs/Web/API/WebGLRenderingContext/vertexAttribPointer
    //
    context.vertex_attrib_pointer_with_i32(
        index,
        size,
        WebGl2RenderingContext::FLOAT,
        false, // don't normalize
        0, // zero stride
        0, // zero offset
    );
}

fn main() {
    // set up a config option that forwards panic messages to `console.error`
    #[cfg(feature = "console_error_panic_hook")]
    console_error_panic_hook::set_once();
    
    sycamore::render(|| {
        let ctrl_x = create_signal(0.0);
        let ctrl_y = create_signal(0.0);
        let opacity = create_signal(0.5);
        let layer_threshold = create_signal(0.0);
        let display = create_node_ref();
        
        on_mount(move || {
            // get the display canvas
            let canvas = display
                .get::<DomNode>()
                .unchecked_into::<web_sys::HtmlCanvasElement>();
            let ctx = canvas
                .get_context("webgl2")
                .unwrap()
                .unwrap()
                .dyn_into::<WebGl2RenderingContext>()
                .unwrap();
            
            // compile and attach the vertex and fragment shaders
            let vertex_shader = compile_shader(
                &ctx,
                WebGl2RenderingContext::VERTEX_SHADER,
                r##"#version 300 es
                
                in vec4 position;
                
                void main() {
                    gl_Position = position;
                }
                "##,
            );
            let fragment_shader = compile_shader(
                &ctx,
                WebGl2RenderingContext::FRAGMENT_SHADER,
                r##"#version 300 es
            
                precision highp float;
                
                out vec4 outColor;
                
                // view
                uniform vec2 resolution;
                uniform float shortdim;
                
                // controls
                uniform vec2 ctrl;
                uniform float opacity;
                uniform int layer_threshold;
                
                // light and camera
                const float focal_slope = 0.3;
                const vec3 light_dir = normalize(vec3(2., 2., 1.));
                
                // --- sRGB ---
                
                // map colors from RGB space to sRGB space, as specified in the
                // sRGB standard (IEC 61966-2-1:1999)
                //
                //   https://www.color.org/sRGB.pdf
                //   https://www.color.org/chardata/rgb/srgb.xalter
                //
                // in RGB space, color value is proportional to light intensity,
                // so linear color-vector interpolation corresponds to physical
                // light mixing. in sRGB space, the color encoding used by many
                // monitors, we use more of the value interval to represent low
                // intensities, and less of the interval to represent high
                // intensities. this improves color quantization
                
                float sRGB(float t) {
                    if (t <= 0.0031308) {
                        return 12.92*t;
                    } else {
                        return 1.055*pow(t, 5./12.) - 0.055;
                    }
                }
                
                vec3 sRGB(vec3 color) {
                    return vec3(sRGB(color.r), sRGB(color.g), sRGB(color.b));
                }
                
                // --- inversive geometry ---
                
                struct vecInv {
                    vec3 sp;
                    vec2 lt;
                };
                
                vecInv sphere(vec3 center, float radius) {
                    return vecInv(
                        center / radius,
                        vec2(
                            0.5 / radius,
                            0.5 * (dot(center, center) / radius - radius)
                        )
                    );
                }
                
                // --- shading ---
                
                struct taggedFrag {
                    vec4 color;
                    float depth;
                };
                
                taggedFrag[2] sort(taggedFrag a, taggedFrag b) {
                    taggedFrag[2] result;
                    if (a.depth < b.depth) {
                        result[0] = a;
                        result[1] = b;
                    } else {
                        result[0] = b;
                        result[1] = a;
                    }
                    return result;
                }
                
                taggedFrag sphere_shading(vecInv v, vec3 pt, vec3 base_color) {
                    // the expression for normal needs to be checked. it's
                    // supposed to give the negative gradient of the lorentz
                    // product between the impact point vector and the sphere
                    // vector with respect to the coordinates of the impact
                    // point. i calculated it in my head and decided that
                    // the result looked good enough for now
                    vec3 normal = normalize(-v.sp + 2.*v.lt.s*pt);
                    
                    float incidence = dot(normal, light_dir);
                    float illum = mix(0.4, 1.0, max(incidence, 0.0));
                    return taggedFrag(vec4(illum * base_color, opacity), -pt.z);
                }
                
                // --- ray-casting ---
                
                vec2 sphere_cast(vecInv v, vec3 dir) {
                    float a = -v.lt.s * dot(dir, dir);
                    float b = dot(v.sp, dir);
                    float c = -v.lt.t;
                    
                    float scale = -b/(2.*a);
                    float adjust = 4.*a*c/(b*b);
                    
                    if (adjust < 1.) {
                        float offset = sqrt(1. - adjust);
                        return vec2(
                            scale * (1. - offset),
                            scale * (1. + offset)
                        );
                    } else {
                        // these parameters describe points behind the camera,
                        // so the corresponding fragments won't be drawn
                        return vec2(-1., -1.);
                    }
                }
                
                void main() {
                    vec2 scr = (2.*gl_FragCoord.xy - resolution) / shortdim;
                    vec3 dir = vec3(focal_slope * scr, -1.);
                    
                    // initialize two spheres
                    vecInv v0 = sphere(vec3(0.5, 0.5, -5. + ctrl.x), 1.);
                    vecInv v1 = sphere(vec3(-0.5, -0.5, -5. + ctrl.y), 1.);
                    vec3 color0 = vec3(1., 0.214, 0.);
                    vec3 color1 = vec3(0., 0.214, 1.);
                    
                    // cast rays through the spheres
                    vec2 u0 = sphere_cast(v0, dir);
                    vec2 u1 = sphere_cast(v1, dir);
                    
                    // shade and depth-sort the impact points
                    taggedFrag front_hits[2] = sort(
                        sphere_shading(v0, u0[0] * dir, color0),
                        sphere_shading(v1, u1[0] * dir, color1)
                    );
                    taggedFrag back_hits[2] = sort(
                        sphere_shading(v0, u0[1] * dir, color0),
                        sphere_shading(v1, u1[1] * dir, color1)
                    );
                    taggedFrag middle_frags[2] = sort(front_hits[1], back_hits[0]);
                    
                    // finish depth sorting
                    taggedFrag frags_by_depth[4];
                    frags_by_depth[0] = front_hits[0];
                    frags_by_depth[1] = middle_frags[0];
                    frags_by_depth[2] = middle_frags[1];
                    frags_by_depth[3] = back_hits[1];
                    
                    // composite the sphere fragments
                    vec3 color = vec3(0.);
                    for (int i = 3; i >= layer_threshold; --i) {
                        if (frags_by_depth[i].depth > 0.) {
                            vec4 frag_color = frags_by_depth[i].color;
                            color = mix(color, frag_color.rgb, frag_color.a);
                        }
                    }
                    outColor = vec4(sRGB(color), 1.);
                }
                "##,
            );
            let program = ctx.create_program().unwrap();
            ctx.attach_shader(&program, &vertex_shader);
            ctx.attach_shader(&program, &fragment_shader);
            ctx.link_program(&program);
            let link_status = ctx
                .get_program_parameter(&program, WebGl2RenderingContext::LINK_STATUS)
                .as_bool()
                .unwrap();
            let link_msg = if link_status {
                "Linked successfully"
            } else {
                "Linking failed"
            };
            console::log_1(&JsValue::from(link_msg));
            ctx.use_program(Some(&program));
            
            // find indices of vertex attributes and uniforms
            let position_index = ctx.get_attrib_location(&program, "position") as u32;
            let resolution_loc = ctx.get_uniform_location(&program, "resolution");
            let shortdim_loc = ctx.get_uniform_location(&program, "shortdim");
            let ctrl_loc = ctx.get_uniform_location(&program, "ctrl");
            let opacity_loc = ctx.get_uniform_location(&program, "opacity");
            let layer_threshold_loc = ctx.get_uniform_location(&program, "layer_threshold");
            
            // create a vertex array and bind it to the graphics context
            let vertex_array = ctx.create_vertex_array().unwrap();
            ctx.bind_vertex_array(Some(&vertex_array));
            
            // set the vertex positions
            const VERTEX_CNT: usize = 6;
            let positions: [f32; 3*VERTEX_CNT] = [
                // northwest triangle
                -1.0, -1.0, 0.0,
                -1.0,  1.0, 0.0,
                 1.0,  1.0, 0.0,
                // southeast triangle
                -1.0, -1.0, 0.0,
                 1.0,  1.0, 0.0,
                 1.0, -1.0, 0.0
            ];
            bind_vertex_attrib(&ctx, position_index, 3, &positions);
            
            // set up a repainting routine
            create_effect(move || {
                // set the resolution
                let width = canvas.width() as f32;
                let height = canvas.height() as f32;
                ctx.uniform2f(resolution_loc.as_ref(), width, height);
                ctx.uniform1f(shortdim_loc.as_ref(), width.min(height));
                
                // pass the control parameters
                ctx.uniform2f(ctrl_loc.as_ref(), ctrl_x.get() as f32, ctrl_y.get() as f32);
                ctx.uniform1f(opacity_loc.as_ref(), opacity.get() as f32);
                ctx.uniform1i(layer_threshold_loc.as_ref(), layer_threshold.get() as i32);
                
                // clear the screen and draw the scene
                ctx.clear_color(0.0, 0.0, 0.0, 1.0);
                ctx.clear(WebGl2RenderingContext::COLOR_BUFFER_BIT);
                ctx.draw_arrays(WebGl2RenderingContext::TRIANGLES, 0, VERTEX_CNT as i32);
            });
        });
        
        view! {
            div(id="app") {
                canvas(ref=display, width="600", height="600")
                input(
                    type="range",
                    min=-1.0,
                    max=1.0,
                    step=0.001,
                    bind:valueAsNumber=ctrl_x
                )
                input(
                    type="range",
                    min=-1.0,
                    max=1.0,
                    step=0.001,
                    bind:valueAsNumber=ctrl_y
                )
                input(
                    type="range",
                    max=1.0,
                    step=0.001,
                    bind:valueAsNumber=opacity
                )
                input(
                    type="range",
                    max=3.0,
                    step=1.0,
                    bind:valueAsNumber=layer_threshold
                )
            }
        }
    });
}