dyna3/engine-proto/gram-test/Engine.jl
Aaron Fenyes d538cbf716 Correct improvement threshold by using unit step
Our formula for the improvement theshold works when the step size is
an absolute distance. However, in commit `4d5ea06`, the step size was
measured relative to the current gradient instead. This commit scales
the base step to unit length, so now the step size really is an absolute
distance.
2024-07-10 23:31:44 -07:00

163 lines
4.4 KiB
Julia

module Engine
using LinearAlgebra
using SparseArrays
using Random
export rand_on_shell, Q, DescentHistory, realize_gram
# === guessing ===
sconh(t, u) = 0.5*(exp(t) + u*exp(-t))
function rand_on_sphere(rng::AbstractRNG, ::Type{T}, n) where T
out = randn(rng, T, n)
tries_left = 2
while dot(out, out) < 1e-6 && tries_left > 0
out = randn(rng, T, n)
tries_left -= 1
end
normalize(out)
end
##[TO DO] write a test to confirm that the outputs are on the correct shells
function rand_on_shell(rng::AbstractRNG, shell::T) where T <: Number
space_part = rand_on_sphere(rng, T, 4)
rapidity = randn(rng, T)
sig = sign(shell)
[sconh(rapidity, sig)*space_part; sconh(rapidity, -sig)]
end
rand_on_shell(rng::AbstractRNG, shells::Array{T}) where T <: Number =
hcat([rand_on_shell(rng, sh) for sh in shells]...)
rand_on_shell(shells::Array{<:Number}) = rand_on_shell(Random.default_rng(), shells)
# === elements ===
plane(normal, offset) = [normal; offset; offset]
function sphere(center, radius)
dist_sq = dot(center, center)
return [
center / radius;
0.5 * ((dist_sq - 1) / radius - radius);
0.5 * ((dist_sq + 1) / radius - radius)
]
end
# === Gram matrix realization ===
# the Lorentz form
Q = diagm([1, 1, 1, 1, -1])
# the difference between the matrices `target` and `attempt`, projected onto the
# subspace of matrices whose entries vanish at each empty index of `target`
function proj_diff(target::SparseMatrixCSC{T, <:Any}, attempt::Matrix{T}) where T
J, K, values = findnz(target)
result = zeros(size(target)...)
for (j, k, val) in zip(J, K, values)
result[j, k] = val - attempt[j, k]
end
result
end
# a type for keeping track of gradient descent history
struct DescentHistory{T}
scaled_loss::Array{T}
neg_grad::Array{Matrix{T}}
slope::Array{T}
stepsize::Array{T}
backoff_steps::Array{Int64}
last_line_L::Array{Matrix{T}}
last_line_loss::Array{T}
function DescentHistory{T}(
scaled_loss = Array{T}(undef, 0),
neg_grad = Array{Matrix{T}}(undef, 0),
slope = Array{T}(undef, 0),
stepsize = Array{T}(undef, 0),
backoff_steps = Int64[],
last_line_L = Array{Matrix{T}}(undef, 0),
last_line_loss = Array{T}(undef, 0)
) where T
new(scaled_loss, neg_grad, slope, stepsize, backoff_steps, last_line_L, last_line_loss)
end
end
# seek a matrix `L` for which `L'QL` matches the sparse matrix `gram` at every
# explicit entry of `gram`. use gradient descent starting from `guess`
function realize_gram(
gram::SparseMatrixCSC{T, <:Any},
guess::Matrix{T};
scaled_tol = 1e-30,
target_improvement = 0.5,
init_stepsize = 1.0,
backoff = 0.9,
max_descent_steps = 600,
max_backoff_steps = 110
) where T <: Number
# start history
history = DescentHistory{T}()
# scale tolerance
scale_adjustment = sqrt(T(nnz(gram)))
tol = scale_adjustment * scaled_tol
# initialize variables
stepsize = init_stepsize
L = copy(guess)
# do gradient descent
Δ_proj = proj_diff(gram, L'*Q*L)
loss = dot(Δ_proj, Δ_proj)
for step in 1:max_descent_steps
# stop if the loss is tolerably low
if loss < tol
break
end
# find negative gradient of loss function
neg_grad = 4*Q*L*Δ_proj
slope = norm(neg_grad)
dir = neg_grad / slope
# store current position, loss, and slope
L_last = L
loss_last = loss
push!(history.scaled_loss, loss / scale_adjustment)
push!(history.neg_grad, neg_grad)
push!(history.slope, slope)
# find a good step size using backtracking line search
push!(history.stepsize, 0)
push!(history.backoff_steps, max_backoff_steps)
empty!(history.last_line_L)
empty!(history.last_line_loss)
for backoff_steps in 0:max_backoff_steps
history.stepsize[end] = stepsize
L = L_last + stepsize * dir
Δ_proj = proj_diff(gram, L'*Q*L)
loss = dot(Δ_proj, Δ_proj)
improvement = loss_last - loss
push!(history.last_line_L, L)
push!(history.last_line_loss, loss / scale_adjustment)
if improvement >= target_improvement * stepsize * slope
history.backoff_steps[end] = backoff_steps
break
end
stepsize *= backoff
end
# [DEBUG] if we've hit a wall, quit
if history.backoff_steps[end] == max_backoff_steps
break
end
end
# return the factorization and its history
push!(history.scaled_loss, loss / scale_adjustment)
L, history
end
end