Record gradient and last line search in history
This commit is contained in:
parent
5652719642
commit
4d5ea062a3
@ -65,17 +65,23 @@ end
|
||||
# a type for keeping track of gradient descent history
|
||||
struct DescentHistory{T}
|
||||
scaled_loss::Array{T}
|
||||
neg_grad::Array{Matrix{T}}
|
||||
slope::Array{T}
|
||||
stepsize::Array{T}
|
||||
backoff_steps::Array{Int64}
|
||||
last_line_L::Array{Matrix{T}}
|
||||
last_line_loss::Array{T}
|
||||
|
||||
function DescentHistory{T}(
|
||||
scaled_loss = Array{T}(undef, 0),
|
||||
neg_grad = Array{Matrix{T}}(undef, 0),
|
||||
slope = Array{T}(undef, 0),
|
||||
stepsize = Array{T}(undef, 0),
|
||||
backoff_steps = Int64[]
|
||||
backoff_steps = Int64[],
|
||||
last_line_L = Array{Matrix{T}}(undef, 0),
|
||||
last_line_loss = Array{T}(undef, 0)
|
||||
) where T
|
||||
new(scaled_loss, slope, stepsize, backoff_steps)
|
||||
new(scaled_loss, neg_grad, slope, stepsize, backoff_steps, last_line_L, last_line_loss)
|
||||
end
|
||||
end
|
||||
|
||||
@ -119,23 +125,33 @@ function realize_gram(
|
||||
L_last = L
|
||||
loss_last = loss
|
||||
push!(history.scaled_loss, loss / scale_adjustment)
|
||||
push!(history.neg_grad, neg_grad)
|
||||
push!(history.slope, slope)
|
||||
|
||||
# find a good step size using backtracking line search
|
||||
push!(history.stepsize, 0)
|
||||
push!(history.backoff_steps, max_backoff_steps)
|
||||
empty!(history.last_line_L)
|
||||
empty!(history.last_line_loss)
|
||||
for backoff_steps in 0:max_backoff_steps
|
||||
history.stepsize[end] = stepsize
|
||||
L = L_last + stepsize * neg_grad
|
||||
Δ_proj = proj_diff(gram, L'*Q*L)
|
||||
loss = dot(Δ_proj, Δ_proj)
|
||||
improvement = loss_last - loss
|
||||
push!(history.last_line_L, L)
|
||||
push!(history.last_line_loss, loss / scale_adjustment)
|
||||
if improvement >= target_improvement * stepsize * slope
|
||||
history.backoff_steps[end] = backoff_steps
|
||||
break
|
||||
end
|
||||
stepsize *= backoff
|
||||
end
|
||||
|
||||
# [DEBUG] if we've hit a wall, quit
|
||||
if history.backoff_steps[end] == max_backoff_steps
|
||||
break
|
||||
end
|
||||
end
|
||||
|
||||
# return the factorization and its history
|
||||
|
Loading…
Reference in New Issue
Block a user