From 4d5ea062a3dc47ae92472dda23e7f8bdccc6c614 Mon Sep 17 00:00:00 2001 From: Aaron Fenyes Date: Tue, 9 Jul 2024 15:00:13 -0700 Subject: [PATCH] Record gradient and last line search in history --- engine-proto/gram-test/Engine.jl | 20 ++++++++++++++++++-- 1 file changed, 18 insertions(+), 2 deletions(-) diff --git a/engine-proto/gram-test/Engine.jl b/engine-proto/gram-test/Engine.jl index 2fb41e0..0539326 100644 --- a/engine-proto/gram-test/Engine.jl +++ b/engine-proto/gram-test/Engine.jl @@ -65,17 +65,23 @@ end # a type for keeping track of gradient descent history struct DescentHistory{T} scaled_loss::Array{T} + neg_grad::Array{Matrix{T}} slope::Array{T} stepsize::Array{T} backoff_steps::Array{Int64} + last_line_L::Array{Matrix{T}} + last_line_loss::Array{T} function DescentHistory{T}( scaled_loss = Array{T}(undef, 0), + neg_grad = Array{Matrix{T}}(undef, 0), slope = Array{T}(undef, 0), stepsize = Array{T}(undef, 0), - backoff_steps = Int64[] + backoff_steps = Int64[], + last_line_L = Array{Matrix{T}}(undef, 0), + last_line_loss = Array{T}(undef, 0) ) where T - new(scaled_loss, slope, stepsize, backoff_steps) + new(scaled_loss, neg_grad, slope, stepsize, backoff_steps, last_line_L, last_line_loss) end end @@ -119,23 +125,33 @@ function realize_gram( L_last = L loss_last = loss push!(history.scaled_loss, loss / scale_adjustment) + push!(history.neg_grad, neg_grad) push!(history.slope, slope) # find a good step size using backtracking line search push!(history.stepsize, 0) push!(history.backoff_steps, max_backoff_steps) + empty!(history.last_line_L) + empty!(history.last_line_loss) for backoff_steps in 0:max_backoff_steps history.stepsize[end] = stepsize L = L_last + stepsize * neg_grad Δ_proj = proj_diff(gram, L'*Q*L) loss = dot(Δ_proj, Δ_proj) improvement = loss_last - loss + push!(history.last_line_L, L) + push!(history.last_line_loss, loss / scale_adjustment) if improvement >= target_improvement * stepsize * slope history.backoff_steps[end] = backoff_steps break end stepsize *= backoff end + + # [DEBUG] if we've hit a wall, quit + if history.backoff_steps[end] == max_backoff_steps + break + end end # return the factorization and its history