Record optimization history
This commit is contained in:
parent
9f8632efb3
commit
ce33bbf418
@ -62,6 +62,30 @@ impl PartialMatrix {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// --- descent history ---
|
||||||
|
|
||||||
|
struct DescentHistory {
|
||||||
|
config: Vec<DMatrix<f64>>,
|
||||||
|
scaled_loss: Vec<f64>,
|
||||||
|
neg_grad: Vec<DMatrix<f64>>,
|
||||||
|
min_eigval: Vec<f64>,
|
||||||
|
base_step: Vec<DMatrix<f64>>,
|
||||||
|
backoff_steps: Vec<i32>
|
||||||
|
}
|
||||||
|
|
||||||
|
impl DescentHistory {
|
||||||
|
fn new() -> DescentHistory {
|
||||||
|
DescentHistory {
|
||||||
|
config: Vec::<DMatrix<f64>>::new(),
|
||||||
|
scaled_loss: Vec::<f64>::new(),
|
||||||
|
neg_grad: Vec::<DMatrix<f64>>::new(),
|
||||||
|
min_eigval: Vec::<f64>::new(),
|
||||||
|
base_step: Vec::<DMatrix<f64>>::new(),
|
||||||
|
backoff_steps: Vec::<i32>::new(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// --- gram matrix realization ---
|
// --- gram matrix realization ---
|
||||||
|
|
||||||
// the Lorentz form
|
// the Lorentz form
|
||||||
@ -108,14 +132,14 @@ fn seek_better_config(
|
|||||||
min_efficiency: f64,
|
min_efficiency: f64,
|
||||||
backoff: f64,
|
backoff: f64,
|
||||||
max_backoff_steps: i32
|
max_backoff_steps: i32
|
||||||
) -> Option<SearchState> {
|
) -> Option<(SearchState, i32)> {
|
||||||
let mut rate = 1.0;
|
let mut rate = 1.0;
|
||||||
for _ in 0..max_backoff_steps {
|
for backoff_steps in 0..max_backoff_steps {
|
||||||
let trial_config = &state.config + rate * base_step;
|
let trial_config = &state.config + rate * base_step;
|
||||||
let trial_state = SearchState::from_config(gram, trial_config);
|
let trial_state = SearchState::from_config(gram, trial_config);
|
||||||
let improvement = state.loss - trial_state.loss;
|
let improvement = state.loss - trial_state.loss;
|
||||||
if improvement >= min_efficiency * rate * base_target_improvement {
|
if improvement >= min_efficiency * rate * base_target_improvement {
|
||||||
return Some(trial_state);
|
return Some((trial_state, backoff_steps));
|
||||||
}
|
}
|
||||||
rate *= backoff;
|
rate *= backoff;
|
||||||
}
|
}
|
||||||
@ -134,7 +158,10 @@ fn realize_gram(
|
|||||||
reg_scale: f64,
|
reg_scale: f64,
|
||||||
max_descent_steps: i32,
|
max_descent_steps: i32,
|
||||||
max_backoff_steps: i32
|
max_backoff_steps: i32
|
||||||
) -> (DMatrix<f64>, bool) {
|
) -> (DMatrix<f64>, bool, DescentHistory) {
|
||||||
|
// start the descent history
|
||||||
|
let mut history = DescentHistory::new();
|
||||||
|
|
||||||
// find the dimension of the search space
|
// find the dimension of the search space
|
||||||
let element_dim = guess.nrows();
|
let element_dim = guess.nrows();
|
||||||
let assembly_dim = guess.ncols();
|
let assembly_dim = guess.ncols();
|
||||||
@ -153,13 +180,14 @@ fn realize_gram(
|
|||||||
let mut state = SearchState::from_config(gram, guess);
|
let mut state = SearchState::from_config(gram, guess);
|
||||||
for _ in 0..max_descent_steps {
|
for _ in 0..max_descent_steps {
|
||||||
// stop if the loss is tolerably low
|
// stop if the loss is tolerably low
|
||||||
println!("scaled loss: {}", state.loss / scale_adjustment);
|
history.config.push(state.config.clone());
|
||||||
/* println!("projected error: {}", state.err_proj); */
|
history.scaled_loss.push(state.loss / scale_adjustment);
|
||||||
if state.loss < tol { break; }
|
if state.loss < tol { break; }
|
||||||
|
|
||||||
// find the negative gradient of the loss function
|
// find the negative gradient of the loss function
|
||||||
let neg_grad = 4.0 * &*Q * &state.config * &state.err_proj;
|
let neg_grad = 4.0 * &*Q * &state.config * &state.err_proj;
|
||||||
let mut neg_grad_stacked = neg_grad.clone().reshape_generic(Dyn(total_dim), Const::<1>);
|
let mut neg_grad_stacked = neg_grad.clone().reshape_generic(Dyn(total_dim), Const::<1>);
|
||||||
|
history.neg_grad.push(neg_grad.clone());
|
||||||
|
|
||||||
// find the negative Hessian of the loss function
|
// find the negative Hessian of the loss function
|
||||||
let mut hess_cols = Vec::<DVector<f64>>::with_capacity(total_dim);
|
let mut hess_cols = Vec::<DVector<f64>>::with_capacity(total_dim);
|
||||||
@ -182,10 +210,10 @@ fn realize_gram(
|
|||||||
|
|
||||||
// regularize the Hessian
|
// regularize the Hessian
|
||||||
let min_eigval = hess.symmetric_eigenvalues().min();
|
let min_eigval = hess.symmetric_eigenvalues().min();
|
||||||
/* println!("lowest eigenvalue: {}", min_eigval); */
|
|
||||||
if min_eigval <= 0.0 {
|
if min_eigval <= 0.0 {
|
||||||
hess -= reg_scale * min_eigval * DMatrix::identity(total_dim, total_dim);
|
hess -= reg_scale * min_eigval * DMatrix::identity(total_dim, total_dim);
|
||||||
}
|
}
|
||||||
|
history.min_eigval.push(min_eigval);
|
||||||
|
|
||||||
// project the negative gradient and negative Hessian onto the
|
// project the negative gradient and negative Hessian onto the
|
||||||
// orthogonal complement of the frozen subspace
|
// orthogonal complement of the frozen subspace
|
||||||
@ -207,17 +235,21 @@ fn realize_gram(
|
|||||||
*/
|
*/
|
||||||
let base_step_stacked = hess.cholesky().unwrap().solve(&neg_grad_stacked);
|
let base_step_stacked = hess.cholesky().unwrap().solve(&neg_grad_stacked);
|
||||||
let base_step = base_step_stacked.reshape_generic(Dyn(element_dim), Dyn(assembly_dim));
|
let base_step = base_step_stacked.reshape_generic(Dyn(element_dim), Dyn(assembly_dim));
|
||||||
|
history.base_step.push(base_step.clone());
|
||||||
|
|
||||||
// use backtracking line search to find a better configuration
|
// use backtracking line search to find a better configuration
|
||||||
match seek_better_config(
|
match seek_better_config(
|
||||||
gram, &state, &base_step, neg_grad.dot(&base_step),
|
gram, &state, &base_step, neg_grad.dot(&base_step),
|
||||||
min_efficiency, backoff, max_backoff_steps
|
min_efficiency, backoff, max_backoff_steps
|
||||||
) {
|
) {
|
||||||
Some(better_state) => state = better_state,
|
Some((better_state, backoff_steps)) => {
|
||||||
None => return (state.config, false)
|
state = better_state;
|
||||||
|
history.backoff_steps.push(backoff_steps);
|
||||||
|
},
|
||||||
|
None => return (state.config, false, history)
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
(state.config, state.loss < tol)
|
(state.config, state.loss < tol, history)
|
||||||
}
|
}
|
||||||
|
|
||||||
// --- tests ---
|
// --- tests ---
|
||||||
@ -329,29 +361,33 @@ mod tests {
|
|||||||
);
|
);
|
||||||
let frozen: [(usize, usize); 4] = array::from_fn(|k| (3, k));
|
let frozen: [(usize, usize); 4] = array::from_fn(|k| (3, k));
|
||||||
const SCALED_TOL: f64 = 1.0e-12;
|
const SCALED_TOL: f64 = 1.0e-12;
|
||||||
let (config, success) = realize_gram(
|
let (config, success, history) = realize_gram(
|
||||||
&gram, guess, &frozen,
|
&gram, guess, &frozen,
|
||||||
SCALED_TOL, 0.5, 0.9, 1.1, 200, 110
|
SCALED_TOL, 0.5, 0.9, 1.1, 200, 110
|
||||||
);
|
);
|
||||||
|
let entry_tol = SCALED_TOL.sqrt();
|
||||||
|
let solution_diams = [30.0, 10.0, 6.0, 5.0, 15.0, 10.0, 3.75, 2.5, 2.0 + 8.0/11.0];
|
||||||
|
for (k, diam) in solution_diams.into_iter().enumerate() {
|
||||||
|
assert!((config[(3, k)] - 1.0 / diam).abs() < entry_tol);
|
||||||
|
}
|
||||||
print!("\nCompleted Gram matrix:{}", config.tr_mul(&*Q) * &config);
|
print!("\nCompleted Gram matrix:{}", config.tr_mul(&*Q) * &config);
|
||||||
let final_state = SearchState::from_config(&gram, config);
|
|
||||||
if success {
|
if success {
|
||||||
println!("Target accuracy achieved!");
|
println!("Target accuracy achieved!");
|
||||||
} else {
|
} else {
|
||||||
println!("Failed to reach target accuracy");
|
println!("Failed to reach target accuracy");
|
||||||
}
|
}
|
||||||
println!("Loss: {}", final_state.loss);
|
println!("Steps: {}", history.scaled_loss.len() - 1);
|
||||||
|
println!("Loss: {}", history.scaled_loss.last().unwrap());
|
||||||
if success {
|
if success {
|
||||||
println!("\nChain diameters:");
|
println!("\nChain diameters:");
|
||||||
println!(" {} sun (given)", 1.0 / final_state.config[(3, 3)]);
|
println!(" {} sun (given)", 1.0 / config[(3, 3)]);
|
||||||
for k in 4..9 {
|
for k in 4..9 {
|
||||||
println!(" {} sun", 1.0 / final_state.config[(3, k)]);
|
println!(" {} sun", 1.0 / config[(3, k)]);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
let entry_tol = SCALED_TOL.sqrt();
|
println!("\nStep │ Loss\n─────┼────────────────────────────────");
|
||||||
let solution_diams = [30.0, 10.0, 6.0, 5.0, 15.0, 10.0, 3.75, 2.5, 2.0 + 8.0/11.0];
|
for (step, scaled_loss) in history.scaled_loss.into_iter().enumerate() {
|
||||||
for (k, diam) in solution_diams.into_iter().enumerate() {
|
println!("{:<4} │ {}", step, scaled_loss);
|
||||||
assert!((final_state.config[(3, k)] - 1.0 / diam).abs() < entry_tol);
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -385,18 +421,22 @@ mod tests {
|
|||||||
])
|
])
|
||||||
};
|
};
|
||||||
println!();
|
println!();
|
||||||
let (config, success) = realize_gram(
|
let (config, success, history) = realize_gram(
|
||||||
&gram, guess, &[],
|
&gram, guess, &[],
|
||||||
1.0e-12, 0.5, 0.9, 1.1, 200, 110
|
1.0e-12, 0.5, 0.9, 1.1, 200, 110
|
||||||
);
|
);
|
||||||
print!("\nCompleted Gram matrix:{}", config.tr_mul(&*Q) * &config);
|
print!("\nCompleted Gram matrix:{}", config.tr_mul(&*Q) * &config);
|
||||||
let final_state = SearchState::from_config(&gram, config);
|
|
||||||
if success {
|
if success {
|
||||||
println!("Target accuracy achieved!");
|
println!("Target accuracy achieved!");
|
||||||
} else {
|
} else {
|
||||||
println!("Failed to reach target accuracy");
|
println!("Failed to reach target accuracy");
|
||||||
}
|
}
|
||||||
println!("Loss: {}", final_state.loss);
|
println!("Steps: {}", history.scaled_loss.len() - 1);
|
||||||
|
println!("Loss: {}", history.scaled_loss.last().unwrap());
|
||||||
|
println!("\nStep │ Loss\n─────┼────────────────────────────────");
|
||||||
|
for (step, scaled_loss) in history.scaled_loss.into_iter().enumerate() {
|
||||||
|
println!("{:<4} │ {}", step, scaled_loss);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
@ -419,18 +459,22 @@ mod tests {
|
|||||||
]);
|
]);
|
||||||
let frozen = [(3, 0)];
|
let frozen = [(3, 0)];
|
||||||
println!();
|
println!();
|
||||||
let (config, success) = realize_gram(
|
let (config, success, history) = realize_gram(
|
||||||
&gram, guess, &frozen,
|
&gram, guess, &frozen,
|
||||||
1.0e-12, 0.5, 0.9, 1.1, 200, 110
|
1.0e-12, 0.5, 0.9, 1.1, 200, 110
|
||||||
);
|
);
|
||||||
print!("\nCompleted Gram matrix:{}", config.tr_mul(&*Q) * &config);
|
print!("\nCompleted Gram matrix:{}", config.tr_mul(&*Q) * &config);
|
||||||
print!("Configuration:{}", config);
|
print!("Configuration:{}", config);
|
||||||
let final_state = SearchState::from_config(&gram, config);
|
|
||||||
if success {
|
if success {
|
||||||
println!("Target accuracy achieved!");
|
println!("Target accuracy achieved!");
|
||||||
} else {
|
} else {
|
||||||
println!("Failed to reach target accuracy");
|
println!("Failed to reach target accuracy");
|
||||||
}
|
}
|
||||||
println!("Loss: {}", final_state.loss);
|
println!("Steps: {}", history.scaled_loss.len() - 1);
|
||||||
|
println!("Loss: {}", history.scaled_loss.last().unwrap());
|
||||||
|
println!("\nStep │ Loss\n─────┼────────────────────────────────");
|
||||||
|
for (step, scaled_loss) in history.scaled_loss.into_iter().enumerate() {
|
||||||
|
println!("{:<4} │ {}", step, scaled_loss);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
Loading…
Reference in New Issue
Block a user