Port the Gram matrix realization routine to Rust
Validate with the process inspection example tests, which print out their results and optimization histories when run one at a time in `--nocapture` mode.
This commit is contained in:
		
							parent
							
								
									e59d60bf77
								
							
						
					
					
						commit
						9fe03264ab
					
				
					 3 changed files with 328 additions and 1 deletions
				
			
		| 
						 | 
				
			
			@ -1,4 +1,11 @@
 | 
			
		|||
use nalgebra::DVector;
 | 
			
		||||
use lazy_static::lazy_static;
 | 
			
		||||
use nalgebra::{Const, DMatrix, DVector, Dyn};
 | 
			
		||||
 | 
			
		||||
// --- elements ---
 | 
			
		||||
 | 
			
		||||
pub fn point(x: f64, y: f64, z: f64) -> DVector<f64> {
 | 
			
		||||
    DVector::from_column_slice(&[x, y, z, 0.5, 0.5*(x*x + y*y + z*z)])
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// the sphere with the given center and radius, with inward-pointing normals
 | 
			
		||||
pub fn sphere(center_x: f64, center_y: f64, center_z: f64, radius: f64) -> DVector<f64> {
 | 
			
		||||
| 
						 | 
				
			
			@ -24,4 +31,316 @@ pub fn sphere_with_offset(dir_x: f64, dir_y: f64, dir_z: f64, off: f64, curv: f6
 | 
			
		|||
        0.5 * curv,
 | 
			
		||||
        off * (1.0 + 0.5 * off * curv)
 | 
			
		||||
    ])
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// --- partial matrices ---
 | 
			
		||||
 | 
			
		||||
struct MatrixEntry {
 | 
			
		||||
    index: (usize, usize),
 | 
			
		||||
    val: f64
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
struct PartialMatrix(Vec<MatrixEntry>);
 | 
			
		||||
 | 
			
		||||
impl PartialMatrix {
 | 
			
		||||
    fn proj(&self, a: &DMatrix<f64>) -> DMatrix<f64> {
 | 
			
		||||
        let mut result = DMatrix::<f64>::zeros(a.nrows(), a.ncols());
 | 
			
		||||
        let PartialMatrix(entries) = self;
 | 
			
		||||
        for ent in entries {
 | 
			
		||||
            result[ent.index] = a[ent.index];
 | 
			
		||||
        }
 | 
			
		||||
        result
 | 
			
		||||
    }
 | 
			
		||||
    
 | 
			
		||||
    fn sub_proj(&self, rhs: &DMatrix<f64>) -> DMatrix<f64> {
 | 
			
		||||
        let mut result = DMatrix::<f64>::zeros(rhs.nrows(), rhs.ncols());
 | 
			
		||||
        let PartialMatrix(entries) = self;
 | 
			
		||||
        for ent in entries {
 | 
			
		||||
            result[ent.index] = ent.val - rhs[ent.index];
 | 
			
		||||
        }
 | 
			
		||||
        result
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// --- gram matrix realization ---
 | 
			
		||||
 | 
			
		||||
// the Lorentz form
 | 
			
		||||
lazy_static! {
 | 
			
		||||
    static ref Q: DMatrix<f64> = DMatrix::from_row_slice(5, 5, &[
 | 
			
		||||
        1.0, 0.0, 0.0,  0.0,  0.0,
 | 
			
		||||
        0.0, 1.0, 0.0,  0.0,  0.0,
 | 
			
		||||
        0.0, 0.0, 1.0,  0.0,  0.0,
 | 
			
		||||
        0.0, 0.0, 0.0,  0.0, -2.0,
 | 
			
		||||
        0.0, 0.0, 0.0, -2.0,  0.0
 | 
			
		||||
    ]);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
struct SearchState {
 | 
			
		||||
    config: DMatrix<f64>,
 | 
			
		||||
    err_proj: DMatrix<f64>,
 | 
			
		||||
    loss: f64
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
impl SearchState {
 | 
			
		||||
    fn from_config(gram: &PartialMatrix, config: DMatrix<f64>) -> SearchState {
 | 
			
		||||
        let err_proj = gram.sub_proj(&(config.tr_mul(&*Q) * &config));
 | 
			
		||||
        let loss = err_proj.norm_squared();
 | 
			
		||||
        SearchState {
 | 
			
		||||
            config: config,
 | 
			
		||||
            err_proj: err_proj,
 | 
			
		||||
            loss: loss
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
fn basis_matrix(index: (usize, usize), nrows: usize, ncols: usize) -> DMatrix<f64> {
 | 
			
		||||
    let mut result = DMatrix::<f64>::zeros(nrows, ncols);
 | 
			
		||||
    result[index] = 1.0;
 | 
			
		||||
    result
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// use backtracking line search to find a better configuration
 | 
			
		||||
fn seek_better_config(
 | 
			
		||||
    gram: &PartialMatrix,
 | 
			
		||||
    state: &SearchState,
 | 
			
		||||
    base_step: &DMatrix<f64>,
 | 
			
		||||
    base_target_improvement: f64,
 | 
			
		||||
    min_efficiency: f64,
 | 
			
		||||
    backoff: f64,
 | 
			
		||||
    max_backoff_steps: i32
 | 
			
		||||
) -> Option<SearchState> {
 | 
			
		||||
    let mut rate = 1.0;
 | 
			
		||||
    for _ in 0..max_backoff_steps {
 | 
			
		||||
        let trial_config = &state.config + rate * base_step;
 | 
			
		||||
        let trial_state = SearchState::from_config(gram, trial_config);
 | 
			
		||||
        let improvement = state.loss - trial_state.loss;
 | 
			
		||||
        if improvement >= min_efficiency * rate * base_target_improvement {
 | 
			
		||||
            return Some(trial_state);
 | 
			
		||||
        }
 | 
			
		||||
        rate *= backoff;
 | 
			
		||||
    }
 | 
			
		||||
    None
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// seek a matrix `config` for which `config' * Q * config` matches the partial
 | 
			
		||||
// matrix `gram`. use gradient descent starting from `guess`
 | 
			
		||||
fn realize_gram(
 | 
			
		||||
    gram: &PartialMatrix,
 | 
			
		||||
    guess: DMatrix<f64>,
 | 
			
		||||
    frozen: &[(usize, usize)],
 | 
			
		||||
    scaled_tol: f64,
 | 
			
		||||
    min_efficiency: f64,
 | 
			
		||||
    backoff: f64,
 | 
			
		||||
    reg_scale: f64,
 | 
			
		||||
    max_descent_steps: i32,
 | 
			
		||||
    max_backoff_steps: i32
 | 
			
		||||
) -> (DMatrix<f64>, bool) {
 | 
			
		||||
    // find the dimension of the search space
 | 
			
		||||
    let element_dim = guess.nrows();
 | 
			
		||||
    let assembly_dim = guess.ncols();
 | 
			
		||||
    let total_dim = element_dim * assembly_dim;
 | 
			
		||||
    
 | 
			
		||||
    // scale the tolerance
 | 
			
		||||
    let scale_adjustment = ((guess.ncols() - frozen.len()) as f64).sqrt();
 | 
			
		||||
    let tol = scale_adjustment * scaled_tol;
 | 
			
		||||
    
 | 
			
		||||
    // convert the frozen indices to stacked format
 | 
			
		||||
    let frozen_stacked: Vec<usize> = frozen.into_iter().map(
 | 
			
		||||
        |index| index.1*element_dim + index.0
 | 
			
		||||
    ).collect();
 | 
			
		||||
    
 | 
			
		||||
    // use Newton's method with backtracking and gradient descent backup
 | 
			
		||||
    let mut state = SearchState::from_config(gram, guess);
 | 
			
		||||
    for _ in 0..max_descent_steps {
 | 
			
		||||
        // stop if the loss is tolerably low
 | 
			
		||||
        println!("loss: {}", state.loss);
 | 
			
		||||
        /*println!("projected error: {}", state.err_proj);*/
 | 
			
		||||
        if state.loss < tol { break; }
 | 
			
		||||
        
 | 
			
		||||
        // find the negative gradient of the loss function
 | 
			
		||||
        let neg_grad = 4.0 * &*Q * &state.config * &state.err_proj;
 | 
			
		||||
        let mut neg_grad_stacked = neg_grad.clone().reshape_generic(Dyn(total_dim), Const::<1>);
 | 
			
		||||
        
 | 
			
		||||
        // find the negative Hessian of the loss function
 | 
			
		||||
        let mut hess_cols = Vec::<DVector<f64>>::with_capacity(total_dim);
 | 
			
		||||
        for col in 0..assembly_dim {
 | 
			
		||||
            for row in 0..element_dim {
 | 
			
		||||
                let index = (row, col);
 | 
			
		||||
                let basis_mat = basis_matrix(index, element_dim, assembly_dim);
 | 
			
		||||
                let neg_d_err =
 | 
			
		||||
                    basis_mat.tr_mul(&*Q) * &state.config
 | 
			
		||||
                    + state.config.tr_mul(&*Q) * &basis_mat;
 | 
			
		||||
                let neg_d_err_proj = gram.proj(&neg_d_err);
 | 
			
		||||
                let deriv_grad = 4.0 * &*Q * (
 | 
			
		||||
                    -&basis_mat * &state.err_proj
 | 
			
		||||
                    + &state.config * &neg_d_err_proj
 | 
			
		||||
                );
 | 
			
		||||
                hess_cols.push(deriv_grad.reshape_generic(Dyn(total_dim), Const::<1>));
 | 
			
		||||
            }
 | 
			
		||||
        }
 | 
			
		||||
        let mut hess = DMatrix::from_columns(hess_cols.as_slice());
 | 
			
		||||
        
 | 
			
		||||
        // regularize the Hessian
 | 
			
		||||
        let min_eigval = hess.symmetric_eigenvalues().min();
 | 
			
		||||
        if min_eigval <= 0.0 {
 | 
			
		||||
            hess -= reg_scale * min_eigval * DMatrix::identity(total_dim, total_dim);
 | 
			
		||||
        }
 | 
			
		||||
        
 | 
			
		||||
        // project the negative gradient and negative Hessian onto the
 | 
			
		||||
        // orthogonal complement of the frozen subspace
 | 
			
		||||
        let zero_col = DVector::zeros(total_dim);
 | 
			
		||||
        let zero_row = zero_col.transpose();
 | 
			
		||||
        for &k in &frozen_stacked {
 | 
			
		||||
            neg_grad_stacked[k] = 0.0;
 | 
			
		||||
            hess.set_row(k, &zero_row);
 | 
			
		||||
            hess.set_column(k, &zero_col);
 | 
			
		||||
            hess[(k, k)] = 1.0;
 | 
			
		||||
        }
 | 
			
		||||
        
 | 
			
		||||
        // compute the Newton step
 | 
			
		||||
        let base_step_stacked = hess.cholesky().unwrap().solve(&neg_grad_stacked);
 | 
			
		||||
        let base_step = base_step_stacked.reshape_generic(Dyn(element_dim), Dyn(assembly_dim));
 | 
			
		||||
        
 | 
			
		||||
        // use backtracking line search to find a better configuration
 | 
			
		||||
        match seek_better_config(
 | 
			
		||||
            gram, &state, &base_step, neg_grad.dot(&base_step),
 | 
			
		||||
            min_efficiency, backoff, max_backoff_steps
 | 
			
		||||
        ) {
 | 
			
		||||
            Some(better_state) => state = better_state,
 | 
			
		||||
            None => return (state.config, false)
 | 
			
		||||
        };
 | 
			
		||||
    }
 | 
			
		||||
    (state.config, state.loss < tol)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// --- tests ---
 | 
			
		||||
 | 
			
		||||
#[cfg(test)]
 | 
			
		||||
mod tests {
 | 
			
		||||
    use std::f64;
 | 
			
		||||
    
 | 
			
		||||
    use super::*;
 | 
			
		||||
    
 | 
			
		||||
    #[test]
 | 
			
		||||
    fn sub_proj_test() {
 | 
			
		||||
        let target = PartialMatrix(vec![
 | 
			
		||||
            MatrixEntry { index: (0, 0), val: 19.0 },
 | 
			
		||||
            MatrixEntry { index: (0, 2), val: 39.0 },
 | 
			
		||||
            MatrixEntry { index: (1, 1), val: 59.0 },
 | 
			
		||||
            MatrixEntry { index: (1, 2), val: 69.0 }
 | 
			
		||||
        ]);
 | 
			
		||||
        let attempt = DMatrix::<f64>::from_row_slice(2, 3, &[
 | 
			
		||||
            1.0, 2.0, 3.0,
 | 
			
		||||
            4.0, 5.0, 6.0
 | 
			
		||||
        ]);
 | 
			
		||||
        let expected_result = DMatrix::<f64>::from_row_slice(2, 3, &[
 | 
			
		||||
            18.0, 0.0, 36.0,
 | 
			
		||||
            0.0, 54.0, 63.0
 | 
			
		||||
        ]);
 | 
			
		||||
        assert_eq!(target.sub_proj(&attempt), expected_result);
 | 
			
		||||
    }
 | 
			
		||||
    
 | 
			
		||||
    #[test]
 | 
			
		||||
    fn zero_loss_test() {
 | 
			
		||||
        let gram = PartialMatrix({
 | 
			
		||||
            let mut entries = Vec::<MatrixEntry>::new();
 | 
			
		||||
            for j in 0..3 {
 | 
			
		||||
                for k in 0..3 {
 | 
			
		||||
                    entries.push(MatrixEntry {
 | 
			
		||||
                        index: (j, k),
 | 
			
		||||
                        val: if j == k { 1.0 } else { -1.0 }
 | 
			
		||||
                    });
 | 
			
		||||
                }
 | 
			
		||||
            }
 | 
			
		||||
            entries
 | 
			
		||||
        });
 | 
			
		||||
        let config = {
 | 
			
		||||
            let a: f64 = 0.75_f64.sqrt();
 | 
			
		||||
            DMatrix::from_columns(&[
 | 
			
		||||
                sphere(1.0, 0.0, 0.0, a),
 | 
			
		||||
                sphere(-0.5, a, 0.0, a),
 | 
			
		||||
                sphere(-0.5, -a, 0.0, a)
 | 
			
		||||
            ])
 | 
			
		||||
        };
 | 
			
		||||
        let state = SearchState::from_config(&gram, config);
 | 
			
		||||
        assert!(state.loss.abs() < f64::EPSILON);
 | 
			
		||||
    }
 | 
			
		||||
    
 | 
			
		||||
    // --- process inspection examples ---
 | 
			
		||||
    
 | 
			
		||||
    // these tests are meant for human inspection, not automated use. run them
 | 
			
		||||
    // one at a time in `--nocapture` mode and read through the results and
 | 
			
		||||
    // optimization histories that they print out. the `run-examples` script
 | 
			
		||||
    // will run all of them
 | 
			
		||||
    
 | 
			
		||||
    #[test]
 | 
			
		||||
    fn three_spheres_example() {
 | 
			
		||||
        let gram = PartialMatrix({
 | 
			
		||||
            let mut entries = Vec::<MatrixEntry>::new();
 | 
			
		||||
            for j in 0..3 {
 | 
			
		||||
                for k in 0..3 {
 | 
			
		||||
                    entries.push(MatrixEntry {
 | 
			
		||||
                        index: (j, k),
 | 
			
		||||
                        val: if j == k { 1.0 } else { -1.0 }
 | 
			
		||||
                    });
 | 
			
		||||
                }
 | 
			
		||||
            }
 | 
			
		||||
            entries
 | 
			
		||||
        });
 | 
			
		||||
        let guess = {
 | 
			
		||||
            let a: f64 = 0.75_f64.sqrt();
 | 
			
		||||
            DMatrix::from_columns(&[
 | 
			
		||||
                sphere(1.0, 0.0, 0.0, 1.0),
 | 
			
		||||
                sphere(-0.5, a, 0.0, 1.0),
 | 
			
		||||
                sphere(-0.5, -a, 0.0, 1.0)
 | 
			
		||||
            ])
 | 
			
		||||
        };
 | 
			
		||||
        println!();
 | 
			
		||||
        let (config, success) = realize_gram(
 | 
			
		||||
            &gram, guess, &[],
 | 
			
		||||
            1.0e-12, 0.5, 0.9, 1.1, 200, 110
 | 
			
		||||
        );
 | 
			
		||||
        print!("\nCompleted Gram matrix:{}", config.tr_mul(&*Q) * &config);
 | 
			
		||||
        let final_state = SearchState::from_config(&gram, config);
 | 
			
		||||
        if success {
 | 
			
		||||
            println!("Target accuracy achieved!");
 | 
			
		||||
        } else {
 | 
			
		||||
            println!("Failed to reach target accuracy");
 | 
			
		||||
        }
 | 
			
		||||
        println!("Loss: {}", final_state.loss);
 | 
			
		||||
    }
 | 
			
		||||
    
 | 
			
		||||
    #[test]
 | 
			
		||||
    fn point_on_sphere_example() {
 | 
			
		||||
        let gram = PartialMatrix({
 | 
			
		||||
            let mut entries = Vec::<MatrixEntry>::new();
 | 
			
		||||
            for j in 0..2 {
 | 
			
		||||
                for k in 0..2 {
 | 
			
		||||
                    entries.push(MatrixEntry {
 | 
			
		||||
                        index: (j, k),
 | 
			
		||||
                        val: if (j, k) == (1, 1) { 1.0 } else { 0.0 }
 | 
			
		||||
                    });
 | 
			
		||||
                }
 | 
			
		||||
            }
 | 
			
		||||
            entries
 | 
			
		||||
        });
 | 
			
		||||
        let guess = DMatrix::from_columns(&[
 | 
			
		||||
            point(0.0, 0.0, 2.0),
 | 
			
		||||
            sphere(0.0, 0.0, 0.0, 1.0)
 | 
			
		||||
        ]);
 | 
			
		||||
        println!();
 | 
			
		||||
        let (config, success) = realize_gram(
 | 
			
		||||
            &gram, guess, &[(3, 0)],
 | 
			
		||||
            1.0e-12, 0.5, 0.9, 1.1, 200, 110
 | 
			
		||||
        );
 | 
			
		||||
        print!("\nCompleted Gram matrix:{}", config.tr_mul(&*Q) * &config);
 | 
			
		||||
        print!("Configuration:{}", config);
 | 
			
		||||
        let final_state = SearchState::from_config(&gram, config);
 | 
			
		||||
        if success {
 | 
			
		||||
            println!("Target accuracy achieved!");
 | 
			
		||||
        } else {
 | 
			
		||||
            println!("Failed to reach target accuracy");
 | 
			
		||||
        }
 | 
			
		||||
        println!("Loss: {}", final_state.loss);
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
		Loading…
	
	Add table
		Add a link
		
	
		Reference in a new issue