module Algebraic

export
  codimension, dimension,
  Construction, realize,
  Element, Point, Sphere,
  Relation, LiesOn, AlignsWithBy, mprod

import Subscripts
using LinearAlgebra
using AbstractAlgebra
using Groebner
using ...HittingSet

# --- commutative algebra ---

# as of version 0.36.6, AbstractAlgebra only supports ideals in multivariate
# polynomial rings when the coefficients are integers. we use Groebner to extend
# support to rationals and to finite fields of prime order
Generic.reduce_gens(I::Generic.Ideal{U}) where {T <: FieldElement, U <: MPolyRingElem{T}} =
  Generic.Ideal{U}(base_ring(I), groebner(gens(I)))

function codimension(I::Generic.Ideal{U}, maxdepth = Inf) where {T <: RingElement, U <: MPolyRingElem{T}}
  leading = [exponent_vector(f, 1) for f in gens(I)]
  targets = [Set(findall(.!iszero.(exp_vec))) for exp_vec in leading]
  length(HittingSet.solve(HittingSetProblem(targets), maxdepth))
end

dimension(I::Generic.Ideal{U}, maxdepth = Inf) where {T <: RingElement, U <: MPolyRingElem{T}} =
  length(gens(base_ring(I))) - codimension(I, maxdepth)

# --- primitve elements ---

abstract type Element{T} end

mutable struct Point{T} <: Element{T}
  coords::Vector{MPolyRingElem{T}}
  vec::Union{Vector{MPolyRingElem{T}}, Nothing}
  rel::Nothing
  
  ## [to do] constructor argument never needed?
  Point{T}(
    coords::Vector{MPolyRingElem{T}} = MPolyRingElem{T}[],
    vec::Union{Vector{MPolyRingElem{T}}, Nothing} = nothing
  ) where T = new(coords, vec, nothing)
end

function buildvec!(pt::Point)
  coordring = parent(pt.coords[1])
  pt.vec = [one(coordring), dot(pt.coords, pt.coords), pt.coords...]
end

mutable struct Sphere{T} <: Element{T}
  coords::Vector{MPolyRingElem{T}}
  vec::Union{Vector{MPolyRingElem{T}}, Nothing}
  rel::Union{MPolyRingElem{T}, Nothing}
  
  ## [to do] constructor argument never needed?
  Sphere{T}(
    coords::Vector{MPolyRingElem{T}} = MPolyRingElem{T}[],
    vec::Union{Vector{MPolyRingElem{T}}, Nothing} = nothing,
    rel::Union{MPolyRingElem{T}, Nothing} = nothing
  ) where T = new(coords, vec, rel)
end

function buildvec!(sph::Sphere)
  coordring = parent(sph.coords[1])
  sph.vec = sph.coords
  sph.rel = mprod(sph.coords, sph.coords) + one(coordring)
end

const coordnames = IdDict{Symbol, Vector{Union{Symbol, Nothing}}}(
  nameof(Point) => [nothing, nothing, :xₚ, :yₚ, :zₚ],
  nameof(Sphere) => [:rₛ, :sₛ, :xₛ, :yₛ, :zₛ]
)

coordname(elt::Element, index) = coordnames[nameof(typeof(elt))][index]

function pushcoordname!(coordnamelist, indexed_elt::Tuple{Any, Element}, coordindex)
  eltindex, elt = indexed_elt
  name = coordname(elt, coordindex)
  if !isnothing(name)
    subscript = Subscripts.sub(string(eltindex))
    push!(coordnamelist, Symbol(name, subscript))
  end
end

function takecoord!(coordlist, indexed_elt::Tuple{Any, Element}, coordindex)
  elt = indexed_elt[2]
  if !isnothing(coordname(elt, coordindex))
    push!(elt.coords, popfirst!(coordlist))
  end
end

# --- primitive relations ---

abstract type Relation{T} end

mprod(v, w) = (v[1]*w[2] + w[1]*v[2]) / 2 - dot(v[3:end], w[3:end])

# elements: point, sphere
struct LiesOn{T} <: Relation{T}
  elements::Vector{Element{T}}
  
  LiesOn{T}(pt::Point{T}, sph::Sphere{T}) where T = new{T}([pt, sph])
end

equation(rel::LiesOn) = mprod(rel.elements[1].vec, rel.elements[2].vec)

# elements: sphere, sphere
struct AlignsWithBy{T} <: Relation{T}
  elements::Vector{Element{T}}
  cos_angle::T
  
  AlignsWithBy{T}(sph1::Sphere{T}, sph2::Sphere{T}, cos_angle::T) where T = new{T}([sph1, sph2], cos_angle)
end

equation(rel::AlignsWithBy) = mprod(rel.elements[1].vec, rel.elements[2].vec) - rel.cos_angle

# --- constructions ---

mutable struct Construction{T}
  points::Set{Point{T}}
  spheres::Set{Sphere{T}}
  relations::Set{Relation{T}}
  
  function Construction{T}(; elements = Set{Element{T}}(), relations = Set{Relation{T}}()) where T
    allelements = union(elements, (rel.elements for rel in relations)...)
    new{T}(
      filter(elt -> isa(elt, Point), allelements),
      filter(elt -> isa(elt, Sphere), allelements),
      relations
    )
  end
end

function Base.push!(ctx::Construction{T}, elt::Point{T}) where T
  push!(ctx.points, elt)
end

function Base.push!(ctx::Construction{T}, elt::Sphere{T}) where T
  push!(ctx.spheres, elt)
end

function Base.push!(ctx::Construction{T}, rel::Relation{T}) where T
  push!(ctx.relations, rel)
  for elt in rel.elements
    push!(ctx, elt)
  end
end

function realize(ctx::Construction{T}) where T
  # collect coordinate names
  coordnamelist = Symbol[]
  eltenum = enumerate(Iterators.flatten((ctx.spheres, ctx.points)))
  for coordindex in 1:5
    for indexed_elt in eltenum
      pushcoordname!(coordnamelist, indexed_elt, coordindex)
    end
  end
  
  # construct coordinate ring
  coordring, coordqueue = polynomial_ring(parent_type(T)(), coordnamelist, ordering = :degrevlex)
  
  # retrieve coordinates
  for (_, elt) in eltenum
    empty!(elt.coords)
  end
  for coordindex in 1:5
    for indexed_elt in eltenum
      takecoord!(coordqueue, indexed_elt, coordindex)
    end
  end
  
  # construct coordinate vectors
  for (_, elt) in eltenum
    buildvec!(elt)
  end
  
  # turn relations into equations
  eqns = vcat(
    equation.(ctx.relations),
    [elt.rel for (_, elt) in eltenum if !isnothing(elt.rel)]
  )
  
  # add relations to center, orient, and scale the construction
  if !isempty(ctx.points)
    append!(eqns, [sum(pt.coords[k] for pt in ctx.points) for k in 1:3])
  end
  if !isempty(ctx.spheres)
    append!(eqns, [sum(sph.coords[k] for sph in ctx.spheres) for k in 3:4])
  end
  n_elts = length(ctx.points) + length(ctx.spheres)
  if n_elts > 0
    push!(eqns, sum(elt.vec[2] for elt in Iterators.flatten((ctx.points, ctx.spheres))) - n_elts)
  end
  
  (Generic.Ideal(coordring, eqns), eqns)
end

end