Record optimization history
This commit is contained in:
		
							parent
							
								
									9f8632efb3
								
							
						
					
					
						commit
						ce33bbf418
					
				
					 1 changed files with 69 additions and 25 deletions
				
			
		| 
						 | 
				
			
			@ -62,6 +62,30 @@ impl PartialMatrix {
 | 
			
		|||
    }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// --- descent history ---
 | 
			
		||||
 | 
			
		||||
struct DescentHistory {
 | 
			
		||||
    config: Vec<DMatrix<f64>>,
 | 
			
		||||
    scaled_loss: Vec<f64>,
 | 
			
		||||
    neg_grad: Vec<DMatrix<f64>>,
 | 
			
		||||
    min_eigval: Vec<f64>,
 | 
			
		||||
    base_step: Vec<DMatrix<f64>>,
 | 
			
		||||
    backoff_steps: Vec<i32>
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
impl DescentHistory {
 | 
			
		||||
    fn new() -> DescentHistory {
 | 
			
		||||
        DescentHistory {
 | 
			
		||||
            config: Vec::<DMatrix<f64>>::new(),
 | 
			
		||||
            scaled_loss: Vec::<f64>::new(),
 | 
			
		||||
            neg_grad: Vec::<DMatrix<f64>>::new(),
 | 
			
		||||
            min_eigval: Vec::<f64>::new(),
 | 
			
		||||
            base_step: Vec::<DMatrix<f64>>::new(),
 | 
			
		||||
            backoff_steps: Vec::<i32>::new(),
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// --- gram matrix realization ---
 | 
			
		||||
 | 
			
		||||
// the Lorentz form
 | 
			
		||||
| 
						 | 
				
			
			@ -108,14 +132,14 @@ fn seek_better_config(
 | 
			
		|||
    min_efficiency: f64,
 | 
			
		||||
    backoff: f64,
 | 
			
		||||
    max_backoff_steps: i32
 | 
			
		||||
) -> Option<SearchState> {
 | 
			
		||||
) -> Option<(SearchState, i32)> {
 | 
			
		||||
    let mut rate = 1.0;
 | 
			
		||||
    for _ in 0..max_backoff_steps {
 | 
			
		||||
    for backoff_steps in 0..max_backoff_steps {
 | 
			
		||||
        let trial_config = &state.config + rate * base_step;
 | 
			
		||||
        let trial_state = SearchState::from_config(gram, trial_config);
 | 
			
		||||
        let improvement = state.loss - trial_state.loss;
 | 
			
		||||
        if improvement >= min_efficiency * rate * base_target_improvement {
 | 
			
		||||
            return Some(trial_state);
 | 
			
		||||
            return Some((trial_state, backoff_steps));
 | 
			
		||||
        }
 | 
			
		||||
        rate *= backoff;
 | 
			
		||||
    }
 | 
			
		||||
| 
						 | 
				
			
			@ -134,7 +158,10 @@ fn realize_gram(
 | 
			
		|||
    reg_scale: f64,
 | 
			
		||||
    max_descent_steps: i32,
 | 
			
		||||
    max_backoff_steps: i32
 | 
			
		||||
) -> (DMatrix<f64>, bool) {
 | 
			
		||||
) -> (DMatrix<f64>, bool, DescentHistory) {
 | 
			
		||||
    // start the descent history
 | 
			
		||||
    let mut history = DescentHistory::new();
 | 
			
		||||
    
 | 
			
		||||
    // find the dimension of the search space
 | 
			
		||||
    let element_dim = guess.nrows();
 | 
			
		||||
    let assembly_dim = guess.ncols();
 | 
			
		||||
| 
						 | 
				
			
			@ -153,13 +180,14 @@ fn realize_gram(
 | 
			
		|||
    let mut state = SearchState::from_config(gram, guess);
 | 
			
		||||
    for _ in 0..max_descent_steps {
 | 
			
		||||
        // stop if the loss is tolerably low
 | 
			
		||||
        println!("scaled loss: {}", state.loss / scale_adjustment);
 | 
			
		||||
        /* println!("projected error: {}", state.err_proj); */
 | 
			
		||||
        history.config.push(state.config.clone());
 | 
			
		||||
        history.scaled_loss.push(state.loss / scale_adjustment);
 | 
			
		||||
        if state.loss < tol { break; }
 | 
			
		||||
        
 | 
			
		||||
        // find the negative gradient of the loss function
 | 
			
		||||
        let neg_grad = 4.0 * &*Q * &state.config * &state.err_proj;
 | 
			
		||||
        let mut neg_grad_stacked = neg_grad.clone().reshape_generic(Dyn(total_dim), Const::<1>);
 | 
			
		||||
        history.neg_grad.push(neg_grad.clone());
 | 
			
		||||
        
 | 
			
		||||
        // find the negative Hessian of the loss function
 | 
			
		||||
        let mut hess_cols = Vec::<DVector<f64>>::with_capacity(total_dim);
 | 
			
		||||
| 
						 | 
				
			
			@ -182,10 +210,10 @@ fn realize_gram(
 | 
			
		|||
        
 | 
			
		||||
        // regularize the Hessian
 | 
			
		||||
        let min_eigval = hess.symmetric_eigenvalues().min();
 | 
			
		||||
        /* println!("lowest eigenvalue: {}", min_eigval); */
 | 
			
		||||
        if min_eigval <= 0.0 {
 | 
			
		||||
            hess -= reg_scale * min_eigval * DMatrix::identity(total_dim, total_dim);
 | 
			
		||||
        }
 | 
			
		||||
        history.min_eigval.push(min_eigval);
 | 
			
		||||
        
 | 
			
		||||
        // project the negative gradient and negative Hessian onto the
 | 
			
		||||
        // orthogonal complement of the frozen subspace
 | 
			
		||||
| 
						 | 
				
			
			@ -207,17 +235,21 @@ fn realize_gram(
 | 
			
		|||
        */
 | 
			
		||||
        let base_step_stacked = hess.cholesky().unwrap().solve(&neg_grad_stacked);
 | 
			
		||||
        let base_step = base_step_stacked.reshape_generic(Dyn(element_dim), Dyn(assembly_dim));
 | 
			
		||||
        history.base_step.push(base_step.clone());
 | 
			
		||||
        
 | 
			
		||||
        // use backtracking line search to find a better configuration
 | 
			
		||||
        match seek_better_config(
 | 
			
		||||
            gram, &state, &base_step, neg_grad.dot(&base_step),
 | 
			
		||||
            min_efficiency, backoff, max_backoff_steps
 | 
			
		||||
        ) {
 | 
			
		||||
            Some(better_state) => state = better_state,
 | 
			
		||||
            None => return (state.config, false)
 | 
			
		||||
            Some((better_state, backoff_steps)) => {
 | 
			
		||||
                state = better_state;
 | 
			
		||||
                history.backoff_steps.push(backoff_steps);
 | 
			
		||||
            },
 | 
			
		||||
            None => return (state.config, false, history)
 | 
			
		||||
        };
 | 
			
		||||
    }
 | 
			
		||||
    (state.config, state.loss < tol)
 | 
			
		||||
    (state.config, state.loss < tol, history)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// --- tests ---
 | 
			
		||||
| 
						 | 
				
			
			@ -329,29 +361,33 @@ mod tests {
 | 
			
		|||
        );
 | 
			
		||||
        let frozen: [(usize, usize); 4] = array::from_fn(|k| (3, k));
 | 
			
		||||
        const SCALED_TOL: f64 = 1.0e-12;
 | 
			
		||||
        let (config, success) = realize_gram(
 | 
			
		||||
        let (config, success, history) = realize_gram(
 | 
			
		||||
            &gram, guess, &frozen,
 | 
			
		||||
            SCALED_TOL, 0.5, 0.9, 1.1, 200, 110
 | 
			
		||||
        );
 | 
			
		||||
        let entry_tol = SCALED_TOL.sqrt();
 | 
			
		||||
        let solution_diams = [30.0, 10.0, 6.0, 5.0, 15.0, 10.0, 3.75, 2.5, 2.0 + 8.0/11.0];
 | 
			
		||||
        for (k, diam) in solution_diams.into_iter().enumerate() {
 | 
			
		||||
            assert!((config[(3, k)] - 1.0 / diam).abs() < entry_tol);
 | 
			
		||||
        }
 | 
			
		||||
        print!("\nCompleted Gram matrix:{}", config.tr_mul(&*Q) * &config);
 | 
			
		||||
        let final_state = SearchState::from_config(&gram, config);
 | 
			
		||||
        if success {
 | 
			
		||||
            println!("Target accuracy achieved!");
 | 
			
		||||
        } else {
 | 
			
		||||
            println!("Failed to reach target accuracy");
 | 
			
		||||
        }
 | 
			
		||||
        println!("Loss: {}", final_state.loss);
 | 
			
		||||
        println!("Steps: {}", history.scaled_loss.len() - 1);
 | 
			
		||||
        println!("Loss: {}", history.scaled_loss.last().unwrap());
 | 
			
		||||
        if success {
 | 
			
		||||
            println!("\nChain diameters:");
 | 
			
		||||
            println!("  {} sun (given)", 1.0 / final_state.config[(3, 3)]);
 | 
			
		||||
            println!("  {} sun (given)", 1.0 / config[(3, 3)]);
 | 
			
		||||
            for k in 4..9 {
 | 
			
		||||
                println!("  {} sun", 1.0 / final_state.config[(3, k)]);
 | 
			
		||||
                println!("  {} sun", 1.0 / config[(3, k)]);
 | 
			
		||||
            }
 | 
			
		||||
        }
 | 
			
		||||
        let entry_tol = SCALED_TOL.sqrt();
 | 
			
		||||
        let solution_diams = [30.0, 10.0, 6.0, 5.0, 15.0, 10.0, 3.75, 2.5, 2.0 + 8.0/11.0];
 | 
			
		||||
        for (k, diam) in solution_diams.into_iter().enumerate() {
 | 
			
		||||
            assert!((final_state.config[(3, k)] - 1.0 / diam).abs() < entry_tol);
 | 
			
		||||
        println!("\nStep │ Loss\n─────┼────────────────────────────────");
 | 
			
		||||
        for (step, scaled_loss) in history.scaled_loss.into_iter().enumerate() {
 | 
			
		||||
            println!("{:<4} │ {}", step, scaled_loss);
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
    
 | 
			
		||||
| 
						 | 
				
			
			@ -385,18 +421,22 @@ mod tests {
 | 
			
		|||
            ])
 | 
			
		||||
        };
 | 
			
		||||
        println!();
 | 
			
		||||
        let (config, success) = realize_gram(
 | 
			
		||||
        let (config, success, history) = realize_gram(
 | 
			
		||||
            &gram, guess, &[],
 | 
			
		||||
            1.0e-12, 0.5, 0.9, 1.1, 200, 110
 | 
			
		||||
        );
 | 
			
		||||
        print!("\nCompleted Gram matrix:{}", config.tr_mul(&*Q) * &config);
 | 
			
		||||
        let final_state = SearchState::from_config(&gram, config);
 | 
			
		||||
        if success {
 | 
			
		||||
            println!("Target accuracy achieved!");
 | 
			
		||||
        } else {
 | 
			
		||||
            println!("Failed to reach target accuracy");
 | 
			
		||||
        }
 | 
			
		||||
        println!("Loss: {}", final_state.loss);
 | 
			
		||||
        println!("Steps: {}", history.scaled_loss.len() - 1);
 | 
			
		||||
        println!("Loss: {}", history.scaled_loss.last().unwrap());
 | 
			
		||||
        println!("\nStep │ Loss\n─────┼────────────────────────────────");
 | 
			
		||||
        for (step, scaled_loss) in history.scaled_loss.into_iter().enumerate() {
 | 
			
		||||
            println!("{:<4} │ {}", step, scaled_loss);
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
    
 | 
			
		||||
    #[test]
 | 
			
		||||
| 
						 | 
				
			
			@ -419,18 +459,22 @@ mod tests {
 | 
			
		|||
        ]);
 | 
			
		||||
        let frozen = [(3, 0)];
 | 
			
		||||
        println!();
 | 
			
		||||
        let (config, success) = realize_gram(
 | 
			
		||||
        let (config, success, history) = realize_gram(
 | 
			
		||||
            &gram, guess, &frozen,
 | 
			
		||||
            1.0e-12, 0.5, 0.9, 1.1, 200, 110
 | 
			
		||||
        );
 | 
			
		||||
        print!("\nCompleted Gram matrix:{}", config.tr_mul(&*Q) * &config);
 | 
			
		||||
        print!("Configuration:{}", config);
 | 
			
		||||
        let final_state = SearchState::from_config(&gram, config);
 | 
			
		||||
        if success {
 | 
			
		||||
            println!("Target accuracy achieved!");
 | 
			
		||||
        } else {
 | 
			
		||||
            println!("Failed to reach target accuracy");
 | 
			
		||||
        }
 | 
			
		||||
        println!("Loss: {}", final_state.loss);
 | 
			
		||||
        println!("Steps: {}", history.scaled_loss.len() - 1);
 | 
			
		||||
        println!("Loss: {}", history.scaled_loss.last().unwrap());
 | 
			
		||||
        println!("\nStep │ Loss\n─────┼────────────────────────────────");
 | 
			
		||||
        for (step, scaled_loss) in history.scaled_loss.into_iter().enumerate() {
 | 
			
		||||
            println!("{:<4} │ {}", step, scaled_loss);
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
		Loading…
	
	Add table
		Add a link
		
	
		Reference in a new issue