2024-07-02 21:57:57 +00:00
|
|
|
module Engine
|
|
|
|
|
|
|
|
using LinearAlgebra
|
2024-07-15 18:32:04 +00:00
|
|
|
using GenericLinearAlgebra
|
2024-07-02 21:57:57 +00:00
|
|
|
using SparseArrays
|
2024-07-03 00:16:31 +00:00
|
|
|
using Random
|
2024-07-15 18:32:04 +00:00
|
|
|
using Optim
|
2024-07-02 21:57:57 +00:00
|
|
|
|
2024-07-15 20:15:15 +00:00
|
|
|
export
|
|
|
|
rand_on_shell, Q, DescentHistory,
|
|
|
|
realize_gram_gradient, realize_gram_newton, realize_gram_optim, realize_gram
|
2024-07-03 00:16:31 +00:00
|
|
|
|
|
|
|
# === guessing ===
|
|
|
|
|
2024-07-07 04:32:43 +00:00
|
|
|
sconh(t, u) = 0.5*(exp(t) + u*exp(-t))
|
|
|
|
|
|
|
|
function rand_on_sphere(rng::AbstractRNG, ::Type{T}, n) where T
|
|
|
|
out = randn(rng, T, n)
|
|
|
|
tries_left = 2
|
|
|
|
while dot(out, out) < 1e-6 && tries_left > 0
|
|
|
|
out = randn(rng, T, n)
|
|
|
|
tries_left -= 1
|
|
|
|
end
|
|
|
|
normalize(out)
|
|
|
|
end
|
|
|
|
|
2024-07-03 00:16:31 +00:00
|
|
|
##[TO DO] write a test to confirm that the outputs are on the correct shells
|
2024-07-07 04:32:43 +00:00
|
|
|
function rand_on_shell(rng::AbstractRNG, shell::T) where T <: Number
|
|
|
|
space_part = rand_on_sphere(rng, T, 4)
|
|
|
|
rapidity = randn(rng, T)
|
|
|
|
sig = sign(shell)
|
2024-07-17 22:37:14 +00:00
|
|
|
nullmix * [sconh(rapidity, sig)*space_part; sconh(rapidity, -sig)]
|
2024-07-03 00:16:31 +00:00
|
|
|
end
|
|
|
|
|
2024-07-07 04:32:43 +00:00
|
|
|
rand_on_shell(rng::AbstractRNG, shells::Array{T}) where T <: Number =
|
|
|
|
hcat([rand_on_shell(rng, sh) for sh in shells]...)
|
|
|
|
|
2024-07-03 00:16:31 +00:00
|
|
|
rand_on_shell(shells::Array{<:Number}) = rand_on_shell(Random.default_rng(), shells)
|
|
|
|
|
2024-07-08 19:56:14 +00:00
|
|
|
# === elements ===
|
|
|
|
|
2024-07-18 06:37:28 +00:00
|
|
|
point(pos) = [pos; 0.5; 0.5 * dot(pos, pos)]
|
2024-07-17 21:30:43 +00:00
|
|
|
|
2024-07-18 07:03:12 +00:00
|
|
|
plane(normal, offset) = [-normal; 0; -offset]
|
2024-07-08 19:56:14 +00:00
|
|
|
|
|
|
|
function sphere(center, radius)
|
|
|
|
dist_sq = dot(center, center)
|
2024-07-17 21:30:43 +00:00
|
|
|
[
|
2024-07-08 19:56:14 +00:00
|
|
|
center / radius;
|
2024-07-18 06:07:34 +00:00
|
|
|
0.5 / radius;
|
2024-07-17 22:37:14 +00:00
|
|
|
0.5 * (dist_sq / radius - radius)
|
2024-07-08 19:56:14 +00:00
|
|
|
]
|
|
|
|
end
|
|
|
|
|
2024-07-03 00:16:31 +00:00
|
|
|
# === Gram matrix realization ===
|
2024-07-02 21:57:57 +00:00
|
|
|
|
2024-07-17 21:30:43 +00:00
|
|
|
# basis changes
|
2024-07-18 06:07:34 +00:00
|
|
|
nullmix = [Matrix{Int64}(I, 3, 3) zeros(Int64, 3, 2); zeros(Int64, 2, 3) [-1 1; 1 1]//2]
|
|
|
|
unmix = [Matrix{Int64}(I, 3, 3) zeros(Int64, 3, 2); zeros(Int64, 2, 3) [-1 1; 1 1]]
|
2024-07-17 21:30:43 +00:00
|
|
|
|
2024-07-02 21:57:57 +00:00
|
|
|
# the Lorentz form
|
2024-07-18 06:07:34 +00:00
|
|
|
Q = [Matrix{Int64}(I, 3, 3) zeros(Int64, 3, 2); zeros(Int64, 2, 3) [0 -2; -2 0]]
|
2024-07-02 21:57:57 +00:00
|
|
|
|
2024-10-16 23:00:36 +00:00
|
|
|
# project a matrix onto the subspace of matrices whose entries vanish away from
|
|
|
|
# the given indices
|
2024-07-11 20:43:52 +00:00
|
|
|
function proj_to_entries(mat, indices)
|
|
|
|
result = zeros(size(mat))
|
|
|
|
for (j, k) in indices
|
|
|
|
result[j, k] = mat[j, k]
|
|
|
|
end
|
|
|
|
result
|
|
|
|
end
|
|
|
|
|
2024-07-02 21:57:57 +00:00
|
|
|
# the difference between the matrices `target` and `attempt`, projected onto the
|
|
|
|
# subspace of matrices whose entries vanish at each empty index of `target`
|
|
|
|
function proj_diff(target::SparseMatrixCSC{T, <:Any}, attempt::Matrix{T}) where T
|
|
|
|
J, K, values = findnz(target)
|
2024-07-11 20:43:52 +00:00
|
|
|
result = zeros(size(target))
|
2024-07-02 21:57:57 +00:00
|
|
|
for (j, k, val) in zip(J, K, values)
|
|
|
|
result[j, k] = val - attempt[j, k]
|
|
|
|
end
|
|
|
|
result
|
|
|
|
end
|
|
|
|
|
|
|
|
# a type for keeping track of gradient descent history
|
|
|
|
struct DescentHistory{T}
|
|
|
|
scaled_loss::Array{T}
|
2024-07-09 22:00:13 +00:00
|
|
|
neg_grad::Array{Matrix{T}}
|
2024-07-15 18:32:04 +00:00
|
|
|
base_step::Array{Matrix{T}}
|
|
|
|
hess::Array{Hermitian{T, Matrix{T}}}
|
2024-07-08 21:19:25 +00:00
|
|
|
slope::Array{T}
|
2024-07-02 21:57:57 +00:00
|
|
|
stepsize::Array{T}
|
2024-07-15 20:15:15 +00:00
|
|
|
positive::Array{Bool}
|
2024-07-02 21:57:57 +00:00
|
|
|
backoff_steps::Array{Int64}
|
2024-07-09 22:00:13 +00:00
|
|
|
last_line_L::Array{Matrix{T}}
|
|
|
|
last_line_loss::Array{T}
|
2024-07-02 21:57:57 +00:00
|
|
|
|
|
|
|
function DescentHistory{T}(
|
|
|
|
scaled_loss = Array{T}(undef, 0),
|
2024-07-09 22:00:13 +00:00
|
|
|
neg_grad = Array{Matrix{T}}(undef, 0),
|
2024-07-15 18:32:04 +00:00
|
|
|
hess = Array{Hermitian{T, Matrix{T}}}(undef, 0),
|
|
|
|
base_step = Array{Matrix{T}}(undef, 0),
|
2024-07-08 21:19:25 +00:00
|
|
|
slope = Array{T}(undef, 0),
|
2024-07-02 21:57:57 +00:00
|
|
|
stepsize = Array{T}(undef, 0),
|
2024-07-15 20:15:15 +00:00
|
|
|
positive = Bool[],
|
2024-07-09 22:00:13 +00:00
|
|
|
backoff_steps = Int64[],
|
|
|
|
last_line_L = Array{Matrix{T}}(undef, 0),
|
|
|
|
last_line_loss = Array{T}(undef, 0)
|
2024-07-02 21:57:57 +00:00
|
|
|
) where T
|
2024-07-15 20:15:15 +00:00
|
|
|
new(scaled_loss, neg_grad, hess, base_step, slope, stepsize, positive, backoff_steps, last_line_L, last_line_loss)
|
2024-07-02 21:57:57 +00:00
|
|
|
end
|
|
|
|
end
|
|
|
|
|
|
|
|
# seek a matrix `L` for which `L'QL` matches the sparse matrix `gram` at every
|
|
|
|
# explicit entry of `gram`. use gradient descent starting from `guess`
|
2024-07-11 20:43:52 +00:00
|
|
|
function realize_gram_gradient(
|
2024-07-02 21:57:57 +00:00
|
|
|
gram::SparseMatrixCSC{T, <:Any},
|
|
|
|
guess::Matrix{T};
|
|
|
|
scaled_tol = 1e-30,
|
2024-07-15 18:32:04 +00:00
|
|
|
min_efficiency = 0.5,
|
2024-07-02 21:57:57 +00:00
|
|
|
init_stepsize = 1.0,
|
|
|
|
backoff = 0.9,
|
|
|
|
max_descent_steps = 600,
|
|
|
|
max_backoff_steps = 110
|
|
|
|
) where T <: Number
|
|
|
|
# start history
|
|
|
|
history = DescentHistory{T}()
|
|
|
|
|
|
|
|
# scale tolerance
|
|
|
|
scale_adjustment = sqrt(T(nnz(gram)))
|
|
|
|
tol = scale_adjustment * scaled_tol
|
|
|
|
|
|
|
|
# initialize variables
|
|
|
|
stepsize = init_stepsize
|
|
|
|
L = copy(guess)
|
|
|
|
|
|
|
|
# do gradient descent
|
|
|
|
Δ_proj = proj_diff(gram, L'*Q*L)
|
2024-07-09 21:00:24 +00:00
|
|
|
loss = dot(Δ_proj, Δ_proj)
|
2024-07-11 20:43:52 +00:00
|
|
|
for _ in 1:max_descent_steps
|
2024-07-02 21:57:57 +00:00
|
|
|
# stop if the loss is tolerably low
|
|
|
|
if loss < tol
|
|
|
|
break
|
|
|
|
end
|
|
|
|
|
|
|
|
# find negative gradient of loss function
|
|
|
|
neg_grad = 4*Q*L*Δ_proj
|
|
|
|
slope = norm(neg_grad)
|
2024-07-11 06:31:44 +00:00
|
|
|
dir = neg_grad / slope
|
2024-07-02 21:57:57 +00:00
|
|
|
|
2024-07-08 21:19:25 +00:00
|
|
|
# store current position, loss, and slope
|
2024-07-02 21:57:57 +00:00
|
|
|
L_last = L
|
|
|
|
loss_last = loss
|
|
|
|
push!(history.scaled_loss, loss / scale_adjustment)
|
2024-07-09 22:00:13 +00:00
|
|
|
push!(history.neg_grad, neg_grad)
|
2024-07-08 21:19:25 +00:00
|
|
|
push!(history.slope, slope)
|
2024-07-02 21:57:57 +00:00
|
|
|
|
|
|
|
# find a good step size using backtracking line search
|
|
|
|
push!(history.stepsize, 0)
|
|
|
|
push!(history.backoff_steps, max_backoff_steps)
|
2024-07-09 22:00:13 +00:00
|
|
|
empty!(history.last_line_L)
|
|
|
|
empty!(history.last_line_loss)
|
2024-07-02 21:57:57 +00:00
|
|
|
for backoff_steps in 0:max_backoff_steps
|
|
|
|
history.stepsize[end] = stepsize
|
2024-07-11 06:31:44 +00:00
|
|
|
L = L_last + stepsize * dir
|
2024-07-02 21:57:57 +00:00
|
|
|
Δ_proj = proj_diff(gram, L'*Q*L)
|
2024-07-09 21:00:24 +00:00
|
|
|
loss = dot(Δ_proj, Δ_proj)
|
2024-07-02 21:57:57 +00:00
|
|
|
improvement = loss_last - loss
|
2024-07-09 22:00:13 +00:00
|
|
|
push!(history.last_line_L, L)
|
|
|
|
push!(history.last_line_loss, loss / scale_adjustment)
|
2024-07-15 18:32:04 +00:00
|
|
|
if improvement >= min_efficiency * stepsize * slope
|
2024-07-02 21:57:57 +00:00
|
|
|
history.backoff_steps[end] = backoff_steps
|
|
|
|
break
|
|
|
|
end
|
|
|
|
stepsize *= backoff
|
|
|
|
end
|
2024-07-09 22:00:13 +00:00
|
|
|
|
|
|
|
# [DEBUG] if we've hit a wall, quit
|
|
|
|
if history.backoff_steps[end] == max_backoff_steps
|
|
|
|
break
|
|
|
|
end
|
2024-07-02 21:57:57 +00:00
|
|
|
end
|
|
|
|
|
|
|
|
# return the factorization and its history
|
|
|
|
push!(history.scaled_loss, loss / scale_adjustment)
|
|
|
|
L, history
|
|
|
|
end
|
|
|
|
|
2024-07-11 20:43:52 +00:00
|
|
|
function basis_matrix(::Type{T}, j, k, dims) where T
|
|
|
|
result = zeros(T, dims)
|
|
|
|
result[j, k] = one(T)
|
|
|
|
result
|
|
|
|
end
|
|
|
|
|
|
|
|
# seek a matrix `L` for which `L'QL` matches the sparse matrix `gram` at every
|
|
|
|
# explicit entry of `gram`. use Newton's method starting from `guess`
|
|
|
|
function realize_gram_newton(
|
|
|
|
gram::SparseMatrixCSC{T, <:Any},
|
|
|
|
guess::Matrix{T};
|
|
|
|
scaled_tol = 1e-30,
|
|
|
|
rate = 1,
|
|
|
|
max_steps = 100
|
|
|
|
) where T <: Number
|
|
|
|
# start history
|
|
|
|
history = DescentHistory{T}()
|
|
|
|
|
|
|
|
# find the dimension of the search space
|
|
|
|
dims = size(guess)
|
|
|
|
element_dim, construction_dim = dims
|
|
|
|
total_dim = element_dim * construction_dim
|
|
|
|
|
|
|
|
# list the constrained entries of the gram matrix
|
|
|
|
J, K, _ = findnz(gram)
|
|
|
|
constrained = zip(J, K)
|
|
|
|
|
|
|
|
# scale the tolerance
|
|
|
|
scale_adjustment = sqrt(T(length(constrained)))
|
|
|
|
tol = scale_adjustment * scaled_tol
|
|
|
|
|
2024-07-15 18:32:04 +00:00
|
|
|
# use Newton's method
|
2024-07-11 20:43:52 +00:00
|
|
|
L = copy(guess)
|
|
|
|
for step in 0:max_steps
|
|
|
|
# evaluate the loss function
|
|
|
|
Δ_proj = proj_diff(gram, L'*Q*L)
|
|
|
|
loss = dot(Δ_proj, Δ_proj)
|
|
|
|
|
|
|
|
# store the current loss
|
|
|
|
push!(history.scaled_loss, loss / scale_adjustment)
|
|
|
|
|
|
|
|
# stop if the loss is tolerably low
|
|
|
|
if loss < tol || step > max_steps
|
|
|
|
break
|
|
|
|
end
|
|
|
|
|
|
|
|
# find the negative gradient of loss function
|
|
|
|
neg_grad = 4*Q*L*Δ_proj
|
|
|
|
|
|
|
|
# find the negative Hessian of the loss function
|
|
|
|
hess = Matrix{T}(undef, total_dim, total_dim)
|
|
|
|
indices = [(j, k) for k in 1:construction_dim for j in 1:element_dim]
|
|
|
|
for (j, k) in indices
|
|
|
|
basis_mat = basis_matrix(T, j, k, dims)
|
|
|
|
neg_dΔ = basis_mat'*Q*L + L'*Q*basis_mat
|
|
|
|
neg_dΔ_proj = proj_to_entries(neg_dΔ, constrained)
|
|
|
|
deriv_grad = 4*Q*(-basis_mat*Δ_proj + L*neg_dΔ_proj)
|
|
|
|
hess[:, (k-1)*element_dim + j] = reshape(deriv_grad, total_dim)
|
|
|
|
end
|
2024-07-15 18:32:04 +00:00
|
|
|
hess = Hermitian(hess)
|
|
|
|
push!(history.hess, hess)
|
2024-07-11 20:43:52 +00:00
|
|
|
|
2024-07-15 18:32:04 +00:00
|
|
|
# compute the Newton step
|
2024-07-11 20:43:52 +00:00
|
|
|
step = hess \ reshape(neg_grad, total_dim)
|
|
|
|
L += rate * reshape(step, dims)
|
|
|
|
end
|
|
|
|
|
|
|
|
# return the factorization and its history
|
|
|
|
L, history
|
|
|
|
end
|
|
|
|
|
2024-07-15 18:32:04 +00:00
|
|
|
LinearAlgebra.eigen!(A::Symmetric{BigFloat, Matrix{BigFloat}}; sortby::Nothing) =
|
|
|
|
eigen!(Hermitian(A))
|
|
|
|
|
2024-07-15 21:08:57 +00:00
|
|
|
function convertnz(type, mat)
|
|
|
|
J, K, values = findnz(mat)
|
|
|
|
sparse(J, K, type.(values))
|
|
|
|
end
|
|
|
|
|
2024-07-15 18:32:04 +00:00
|
|
|
function realize_gram_optim(
|
|
|
|
gram::SparseMatrixCSC{T, <:Any},
|
|
|
|
guess::Matrix{T}
|
|
|
|
) where T <: Number
|
|
|
|
# find the dimension of the search space
|
|
|
|
dims = size(guess)
|
|
|
|
element_dim, construction_dim = dims
|
|
|
|
total_dim = element_dim * construction_dim
|
|
|
|
|
|
|
|
# list the constrained entries of the gram matrix
|
|
|
|
J, K, _ = findnz(gram)
|
|
|
|
constrained = zip(J, K)
|
|
|
|
|
|
|
|
# scale the loss function
|
|
|
|
scale_adjustment = length(constrained)
|
|
|
|
|
|
|
|
function loss(L_vec)
|
|
|
|
L = reshape(L_vec, dims)
|
|
|
|
Δ_proj = proj_diff(gram, L'*Q*L)
|
|
|
|
dot(Δ_proj, Δ_proj) / scale_adjustment
|
|
|
|
end
|
|
|
|
|
|
|
|
function loss_grad!(storage, L_vec)
|
|
|
|
L = reshape(L_vec, dims)
|
|
|
|
Δ_proj = proj_diff(gram, L'*Q*L)
|
|
|
|
storage .= reshape(-4*Q*L*Δ_proj, total_dim) / scale_adjustment
|
|
|
|
end
|
|
|
|
|
|
|
|
function loss_hess!(storage, L_vec)
|
|
|
|
L = reshape(L_vec, dims)
|
|
|
|
Δ_proj = proj_diff(gram, L'*Q*L)
|
|
|
|
indices = [(j, k) for k in 1:construction_dim for j in 1:element_dim]
|
|
|
|
for (j, k) in indices
|
|
|
|
basis_mat = basis_matrix(T, j, k, dims)
|
|
|
|
neg_dΔ = basis_mat'*Q*L + L'*Q*basis_mat
|
|
|
|
neg_dΔ_proj = proj_to_entries(neg_dΔ, constrained)
|
|
|
|
deriv_grad = 4*Q*(-basis_mat*Δ_proj + L*neg_dΔ_proj) / scale_adjustment
|
|
|
|
storage[:, (k-1)*element_dim + j] = reshape(deriv_grad, total_dim)
|
|
|
|
end
|
|
|
|
end
|
|
|
|
|
|
|
|
optimize(
|
|
|
|
loss, loss_grad!, loss_hess!,
|
|
|
|
reshape(guess, total_dim),
|
2024-07-15 22:52:38 +00:00
|
|
|
Newton()
|
2024-07-15 18:32:04 +00:00
|
|
|
)
|
|
|
|
end
|
|
|
|
|
|
|
|
# seek a matrix `L` for which `L'QL` matches the sparse matrix `gram` at every
|
|
|
|
# explicit entry of `gram`. use gradient descent starting from `guess`
|
|
|
|
function realize_gram(
|
|
|
|
gram::SparseMatrixCSC{T, <:Any},
|
2024-07-16 05:11:54 +00:00
|
|
|
guess::Matrix{T},
|
|
|
|
frozen = nothing;
|
2024-07-15 21:31:30 +00:00
|
|
|
scaled_tol = 1e-30,
|
2024-07-15 18:32:04 +00:00
|
|
|
min_efficiency = 0.5,
|
|
|
|
init_rate = 1.0,
|
|
|
|
backoff = 0.9,
|
|
|
|
reg_scale = 1.1,
|
|
|
|
max_descent_steps = 200,
|
|
|
|
max_backoff_steps = 110
|
|
|
|
) where T <: Number
|
|
|
|
# start history
|
|
|
|
history = DescentHistory{T}()
|
|
|
|
|
|
|
|
# find the dimension of the search space
|
|
|
|
dims = size(guess)
|
|
|
|
element_dim, construction_dim = dims
|
|
|
|
total_dim = element_dim * construction_dim
|
|
|
|
|
|
|
|
# list the constrained entries of the gram matrix
|
|
|
|
J, K, _ = findnz(gram)
|
|
|
|
constrained = zip(J, K)
|
|
|
|
|
|
|
|
# scale the tolerance
|
|
|
|
scale_adjustment = sqrt(T(length(constrained)))
|
|
|
|
tol = scale_adjustment * scaled_tol
|
|
|
|
|
2024-07-16 05:11:54 +00:00
|
|
|
# list the un-frozen indices
|
|
|
|
has_frozen = !isnothing(frozen)
|
|
|
|
if has_frozen
|
|
|
|
is_unfrozen = fill(true, size(guess))
|
|
|
|
is_unfrozen[frozen] .= false
|
|
|
|
unfrozen = findall(is_unfrozen)
|
|
|
|
unfrozen_stacked = reshape(is_unfrozen, total_dim)
|
|
|
|
end
|
|
|
|
|
2024-07-15 18:32:04 +00:00
|
|
|
# initialize variables
|
|
|
|
grad_rate = init_rate
|
|
|
|
L = copy(guess)
|
|
|
|
|
|
|
|
# use Newton's method with backtracking and gradient descent backup
|
|
|
|
Δ_proj = proj_diff(gram, L'*Q*L)
|
|
|
|
loss = dot(Δ_proj, Δ_proj)
|
|
|
|
for step in 1:max_descent_steps
|
|
|
|
# stop if the loss is tolerably low
|
|
|
|
if loss < tol
|
|
|
|
break
|
|
|
|
end
|
|
|
|
|
|
|
|
# find the negative gradient of loss function
|
|
|
|
neg_grad = 4*Q*L*Δ_proj
|
|
|
|
|
|
|
|
# find the negative Hessian of the loss function
|
|
|
|
hess = Matrix{T}(undef, total_dim, total_dim)
|
|
|
|
indices = [(j, k) for k in 1:construction_dim for j in 1:element_dim]
|
|
|
|
for (j, k) in indices
|
|
|
|
basis_mat = basis_matrix(T, j, k, dims)
|
|
|
|
neg_dΔ = basis_mat'*Q*L + L'*Q*basis_mat
|
|
|
|
neg_dΔ_proj = proj_to_entries(neg_dΔ, constrained)
|
|
|
|
deriv_grad = 4*Q*(-basis_mat*Δ_proj + L*neg_dΔ_proj)
|
|
|
|
hess[:, (k-1)*element_dim + j] = reshape(deriv_grad, total_dim)
|
|
|
|
end
|
|
|
|
hess = Hermitian(hess)
|
|
|
|
push!(history.hess, hess)
|
|
|
|
|
2024-07-15 20:15:15 +00:00
|
|
|
# regularize the Hessian
|
2024-07-15 18:32:04 +00:00
|
|
|
min_eigval = minimum(eigvals(hess))
|
2024-07-15 20:15:15 +00:00
|
|
|
push!(history.positive, min_eigval > 0)
|
|
|
|
if min_eigval <= 0
|
2024-07-15 18:32:04 +00:00
|
|
|
hess -= reg_scale * min_eigval * I
|
|
|
|
end
|
2024-07-16 05:11:54 +00:00
|
|
|
|
|
|
|
# compute the Newton step
|
|
|
|
neg_grad_stacked = reshape(neg_grad, total_dim)
|
|
|
|
if has_frozen
|
|
|
|
hess = hess[unfrozen_stacked, unfrozen_stacked]
|
|
|
|
neg_grad_compressed = neg_grad_stacked[unfrozen_stacked]
|
|
|
|
else
|
|
|
|
neg_grad_compressed = neg_grad_stacked
|
|
|
|
end
|
|
|
|
base_step_compressed = hess \ neg_grad_compressed
|
|
|
|
if has_frozen
|
|
|
|
base_step_stacked = zeros(total_dim)
|
|
|
|
base_step_stacked[unfrozen_stacked] .= base_step_compressed
|
|
|
|
else
|
|
|
|
base_step_stacked = base_step_compressed
|
|
|
|
end
|
|
|
|
base_step = reshape(base_step_stacked, dims)
|
2024-07-15 18:32:04 +00:00
|
|
|
push!(history.base_step, base_step)
|
|
|
|
|
|
|
|
# store the current position, loss, and slope
|
|
|
|
L_last = L
|
|
|
|
loss_last = loss
|
|
|
|
push!(history.scaled_loss, loss / scale_adjustment)
|
|
|
|
push!(history.neg_grad, neg_grad)
|
|
|
|
push!(history.slope, norm(neg_grad))
|
|
|
|
|
|
|
|
# find a good step size using backtracking line search
|
|
|
|
push!(history.stepsize, 0)
|
|
|
|
push!(history.backoff_steps, max_backoff_steps)
|
|
|
|
empty!(history.last_line_L)
|
|
|
|
empty!(history.last_line_loss)
|
|
|
|
rate = one(T)
|
2024-07-15 20:15:15 +00:00
|
|
|
step_success = false
|
2024-07-15 18:32:04 +00:00
|
|
|
for backoff_steps in 0:max_backoff_steps
|
|
|
|
history.stepsize[end] = rate
|
2024-07-15 20:15:15 +00:00
|
|
|
L = L_last + rate * base_step
|
2024-07-15 18:32:04 +00:00
|
|
|
Δ_proj = proj_diff(gram, L'*Q*L)
|
|
|
|
loss = dot(Δ_proj, Δ_proj)
|
|
|
|
improvement = loss_last - loss
|
2024-07-15 20:15:15 +00:00
|
|
|
push!(history.last_line_L, L)
|
|
|
|
push!(history.last_line_loss, loss / scale_adjustment)
|
|
|
|
if improvement >= min_efficiency * rate * dot(neg_grad, base_step)
|
2024-07-15 18:32:04 +00:00
|
|
|
history.backoff_steps[end] = backoff_steps
|
2024-07-15 20:15:15 +00:00
|
|
|
step_success = true
|
2024-07-15 18:32:04 +00:00
|
|
|
break
|
|
|
|
end
|
|
|
|
rate *= backoff
|
|
|
|
end
|
|
|
|
|
2024-07-15 20:15:15 +00:00
|
|
|
# if we've hit a wall, quit
|
|
|
|
if !step_success
|
|
|
|
return L_last, false, history
|
2024-07-15 18:32:04 +00:00
|
|
|
end
|
|
|
|
end
|
|
|
|
|
|
|
|
# return the factorization and its history
|
|
|
|
push!(history.scaled_loss, loss / scale_adjustment)
|
2024-07-18 08:04:40 +00:00
|
|
|
L, loss < tol, history
|
2024-07-15 18:32:04 +00:00
|
|
|
end
|
|
|
|
|
2024-07-02 21:57:57 +00:00
|
|
|
end
|