module Engine

using LinearAlgebra
using GenericLinearAlgebra
using SparseArrays
using Random
using Optim

export
  rand_on_shell, Q, DescentHistory,
  realize_gram_gradient, realize_gram_newton, realize_gram_optim,
  realize_gram_alt_proj, realize_gram

# === guessing ===

sconh(t, u) = 0.5*(exp(t) + u*exp(-t))

function rand_on_sphere(rng::AbstractRNG, ::Type{T}, n) where T
  out = randn(rng, T, n)
  tries_left = 2
  while dot(out, out) < 1e-6 && tries_left > 0
    out = randn(rng, T, n)
    tries_left -= 1
  end
  normalize(out)
end

##[TO DO] write a test to confirm that the outputs are on the correct shells
function rand_on_shell(rng::AbstractRNG, shell::T) where T <: Number
  space_part = rand_on_sphere(rng, T, 4)
  rapidity = randn(rng, T)
  sig = sign(shell)
  nullmix * [sconh(rapidity, sig)*space_part; sconh(rapidity, -sig)]
end

rand_on_shell(rng::AbstractRNG, shells::Array{T}) where T <: Number =
  hcat([rand_on_shell(rng, sh) for sh in shells]...)

rand_on_shell(shells::Array{<:Number}) = rand_on_shell(Random.default_rng(), shells)

# === elements ===

point(pos) = [pos; 0.5; 0.5 * dot(pos, pos)]

plane(normal, offset) = [-normal; 0; -offset]

function sphere(center, radius)
  dist_sq = dot(center, center)
  [
    center / radius;
    0.5 / radius;
    0.5 * (dist_sq / radius - radius)
  ]
end

# === Gram matrix realization ===

# basis changes
nullmix = [Matrix{Int64}(I, 3, 3) zeros(Int64, 3, 2); zeros(Int64, 2, 3) [-1 1; 1 1]//2]
unmix = [Matrix{Int64}(I, 3, 3) zeros(Int64, 3, 2); zeros(Int64, 2, 3) [-1 1; 1 1]]

# the Lorentz form
Q = [Matrix{Int64}(I, 3, 3) zeros(Int64, 3, 2); zeros(Int64, 2, 3) [0 -2; -2 0]]

# project a matrix onto the subspace of matrices whose entries vanish away from
# the given indices
function proj_to_entries(mat, indices)
  result = zeros(size(mat))
  for (j, k) in indices
    result[j, k] = mat[j, k]
  end
  result
end

# the difference between the matrices `target` and `attempt`, projected onto the
# subspace of matrices whose entries vanish at each empty index of `target`
function proj_diff(target::SparseMatrixCSC{T, <:Any}, attempt::Matrix{T}) where T
  J, K, values = findnz(target)
  result = zeros(size(target))
  for (j, k, val) in zip(J, K, values)
    result[j, k] = val - attempt[j, k]
  end
  result
end

# a type for keeping track of gradient descent history
struct DescentHistory{T}
  scaled_loss::Array{T}
  neg_grad::Array{Matrix{T}}
  base_step::Array{Matrix{T}}
  hess::Array{Hermitian{T, Matrix{T}}}
  slope::Array{T}
  stepsize::Array{T}
  positive::Array{Bool}
  backoff_steps::Array{Int64}
  last_line_L::Array{Matrix{T}}
  last_line_loss::Array{T}
  
  function DescentHistory{T}(
    scaled_loss = Array{T}(undef, 0),
    neg_grad = Array{Matrix{T}}(undef, 0),
    hess = Array{Hermitian{T, Matrix{T}}}(undef, 0),
    base_step = Array{Matrix{T}}(undef, 0),
    slope = Array{T}(undef, 0),
    stepsize = Array{T}(undef, 0),
    positive = Bool[],
    backoff_steps = Int64[],
    last_line_L = Array{Matrix{T}}(undef, 0),
    last_line_loss = Array{T}(undef, 0)
  ) where T
    new(scaled_loss, neg_grad, hess, base_step, slope, stepsize, positive, backoff_steps, last_line_L, last_line_loss)
  end
end

# seek a matrix `L` for which `L'QL` matches the sparse matrix `gram` at every
# explicit entry of `gram`. use gradient descent starting from `guess`
function realize_gram_gradient(
  gram::SparseMatrixCSC{T, <:Any},
  guess::Matrix{T};
  scaled_tol = 1e-30,
  min_efficiency = 0.5,
  init_stepsize = 1.0,
  backoff = 0.9,
  max_descent_steps = 600,
  max_backoff_steps = 110
) where T <: Number
  # start history
  history = DescentHistory{T}()
  
  # scale tolerance
  scale_adjustment = sqrt(T(nnz(gram)))
  tol = scale_adjustment * scaled_tol
  
  # initialize variables
  stepsize = init_stepsize
  L = copy(guess)
  
  # do gradient descent
  Δ_proj = proj_diff(gram, L'*Q*L)
  loss = dot(Δ_proj, Δ_proj)
  for _ in 1:max_descent_steps
    # stop if the loss is tolerably low
    if loss < tol
      break
    end
    
    # find the negative gradient of the loss function
    neg_grad = 4*Q*L*Δ_proj
    slope = norm(neg_grad)
    dir = neg_grad / slope
    
    # store current position, loss, and slope
    L_last = L
    loss_last = loss
    push!(history.scaled_loss, loss / scale_adjustment)
    push!(history.neg_grad, neg_grad)
    push!(history.slope, slope)
    
    # find a good step size using backtracking line search
    push!(history.stepsize, 0)
    push!(history.backoff_steps, max_backoff_steps)
    empty!(history.last_line_L)
    empty!(history.last_line_loss)
    for backoff_steps in 0:max_backoff_steps
      history.stepsize[end] = stepsize
      L = L_last + stepsize * dir
      Δ_proj = proj_diff(gram, L'*Q*L)
      loss = dot(Δ_proj, Δ_proj)
      improvement = loss_last - loss
      push!(history.last_line_L, L)
      push!(history.last_line_loss, loss / scale_adjustment)
      if improvement >= min_efficiency * stepsize * slope
        history.backoff_steps[end] = backoff_steps
        break
      end
      stepsize *= backoff
    end
    
    # [DEBUG] if we've hit a wall, quit
    if history.backoff_steps[end] == max_backoff_steps
      break
    end
  end
  
  # return the factorization and its history
  push!(history.scaled_loss, loss / scale_adjustment)
  L, history
end

function basis_matrix(::Type{T}, j, k, dims) where T
  result = zeros(T, dims)
  result[j, k] = one(T)
  result
end

# seek a matrix `L` for which `L'QL` matches the sparse matrix `gram` at every
# explicit entry of `gram`. use Newton's method starting from `guess`
function realize_gram_newton(
  gram::SparseMatrixCSC{T, <:Any},
  guess::Matrix{T};
  scaled_tol = 1e-30,
  rate = 1,
  max_steps = 100
) where T <: Number
  # start history
  history = DescentHistory{T}()
  
  # find the dimension of the search space
  dims = size(guess)
  element_dim, construction_dim = dims
  total_dim = element_dim * construction_dim
  
  # list the constrained entries of the gram matrix
  J, K, _ = findnz(gram)
  constrained = zip(J, K)
  
  # scale the tolerance
  scale_adjustment = sqrt(T(length(constrained)))
  tol = scale_adjustment * scaled_tol
  
  # use Newton's method
  L = copy(guess)
  for step in 0:max_steps
    # evaluate the loss function
    Δ_proj = proj_diff(gram, L'*Q*L)
    loss = dot(Δ_proj, Δ_proj)
    
    # store the current loss
    push!(history.scaled_loss, loss / scale_adjustment)
    
    # stop if the loss is tolerably low
    if loss < tol || step > max_steps
      break
    end
    
    # find the negative gradient of the loss function
    neg_grad = 4*Q*L*Δ_proj
    
    # find the negative Hessian of the loss function
    hess = Matrix{T}(undef, total_dim, total_dim)
    indices = [(j, k) for k in 1:construction_dim for j in 1:element_dim]
    for (j, k) in indices
      basis_mat = basis_matrix(T, j, k, dims)
      neg_dΔ = basis_mat'*Q*L + L'*Q*basis_mat
      neg_dΔ_proj = proj_to_entries(neg_dΔ, constrained)
      deriv_grad = 4*Q*(-basis_mat*Δ_proj + L*neg_dΔ_proj)
      hess[:, (k-1)*element_dim + j] = reshape(deriv_grad, total_dim)
    end
    hess = Hermitian(hess)
    push!(history.hess, hess)
    
    # compute the Newton step
    step = hess \ reshape(neg_grad, total_dim)
    L += rate * reshape(step, dims)
  end
  
  # return the factorization and its history
  L, history
end

LinearAlgebra.eigen!(A::Symmetric{BigFloat, Matrix{BigFloat}}; sortby::Nothing) =
  eigen!(Hermitian(A))

function convertnz(type, mat)
  J, K, values = findnz(mat)
  sparse(J, K, type.(values))
end

function realize_gram_optim(
  gram::SparseMatrixCSC{T, <:Any},
  guess::Matrix{T}
) where T <: Number
  # find the dimension of the search space
  dims = size(guess)
  element_dim, construction_dim = dims
  total_dim = element_dim * construction_dim
  
  # list the constrained entries of the gram matrix
  J, K, _ = findnz(gram)
  constrained = zip(J, K)
  
  # scale the loss function
  scale_adjustment = length(constrained)
  
  function loss(L_vec)
    L = reshape(L_vec, dims)
    Δ_proj = proj_diff(gram, L'*Q*L)
    dot(Δ_proj, Δ_proj) / scale_adjustment
  end
  
  function loss_grad!(storage, L_vec)
    L = reshape(L_vec, dims)
    Δ_proj = proj_diff(gram, L'*Q*L)
    storage .= reshape(-4*Q*L*Δ_proj, total_dim) / scale_adjustment
  end
  
  function loss_hess!(storage, L_vec)
    L = reshape(L_vec, dims)
    Δ_proj = proj_diff(gram, L'*Q*L)
    indices = [(j, k) for k in 1:construction_dim for j in 1:element_dim]
    for (j, k) in indices
      basis_mat = basis_matrix(T, j, k, dims)
      neg_dΔ = basis_mat'*Q*L + L'*Q*basis_mat
      neg_dΔ_proj = proj_to_entries(neg_dΔ, constrained)
      deriv_grad = 4*Q*(-basis_mat*Δ_proj + L*neg_dΔ_proj) / scale_adjustment
      storage[:, (k-1)*element_dim + j] = reshape(deriv_grad, total_dim)
    end
  end
  
  optimize(
    loss, loss_grad!, loss_hess!,
    reshape(guess, total_dim),
    Newton()
  )
end

# seek a matrix `L` for which `L'QL` matches the sparse matrix `gram` at every
# explicit entry of `gram`. use gradient descent starting from `guess`, with an
# alternate technique for finding the projected base step from the unprojected
# Hessian
function realize_gram_alt_proj(
  gram::SparseMatrixCSC{T, <:Any},
  guess::Matrix{T},
  frozen = CartesianIndex[];
  scaled_tol = 1e-30,
  min_efficiency = 0.5,
  backoff = 0.9,
  reg_scale = 1.1,
  max_descent_steps = 200,
  max_backoff_steps = 110
) where T <: Number
  # start history
  history = DescentHistory{T}()
  
  # find the dimension of the search space
  dims = size(guess)
  element_dim, construction_dim = dims
  total_dim = element_dim * construction_dim
  
  # list the constrained entries of the gram matrix
  J, K, _ = findnz(gram)
  constrained = zip(J, K)
  
  # scale the tolerance
  scale_adjustment = sqrt(T(length(constrained)))
  tol = scale_adjustment * scaled_tol
  
  # convert the frozen indices to stacked format
  frozen_stacked = [(index[2]-1)*element_dim + index[1] for index in frozen]
  
  # initialize search state
  L = copy(guess)
  Δ_proj = proj_diff(gram, L'*Q*L)
  loss = dot(Δ_proj, Δ_proj)
  
  # use Newton's method with backtracking and gradient descent backup
  for step in 1:max_descent_steps
    # stop if the loss is tolerably low
    if loss < tol
      break
    end
    
    # find the negative gradient of the loss function
    neg_grad = 4*Q*L*Δ_proj
    
    # find the negative Hessian of the loss function
    hess = Matrix{T}(undef, total_dim, total_dim)
    indices = [(j, k) for k in 1:construction_dim for j in 1:element_dim]
    for (j, k) in indices
      basis_mat = basis_matrix(T, j, k, dims)
      neg_dΔ = basis_mat'*Q*L + L'*Q*basis_mat
      neg_dΔ_proj = proj_to_entries(neg_dΔ, constrained)
      deriv_grad = 4*Q*(-basis_mat*Δ_proj + L*neg_dΔ_proj)
      hess[:, (k-1)*element_dim + j] = reshape(deriv_grad, total_dim)
    end
    hess_sym = Hermitian(hess)
    push!(history.hess, hess_sym)
    
    # regularize the Hessian
    min_eigval = minimum(eigvals(hess_sym))
    push!(history.positive, min_eigval > 0)
    if min_eigval <= 0
      hess -= reg_scale * min_eigval * I
    end
    
    # compute the Newton step
    neg_grad_stacked = reshape(neg_grad, total_dim)
    for k in frozen_stacked
      neg_grad_stacked[k] = 0
      hess[k, :] .= 0
      hess[:, k] .= 0
      hess[k, k] = 1
    end
    base_step_stacked = Hermitian(hess) \ neg_grad_stacked
    base_step = reshape(base_step_stacked, dims)
    push!(history.base_step, base_step)
    
    # store the current position, loss, and slope
    L_last = L
    loss_last = loss
    push!(history.scaled_loss, loss / scale_adjustment)
    push!(history.neg_grad, neg_grad)
    push!(history.slope, norm(neg_grad))
    
    # find a good step size using backtracking line search
    push!(history.stepsize, 0)
    push!(history.backoff_steps, max_backoff_steps)
    empty!(history.last_line_L)
    empty!(history.last_line_loss)
    rate = one(T)
    step_success = false
    base_target_improvement = dot(neg_grad, base_step)
    for backoff_steps in 0:max_backoff_steps
      history.stepsize[end] = rate
      L = L_last + rate * base_step
      Δ_proj = proj_diff(gram, L'*Q*L)
      loss = dot(Δ_proj, Δ_proj)
      improvement = loss_last - loss
      push!(history.last_line_L, L)
      push!(history.last_line_loss, loss / scale_adjustment)
      if improvement >= min_efficiency * rate * base_target_improvement
        history.backoff_steps[end] = backoff_steps
        step_success = true
        break
      end
      rate *= backoff
    end
    
    # if we've hit a wall, quit
    if !step_success
      return L_last, false, history
    end
  end
  
  # return the factorization and its history
  push!(history.scaled_loss, loss / scale_adjustment)
  L, loss < tol, history
end

# seek a matrix `L` for which `L'QL` matches the sparse matrix `gram` at every
# explicit entry of `gram`. use gradient descent starting from `guess`
function realize_gram(
  gram::SparseMatrixCSC{T, <:Any},
  guess::Matrix{T},
  frozen = nothing;
  scaled_tol = 1e-30,
  min_efficiency = 0.5,
  backoff = 0.9,
  reg_scale = 1.1,
  max_descent_steps = 200,
  max_backoff_steps = 110
) where T <: Number
  # start history
  history = DescentHistory{T}()
  
  # find the dimension of the search space
  dims = size(guess)
  element_dim, construction_dim = dims
  total_dim = element_dim * construction_dim
  
  # list the constrained entries of the gram matrix
  J, K, _ = findnz(gram)
  constrained = zip(J, K)
  
  # scale the tolerance
  scale_adjustment = sqrt(T(length(constrained)))
  tol = scale_adjustment * scaled_tol
  
  # list the un-frozen indices
  has_frozen = !isnothing(frozen)
  if has_frozen
    is_unfrozen = fill(true, size(guess))
    is_unfrozen[frozen] .= false
    unfrozen = findall(is_unfrozen)
    unfrozen_stacked = reshape(is_unfrozen, total_dim)
  end
  
  # initialize search state
  L = copy(guess)
  Δ_proj = proj_diff(gram, L'*Q*L)
  loss = dot(Δ_proj, Δ_proj)
  
  # use Newton's method with backtracking and gradient descent backup
  for step in 1:max_descent_steps
    # stop if the loss is tolerably low
    if loss < tol
      break
    end
    
    # find the negative gradient of the loss function
    neg_grad = 4*Q*L*Δ_proj
    
    # find the negative Hessian of the loss function
    hess = Matrix{T}(undef, total_dim, total_dim)
    indices = [(j, k) for k in 1:construction_dim for j in 1:element_dim]
    for (j, k) in indices
      basis_mat = basis_matrix(T, j, k, dims)
      neg_dΔ = basis_mat'*Q*L + L'*Q*basis_mat
      neg_dΔ_proj = proj_to_entries(neg_dΔ, constrained)
      deriv_grad = 4*Q*(-basis_mat*Δ_proj + L*neg_dΔ_proj)
      hess[:, (k-1)*element_dim + j] = reshape(deriv_grad, total_dim)
    end
    hess = Hermitian(hess)
    push!(history.hess, hess)
    
    # regularize the Hessian
    min_eigval = minimum(eigvals(hess))
    push!(history.positive, min_eigval > 0)
    if min_eigval <= 0
      hess -= reg_scale * min_eigval * I
    end
    
    # compute the Newton step
    neg_grad_stacked = reshape(neg_grad, total_dim)
    if has_frozen
      hess = hess[unfrozen_stacked, unfrozen_stacked]
      neg_grad_compressed = neg_grad_stacked[unfrozen_stacked]
    else
      neg_grad_compressed = neg_grad_stacked
    end
    base_step_compressed = hess \ neg_grad_compressed
    if has_frozen
      base_step_stacked = zeros(total_dim)
      base_step_stacked[unfrozen_stacked] .= base_step_compressed
    else
      base_step_stacked = base_step_compressed
    end
    base_step = reshape(base_step_stacked, dims)
    push!(history.base_step, base_step)
    
    # store the current position, loss, and slope
    L_last = L
    loss_last = loss
    push!(history.scaled_loss, loss / scale_adjustment)
    push!(history.neg_grad, neg_grad)
    push!(history.slope, norm(neg_grad))
    
    # find a good step size using backtracking line search
    push!(history.stepsize, 0)
    push!(history.backoff_steps, max_backoff_steps)
    empty!(history.last_line_L)
    empty!(history.last_line_loss)
    rate = one(T)
    step_success = false
    base_target_improvement = dot(neg_grad, base_step)
    for backoff_steps in 0:max_backoff_steps
      history.stepsize[end] = rate
      L = L_last + rate * base_step
      Δ_proj = proj_diff(gram, L'*Q*L)
      loss = dot(Δ_proj, Δ_proj)
      improvement = loss_last - loss
      push!(history.last_line_L, L)
      push!(history.last_line_loss, loss / scale_adjustment)
      if improvement >= min_efficiency * rate * base_target_improvement
        history.backoff_steps[end] = backoff_steps
        step_success = true
        break
      end
      rate *= backoff
    end
    
    # if we've hit a wall, quit
    if !step_success
      return L_last, false, history
    end
  end
  
  # return the factorization and its history
  push!(history.scaled_loss, loss / scale_adjustment)
  L, loss < tol, history
end

end