using LinearAlgebra
using AbstractAlgebra

function printgood(msg)
  printstyled("✓", color = :green)
  println(" ", msg)
end

function printbad(msg)
  printstyled("✗", color = :red)
  println(" ", msg)
end

F, gens = rational_function_field(AbstractAlgebra.Rationals{BigInt}(), ["a₁", "a₂", "b₁", "b₂", "c₁", "c₂"])
a = gens[1:2]
b = gens[3:4]
c = gens[5:6]

# three mutually tangent spheres which are all perpendicular to the x, y plane
gram = [
  -1  1  1;
   1 -1  1;
   1  1 -1
]

eig = eigen(gram)
n_pos = count(eig.values .> 0.5)
n_neg = count(eig.values .< -0.5)
if n_pos + n_neg == size(gram, 1)
  printgood("Non-degenerate subspace")
else
  printbad("Degenerate subspace")
end
sig_rem = Int64[ones(1-n_pos); -ones(4-n_neg)]
unk = hcat(a, b, c)
M = matrix_space(F, 5, 5)
big_gram = M(F.([
  diagm(sig_rem) unk;
  transpose(unk) gram
]))

r, p, L, U = lu(big_gram)
if isone(p)
  printgood("Found a solution")
else
  printbad("Didn't find a solution")
end
solution = transpose(L)
mform = U * inv(solution)

vals = [0, 0, 0, 1, 0, -3//4]
solution_ex = [evaluate(entry, vals) for entry in solution]
mform_ex = [evaluate(entry, vals) for entry in mform]

std_basis = [
  0 0 0 1  1;
  0 0 0 1 -1;
  1 0 0 0  0;
  0 1 0 0  0;
  0 0 1 0  0
]
std_solution = M(F.(std_basis)) * solution
std_solution_ex = std_basis * solution_ex

println("Minkowski form:")
display(mform_ex)

big_gram_recovered = transpose(solution_ex) * mform_ex * solution_ex
valid = all(iszero.(
  [evaluate(entry, vals) for entry in big_gram] - big_gram_recovered
))
if valid
  printgood("Recovered Gram matrix:")
else
  printbad("Didn't recover Gram matrix. Instead, got:")
end
display(big_gram_recovered)

# this should be a solution
hand_solution = [0 0 1 0 0; 0 0 -1 2 2; 0 0 0 1 -1; 1 0 0 0 0; 0 1 0 0 0]
unmix = Rational{Int64}[[1//2 1//2; 1//2 -1//2] zeros(Int64, 2, 3); zeros(Int64, 3, 2) Matrix{Int64}(I, 3, 3)]
hand_solution_diag = unmix * hand_solution
big_gram_hand_recovered = transpose(hand_solution_diag) * diagm([1; -ones(Int64, 4)]) * hand_solution_diag
println("Gram matrix from hand-written solution:")
display(big_gram_hand_recovered)