Track slope in gradient descent history
This commit is contained in:
parent
93dd05c317
commit
610fc451f0
@ -65,15 +65,17 @@ end
|
|||||||
# a type for keeping track of gradient descent history
|
# a type for keeping track of gradient descent history
|
||||||
struct DescentHistory{T}
|
struct DescentHistory{T}
|
||||||
scaled_loss::Array{T}
|
scaled_loss::Array{T}
|
||||||
|
slope::Array{T}
|
||||||
stepsize::Array{T}
|
stepsize::Array{T}
|
||||||
backoff_steps::Array{Int64}
|
backoff_steps::Array{Int64}
|
||||||
|
|
||||||
function DescentHistory{T}(
|
function DescentHistory{T}(
|
||||||
scaled_loss = Array{T}(undef, 0),
|
scaled_loss = Array{T}(undef, 0),
|
||||||
|
slope = Array{T}(undef, 0),
|
||||||
stepsize = Array{T}(undef, 0),
|
stepsize = Array{T}(undef, 0),
|
||||||
backoff_steps = Int64[]
|
backoff_steps = Int64[]
|
||||||
) where T
|
) where T
|
||||||
new(scaled_loss, stepsize, backoff_steps)
|
new(scaled_loss, slope, stepsize, backoff_steps)
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
@ -113,10 +115,11 @@ function realize_gram(
|
|||||||
neg_grad = 4*Q*L*Δ_proj
|
neg_grad = 4*Q*L*Δ_proj
|
||||||
slope = norm(neg_grad)
|
slope = norm(neg_grad)
|
||||||
|
|
||||||
# store current position and loss
|
# store current position, loss, and slope
|
||||||
L_last = L
|
L_last = L
|
||||||
loss_last = loss
|
loss_last = loss
|
||||||
push!(history.scaled_loss, loss / scale_adjustment)
|
push!(history.scaled_loss, loss / scale_adjustment)
|
||||||
|
push!(history.slope, slope)
|
||||||
|
|
||||||
# find a good step size using backtracking line search
|
# find a good step size using backtracking line search
|
||||||
push!(history.stepsize, 0)
|
push!(history.stepsize, 0)
|
||||||
|
Loading…
Reference in New Issue
Block a user