Engine prototype #13

Merged
glen merged 133 commits from engine-proto into main 2024-10-21 03:18:48 +00:00
4 changed files with 41 additions and 82 deletions
Showing only changes of commit 7b3efbc385 - Show all commits

View File

@ -6,7 +6,9 @@ using SparseArrays
using Random
using Optim
export rand_on_shell, Q, DescentHistory, realize_gram
export
rand_on_shell, Q, DescentHistory,
realize_gram_gradient, realize_gram_newton, realize_gram_optim, realize_gram
# === guessing ===
@ -82,7 +84,7 @@ struct DescentHistory{T}
hess::Array{Hermitian{T, Matrix{T}}}
slope::Array{T}
stepsize::Array{T}
used_grad::Array{Bool}
positive::Array{Bool}
backoff_steps::Array{Int64}
last_line_L::Array{Matrix{T}}
last_line_loss::Array{T}
@ -94,12 +96,12 @@ struct DescentHistory{T}
base_step = Array{Matrix{T}}(undef, 0),
slope = Array{T}(undef, 0),
stepsize = Array{T}(undef, 0),
used_grad = Bool[],
positive = Bool[],
backoff_steps = Int64[],
last_line_L = Array{Matrix{T}}(undef, 0),
last_line_loss = Array{T}(undef, 0)
) where T
new(scaled_loss, neg_grad, hess, base_step, slope, stepsize, used_grad, backoff_steps, last_line_L, last_line_loss)
new(scaled_loss, neg_grad, hess, base_step, slope, stepsize, positive, backoff_steps, last_line_L, last_line_loss)
end
end
@ -305,7 +307,7 @@ end
function realize_gram(
gram::SparseMatrixCSC{T, <:Any},
guess::Matrix{T};
scaled_tol = 1e-30,
scaled_tol = 1e-16,
min_efficiency = 0.5,
init_rate = 1.0,
backoff = 0.9,
@ -358,54 +360,14 @@ function realize_gram(
hess = Hermitian(hess)
push!(history.hess, hess)
# choose a base step: the Newton step if the Hessian is non-singular, and
# the gradient descent direction otherwise
#=
sing = false
base_step = try
reshape(hess \ reshape(neg_grad, total_dim), dims)
catch ex
if isa(ex, SingularException)
sing = true
normalize(neg_grad)
else
throw(ex)
end
end
=#
#=
if !sing
rate = one(T)
end
=#
#=
if cond(Float64.(hess)) < 1e5
sing = false
base_step = reshape(hess \ reshape(neg_grad, total_dim), dims)
else
sing = true
base_step = normalize(neg_grad)
end
=#
#=
if cond(Float64.(hess)) > 1e3
sing = true
hess += big"1e-5"*I
else
sing = false
end
base_step = reshape(hess \ reshape(neg_grad, total_dim), dims)
=#
# regularize the Hessian
min_eigval = minimum(eigvals(hess))
if min_eigval < 0
push!(history.positive, min_eigval > 0)
if min_eigval <= 0
hess -= reg_scale * min_eigval * I
end
push!(history.used_grad, false)
base_step = reshape(hess \ reshape(neg_grad, total_dim), dims)
push!(history.base_step, base_step)
#=
push!(history.used_grad, sing)
=#
# store the current position, loss, and slope
L_last = L
@ -420,12 +382,9 @@ function realize_gram(
empty!(history.last_line_L)
empty!(history.last_line_loss)
rate = one(T)
step_success = false
for backoff_steps in 0:max_backoff_steps
history.stepsize[end] = rate
# try Newton step, but not on the first step. doing at least one step of
# gradient descent seems to help prevent getting stuck, for some reason?
if step > 0
L = L_last + rate * base_step
Δ_proj = proj_diff(gram, L'*Q*L)
loss = dot(Δ_proj, Δ_proj)
@ -434,36 +393,21 @@ function realize_gram(
push!(history.last_line_loss, loss / scale_adjustment)
if improvement >= min_efficiency * rate * dot(neg_grad, base_step)
history.backoff_steps[end] = backoff_steps
step_success = true
break
end
end
# try gradient descent step
slope = norm(neg_grad)
dir = neg_grad / slope
L = L_last + rate * grad_rate * dir
Δ_proj = proj_diff(gram, L'*Q*L)
loss = dot(Δ_proj, Δ_proj)
improvement = loss_last - loss
if improvement >= min_efficiency * rate * grad_rate * slope
grad_rate *= rate
history.used_grad[end] = true
history.backoff_steps[end] = backoff_steps
break
end
rate *= backoff
end
# [DEBUG] if we've hit a wall, quit
if history.backoff_steps[end] == max_backoff_steps
return L_last, history
# if we've hit a wall, quit
if !step_success
return L_last, false, history
end
end
# return the factorization and its history
push!(history.scaled_loss, loss / scale_adjustment)
L, history
L, true, history
end
end

View File

@ -86,7 +86,7 @@ L, history = Engine.realize_gram_gradient(gram, guess, scaled_tol = 0.01)
L_pol, history_pol = Engine.realize_gram_newton(gram, L, rate = 0.3, scaled_tol = 1e-9)
L_pol2, history_pol2 = Engine.realize_gram_newton(gram, L_pol)
=#
L, history = Engine.realize_gram(Float64.(gram), Float64.(guess))
L, success, history = Engine.realize_gram(Float64.(gram), Float64.(guess))
completed_gram = L'*Engine.Q*L
println("Completed Gram matrix:\n")
display(completed_gram)
@ -99,5 +99,10 @@ println(
)
println("Loss: ", history_pol2.scaled_loss[end], "\n")
=#
println("\nSteps: ", size(history.scaled_loss, 1))
if success
println("\nTarget accuracy achieved!")
else
println("\nFailed to reach target accuracy")
end
println("Steps: ", size(history.scaled_loss, 1))
println("Loss: ", history.scaled_loss[end], "\n")

View File

@ -52,7 +52,7 @@ guess = hcat(
L, history = Engine.realize_gram_gradient(gram, guess, scaled_tol = 0.01)
L_pol, history_pol = Engine.realize_gram_newton(gram, L)
=#
L, history = Engine.realize_gram(Float64.(gram), Float64.(guess))
L, success, history = Engine.realize_gram(Float64.(gram), Float64.(guess))
completed_gram = L'*Engine.Q*L
println("Completed Gram matrix:\n")
display(completed_gram)
@ -60,7 +60,12 @@ display(completed_gram)
println("\nSteps: ", size(history.scaled_loss, 1), " + ", size(history_pol.scaled_loss, 1))
println("Loss: ", history_pol.scaled_loss[end], "\n")
=#
println("\nSteps: ", size(history.scaled_loss, 1))
if success
println("\nTarget accuracy achieved!")
else
println("\nFailed to reach target accuracy")
end
println("Steps: ", size(history.scaled_loss, 1))
println("Loss: ", history.scaled_loss[end], "\n")
# === algebraic check ===

View File

@ -37,9 +37,14 @@ guess = sqrt(1/BigFloat(3)) * BigFloat[
#=
L, history = Engine.realize_gram_newton(gram, guess)
=#
L, history = Engine.realize_gram(gram, guess, max_descent_steps = 50)
L, success, history = Engine.realize_gram(gram, guess)
completed_gram = L'*Engine.Q*L
println("Completed Gram matrix:\n")
display(completed_gram)
println("\nSteps: ", size(history.scaled_loss, 1))
if success
println("\nTarget accuracy achieved!")
else
println("\nFailed to reach target accuracy")
end
println("Steps: ", size(history.scaled_loss, 1))
println("Loss: ", history.scaled_loss[end], "\n")