Port the Gram matrix realization routine to Rust
Validate with the process inspection example tests, which print out their results and optimization histories when run one at a time in `--nocapture` mode.
This commit is contained in:
parent
e59d60bf77
commit
9fe03264ab
@ -10,6 +10,7 @@ default = ["console_error_panic_hook"]
|
|||||||
[dependencies]
|
[dependencies]
|
||||||
itertools = "0.13.0"
|
itertools = "0.13.0"
|
||||||
js-sys = "0.3.70"
|
js-sys = "0.3.70"
|
||||||
|
lazy_static = "1.5.0"
|
||||||
nalgebra = "0.33.0"
|
nalgebra = "0.33.0"
|
||||||
rustc-hash = "2.0.0"
|
rustc-hash = "2.0.0"
|
||||||
slab = "0.4.9"
|
slab = "0.4.9"
|
||||||
|
7
app-proto/run-examples
Executable file
7
app-proto/run-examples
Executable file
@ -0,0 +1,7 @@
|
|||||||
|
# based on "Enabling print statements in Cargo tests", by Jon Almeida
|
||||||
|
#
|
||||||
|
# https://jonalmeida.com/posts/2015/01/23/print-cargo/
|
||||||
|
#
|
||||||
|
|
||||||
|
cargo test -- --nocapture engine::tests::three_spheres_example
|
||||||
|
cargo test -- --nocapture engine::tests::point_on_sphere_example
|
@ -1,4 +1,11 @@
|
|||||||
use nalgebra::DVector;
|
use lazy_static::lazy_static;
|
||||||
|
use nalgebra::{Const, DMatrix, DVector, Dyn};
|
||||||
|
|
||||||
|
// --- elements ---
|
||||||
|
|
||||||
|
pub fn point(x: f64, y: f64, z: f64) -> DVector<f64> {
|
||||||
|
DVector::from_column_slice(&[x, y, z, 0.5, 0.5*(x*x + y*y + z*z)])
|
||||||
|
}
|
||||||
|
|
||||||
// the sphere with the given center and radius, with inward-pointing normals
|
// the sphere with the given center and radius, with inward-pointing normals
|
||||||
pub fn sphere(center_x: f64, center_y: f64, center_z: f64, radius: f64) -> DVector<f64> {
|
pub fn sphere(center_x: f64, center_y: f64, center_z: f64, radius: f64) -> DVector<f64> {
|
||||||
@ -24,4 +31,316 @@ pub fn sphere_with_offset(dir_x: f64, dir_y: f64, dir_z: f64, off: f64, curv: f6
|
|||||||
0.5 * curv,
|
0.5 * curv,
|
||||||
off * (1.0 + 0.5 * off * curv)
|
off * (1.0 + 0.5 * off * curv)
|
||||||
])
|
])
|
||||||
|
}
|
||||||
|
|
||||||
|
// --- partial matrices ---
|
||||||
|
|
||||||
|
struct MatrixEntry {
|
||||||
|
index: (usize, usize),
|
||||||
|
val: f64
|
||||||
|
}
|
||||||
|
|
||||||
|
struct PartialMatrix(Vec<MatrixEntry>);
|
||||||
|
|
||||||
|
impl PartialMatrix {
|
||||||
|
fn proj(&self, a: &DMatrix<f64>) -> DMatrix<f64> {
|
||||||
|
let mut result = DMatrix::<f64>::zeros(a.nrows(), a.ncols());
|
||||||
|
let PartialMatrix(entries) = self;
|
||||||
|
for ent in entries {
|
||||||
|
result[ent.index] = a[ent.index];
|
||||||
|
}
|
||||||
|
result
|
||||||
|
}
|
||||||
|
|
||||||
|
fn sub_proj(&self, rhs: &DMatrix<f64>) -> DMatrix<f64> {
|
||||||
|
let mut result = DMatrix::<f64>::zeros(rhs.nrows(), rhs.ncols());
|
||||||
|
let PartialMatrix(entries) = self;
|
||||||
|
for ent in entries {
|
||||||
|
result[ent.index] = ent.val - rhs[ent.index];
|
||||||
|
}
|
||||||
|
result
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// --- gram matrix realization ---
|
||||||
|
|
||||||
|
// the Lorentz form
|
||||||
|
lazy_static! {
|
||||||
|
static ref Q: DMatrix<f64> = DMatrix::from_row_slice(5, 5, &[
|
||||||
|
1.0, 0.0, 0.0, 0.0, 0.0,
|
||||||
|
0.0, 1.0, 0.0, 0.0, 0.0,
|
||||||
|
0.0, 0.0, 1.0, 0.0, 0.0,
|
||||||
|
0.0, 0.0, 0.0, 0.0, -2.0,
|
||||||
|
0.0, 0.0, 0.0, -2.0, 0.0
|
||||||
|
]);
|
||||||
|
}
|
||||||
|
|
||||||
|
struct SearchState {
|
||||||
|
config: DMatrix<f64>,
|
||||||
|
err_proj: DMatrix<f64>,
|
||||||
|
loss: f64
|
||||||
|
}
|
||||||
|
|
||||||
|
impl SearchState {
|
||||||
|
fn from_config(gram: &PartialMatrix, config: DMatrix<f64>) -> SearchState {
|
||||||
|
let err_proj = gram.sub_proj(&(config.tr_mul(&*Q) * &config));
|
||||||
|
let loss = err_proj.norm_squared();
|
||||||
|
SearchState {
|
||||||
|
config: config,
|
||||||
|
err_proj: err_proj,
|
||||||
|
loss: loss
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn basis_matrix(index: (usize, usize), nrows: usize, ncols: usize) -> DMatrix<f64> {
|
||||||
|
let mut result = DMatrix::<f64>::zeros(nrows, ncols);
|
||||||
|
result[index] = 1.0;
|
||||||
|
result
|
||||||
|
}
|
||||||
|
|
||||||
|
// use backtracking line search to find a better configuration
|
||||||
|
fn seek_better_config(
|
||||||
|
gram: &PartialMatrix,
|
||||||
|
state: &SearchState,
|
||||||
|
base_step: &DMatrix<f64>,
|
||||||
|
base_target_improvement: f64,
|
||||||
|
min_efficiency: f64,
|
||||||
|
backoff: f64,
|
||||||
|
max_backoff_steps: i32
|
||||||
|
) -> Option<SearchState> {
|
||||||
|
let mut rate = 1.0;
|
||||||
|
for _ in 0..max_backoff_steps {
|
||||||
|
let trial_config = &state.config + rate * base_step;
|
||||||
|
let trial_state = SearchState::from_config(gram, trial_config);
|
||||||
|
let improvement = state.loss - trial_state.loss;
|
||||||
|
if improvement >= min_efficiency * rate * base_target_improvement {
|
||||||
|
return Some(trial_state);
|
||||||
|
}
|
||||||
|
rate *= backoff;
|
||||||
|
}
|
||||||
|
None
|
||||||
|
}
|
||||||
|
|
||||||
|
// seek a matrix `config` for which `config' * Q * config` matches the partial
|
||||||
|
// matrix `gram`. use gradient descent starting from `guess`
|
||||||
|
fn realize_gram(
|
||||||
|
gram: &PartialMatrix,
|
||||||
|
guess: DMatrix<f64>,
|
||||||
|
frozen: &[(usize, usize)],
|
||||||
|
scaled_tol: f64,
|
||||||
|
min_efficiency: f64,
|
||||||
|
backoff: f64,
|
||||||
|
reg_scale: f64,
|
||||||
|
max_descent_steps: i32,
|
||||||
|
max_backoff_steps: i32
|
||||||
|
) -> (DMatrix<f64>, bool) {
|
||||||
|
// find the dimension of the search space
|
||||||
|
let element_dim = guess.nrows();
|
||||||
|
let assembly_dim = guess.ncols();
|
||||||
|
let total_dim = element_dim * assembly_dim;
|
||||||
|
|
||||||
|
// scale the tolerance
|
||||||
|
let scale_adjustment = ((guess.ncols() - frozen.len()) as f64).sqrt();
|
||||||
|
let tol = scale_adjustment * scaled_tol;
|
||||||
|
|
||||||
|
// convert the frozen indices to stacked format
|
||||||
|
let frozen_stacked: Vec<usize> = frozen.into_iter().map(
|
||||||
|
|index| index.1*element_dim + index.0
|
||||||
|
).collect();
|
||||||
|
|
||||||
|
// use Newton's method with backtracking and gradient descent backup
|
||||||
|
let mut state = SearchState::from_config(gram, guess);
|
||||||
|
for _ in 0..max_descent_steps {
|
||||||
|
// stop if the loss is tolerably low
|
||||||
|
println!("loss: {}", state.loss);
|
||||||
|
/*println!("projected error: {}", state.err_proj);*/
|
||||||
|
if state.loss < tol { break; }
|
||||||
|
|
||||||
|
// find the negative gradient of the loss function
|
||||||
|
let neg_grad = 4.0 * &*Q * &state.config * &state.err_proj;
|
||||||
|
let mut neg_grad_stacked = neg_grad.clone().reshape_generic(Dyn(total_dim), Const::<1>);
|
||||||
|
|
||||||
|
// find the negative Hessian of the loss function
|
||||||
|
let mut hess_cols = Vec::<DVector<f64>>::with_capacity(total_dim);
|
||||||
|
for col in 0..assembly_dim {
|
||||||
|
for row in 0..element_dim {
|
||||||
|
let index = (row, col);
|
||||||
|
let basis_mat = basis_matrix(index, element_dim, assembly_dim);
|
||||||
|
let neg_d_err =
|
||||||
|
basis_mat.tr_mul(&*Q) * &state.config
|
||||||
|
+ state.config.tr_mul(&*Q) * &basis_mat;
|
||||||
|
let neg_d_err_proj = gram.proj(&neg_d_err);
|
||||||
|
let deriv_grad = 4.0 * &*Q * (
|
||||||
|
-&basis_mat * &state.err_proj
|
||||||
|
+ &state.config * &neg_d_err_proj
|
||||||
|
);
|
||||||
|
hess_cols.push(deriv_grad.reshape_generic(Dyn(total_dim), Const::<1>));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
let mut hess = DMatrix::from_columns(hess_cols.as_slice());
|
||||||
|
|
||||||
|
// regularize the Hessian
|
||||||
|
let min_eigval = hess.symmetric_eigenvalues().min();
|
||||||
|
if min_eigval <= 0.0 {
|
||||||
|
hess -= reg_scale * min_eigval * DMatrix::identity(total_dim, total_dim);
|
||||||
|
}
|
||||||
|
|
||||||
|
// project the negative gradient and negative Hessian onto the
|
||||||
|
// orthogonal complement of the frozen subspace
|
||||||
|
let zero_col = DVector::zeros(total_dim);
|
||||||
|
let zero_row = zero_col.transpose();
|
||||||
|
for &k in &frozen_stacked {
|
||||||
|
neg_grad_stacked[k] = 0.0;
|
||||||
|
hess.set_row(k, &zero_row);
|
||||||
|
hess.set_column(k, &zero_col);
|
||||||
|
hess[(k, k)] = 1.0;
|
||||||
|
}
|
||||||
|
|
||||||
|
// compute the Newton step
|
||||||
|
let base_step_stacked = hess.cholesky().unwrap().solve(&neg_grad_stacked);
|
||||||
|
let base_step = base_step_stacked.reshape_generic(Dyn(element_dim), Dyn(assembly_dim));
|
||||||
|
|
||||||
|
// use backtracking line search to find a better configuration
|
||||||
|
match seek_better_config(
|
||||||
|
gram, &state, &base_step, neg_grad.dot(&base_step),
|
||||||
|
min_efficiency, backoff, max_backoff_steps
|
||||||
|
) {
|
||||||
|
Some(better_state) => state = better_state,
|
||||||
|
None => return (state.config, false)
|
||||||
|
};
|
||||||
|
}
|
||||||
|
(state.config, state.loss < tol)
|
||||||
|
}
|
||||||
|
|
||||||
|
// --- tests ---
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use std::f64;
|
||||||
|
|
||||||
|
use super::*;
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn sub_proj_test() {
|
||||||
|
let target = PartialMatrix(vec![
|
||||||
|
MatrixEntry { index: (0, 0), val: 19.0 },
|
||||||
|
MatrixEntry { index: (0, 2), val: 39.0 },
|
||||||
|
MatrixEntry { index: (1, 1), val: 59.0 },
|
||||||
|
MatrixEntry { index: (1, 2), val: 69.0 }
|
||||||
|
]);
|
||||||
|
let attempt = DMatrix::<f64>::from_row_slice(2, 3, &[
|
||||||
|
1.0, 2.0, 3.0,
|
||||||
|
4.0, 5.0, 6.0
|
||||||
|
]);
|
||||||
|
let expected_result = DMatrix::<f64>::from_row_slice(2, 3, &[
|
||||||
|
18.0, 0.0, 36.0,
|
||||||
|
0.0, 54.0, 63.0
|
||||||
|
]);
|
||||||
|
assert_eq!(target.sub_proj(&attempt), expected_result);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn zero_loss_test() {
|
||||||
|
let gram = PartialMatrix({
|
||||||
|
let mut entries = Vec::<MatrixEntry>::new();
|
||||||
|
for j in 0..3 {
|
||||||
|
for k in 0..3 {
|
||||||
|
entries.push(MatrixEntry {
|
||||||
|
index: (j, k),
|
||||||
|
val: if j == k { 1.0 } else { -1.0 }
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
entries
|
||||||
|
});
|
||||||
|
let config = {
|
||||||
|
let a: f64 = 0.75_f64.sqrt();
|
||||||
|
DMatrix::from_columns(&[
|
||||||
|
sphere(1.0, 0.0, 0.0, a),
|
||||||
|
sphere(-0.5, a, 0.0, a),
|
||||||
|
sphere(-0.5, -a, 0.0, a)
|
||||||
|
])
|
||||||
|
};
|
||||||
|
let state = SearchState::from_config(&gram, config);
|
||||||
|
assert!(state.loss.abs() < f64::EPSILON);
|
||||||
|
}
|
||||||
|
|
||||||
|
// --- process inspection examples ---
|
||||||
|
|
||||||
|
// these tests are meant for human inspection, not automated use. run them
|
||||||
|
// one at a time in `--nocapture` mode and read through the results and
|
||||||
|
// optimization histories that they print out. the `run-examples` script
|
||||||
|
// will run all of them
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn three_spheres_example() {
|
||||||
|
let gram = PartialMatrix({
|
||||||
|
let mut entries = Vec::<MatrixEntry>::new();
|
||||||
|
for j in 0..3 {
|
||||||
|
for k in 0..3 {
|
||||||
|
entries.push(MatrixEntry {
|
||||||
|
index: (j, k),
|
||||||
|
val: if j == k { 1.0 } else { -1.0 }
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
entries
|
||||||
|
});
|
||||||
|
let guess = {
|
||||||
|
let a: f64 = 0.75_f64.sqrt();
|
||||||
|
DMatrix::from_columns(&[
|
||||||
|
sphere(1.0, 0.0, 0.0, 1.0),
|
||||||
|
sphere(-0.5, a, 0.0, 1.0),
|
||||||
|
sphere(-0.5, -a, 0.0, 1.0)
|
||||||
|
])
|
||||||
|
};
|
||||||
|
println!();
|
||||||
|
let (config, success) = realize_gram(
|
||||||
|
&gram, guess, &[],
|
||||||
|
1.0e-12, 0.5, 0.9, 1.1, 200, 110
|
||||||
|
);
|
||||||
|
print!("\nCompleted Gram matrix:{}", config.tr_mul(&*Q) * &config);
|
||||||
|
let final_state = SearchState::from_config(&gram, config);
|
||||||
|
if success {
|
||||||
|
println!("Target accuracy achieved!");
|
||||||
|
} else {
|
||||||
|
println!("Failed to reach target accuracy");
|
||||||
|
}
|
||||||
|
println!("Loss: {}", final_state.loss);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn point_on_sphere_example() {
|
||||||
|
let gram = PartialMatrix({
|
||||||
|
let mut entries = Vec::<MatrixEntry>::new();
|
||||||
|
for j in 0..2 {
|
||||||
|
for k in 0..2 {
|
||||||
|
entries.push(MatrixEntry {
|
||||||
|
index: (j, k),
|
||||||
|
val: if (j, k) == (1, 1) { 1.0 } else { 0.0 }
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
entries
|
||||||
|
});
|
||||||
|
let guess = DMatrix::from_columns(&[
|
||||||
|
point(0.0, 0.0, 2.0),
|
||||||
|
sphere(0.0, 0.0, 0.0, 1.0)
|
||||||
|
]);
|
||||||
|
println!();
|
||||||
|
let (config, success) = realize_gram(
|
||||||
|
&gram, guess, &[(3, 0)],
|
||||||
|
1.0e-12, 0.5, 0.9, 1.1, 200, 110
|
||||||
|
);
|
||||||
|
print!("\nCompleted Gram matrix:{}", config.tr_mul(&*Q) * &config);
|
||||||
|
print!("Configuration:{}", config);
|
||||||
|
let final_state = SearchState::from_config(&gram, config);
|
||||||
|
if success {
|
||||||
|
println!("Target accuracy achieved!");
|
||||||
|
} else {
|
||||||
|
println!("Failed to reach target accuracy");
|
||||||
|
}
|
||||||
|
println!("Loss: {}", final_state.loss);
|
||||||
|
}
|
||||||
}
|
}
|
Loading…
Reference in New Issue
Block a user