Record gradient and last line search in history

This commit is contained in:
Aaron Fenyes 2024-07-09 15:00:13 -07:00
parent 5652719642
commit 4d5ea062a3

View File

@ -65,17 +65,23 @@ end
# a type for keeping track of gradient descent history # a type for keeping track of gradient descent history
struct DescentHistory{T} struct DescentHistory{T}
scaled_loss::Array{T} scaled_loss::Array{T}
neg_grad::Array{Matrix{T}}
slope::Array{T} slope::Array{T}
stepsize::Array{T} stepsize::Array{T}
backoff_steps::Array{Int64} backoff_steps::Array{Int64}
last_line_L::Array{Matrix{T}}
last_line_loss::Array{T}
function DescentHistory{T}( function DescentHistory{T}(
scaled_loss = Array{T}(undef, 0), scaled_loss = Array{T}(undef, 0),
neg_grad = Array{Matrix{T}}(undef, 0),
slope = Array{T}(undef, 0), slope = Array{T}(undef, 0),
stepsize = Array{T}(undef, 0), stepsize = Array{T}(undef, 0),
backoff_steps = Int64[] backoff_steps = Int64[],
last_line_L = Array{Matrix{T}}(undef, 0),
last_line_loss = Array{T}(undef, 0)
) where T ) where T
new(scaled_loss, slope, stepsize, backoff_steps) new(scaled_loss, neg_grad, slope, stepsize, backoff_steps, last_line_L, last_line_loss)
end end
end end
@ -119,23 +125,33 @@ function realize_gram(
L_last = L L_last = L
loss_last = loss loss_last = loss
push!(history.scaled_loss, loss / scale_adjustment) push!(history.scaled_loss, loss / scale_adjustment)
push!(history.neg_grad, neg_grad)
push!(history.slope, slope) push!(history.slope, slope)
# find a good step size using backtracking line search # find a good step size using backtracking line search
push!(history.stepsize, 0) push!(history.stepsize, 0)
push!(history.backoff_steps, max_backoff_steps) push!(history.backoff_steps, max_backoff_steps)
empty!(history.last_line_L)
empty!(history.last_line_loss)
for backoff_steps in 0:max_backoff_steps for backoff_steps in 0:max_backoff_steps
history.stepsize[end] = stepsize history.stepsize[end] = stepsize
L = L_last + stepsize * neg_grad L = L_last + stepsize * neg_grad
Δ_proj = proj_diff(gram, L'*Q*L) Δ_proj = proj_diff(gram, L'*Q*L)
loss = dot(Δ_proj, Δ_proj) loss = dot(Δ_proj, Δ_proj)
improvement = loss_last - loss improvement = loss_last - loss
push!(history.last_line_L, L)
push!(history.last_line_loss, loss / scale_adjustment)
if improvement >= target_improvement * stepsize * slope if improvement >= target_improvement * stepsize * slope
history.backoff_steps[end] = backoff_steps history.backoff_steps[end] = backoff_steps
break break
end end
stepsize *= backoff stepsize *= backoff
end end
# [DEBUG] if we've hit a wall, quit
if history.backoff_steps[end] == max_backoff_steps
break
end
end end
# return the factorization and its history # return the factorization and its history