Introduce soft constraints

Use a penalty method as a quick & dirty way to get started.
This commit is contained in:
Aaron Fenyes 2025-09-18 10:31:44 -07:00
parent 9e74d4e837
commit 3664ea73b1
3 changed files with 105 additions and 59 deletions

View file

@ -402,6 +402,7 @@ pub struct InversiveDistanceRegulator {
pub subjects: [Rc<dyn Element>; 2], pub subjects: [Rc<dyn Element>; 2],
pub measurement: ReadSignal<f64>, pub measurement: ReadSignal<f64>,
pub set_point: Signal<SpecifiedValue>, pub set_point: Signal<SpecifiedValue>,
pub soft: Signal<bool>,
distortion: Option<ReadSignal<f64>>, /* KLUDGE */ distortion: Option<ReadSignal<f64>>, /* KLUDGE */
serial: u64, serial: u64,
} }
@ -432,9 +433,10 @@ impl InversiveDistanceRegulator {
} else { } else {
None None
}; };
let soft = create_signal(false);
let serial = Self::next_serial(); let serial = Self::next_serial();
Self { subjects, measurement, set_point, distortion, serial } Self { subjects, measurement, set_point, soft, distortion, serial }
} }
} }
@ -464,6 +466,7 @@ impl Serial for InversiveDistanceRegulator {
impl ProblemPoser for InversiveDistanceRegulator { impl ProblemPoser for InversiveDistanceRegulator {
fn pose(&self, problem: &mut ConstraintProblem) { fn pose(&self, problem: &mut ConstraintProblem) {
let soft = self.soft.get_untracked();
self.set_point.with_untracked(|set_pt| { self.set_point.with_untracked(|set_pt| {
if let Some(val) = set_pt.value { if let Some(val) = set_pt.value {
let [row, col] = self.subjects.each_ref().map( let [row, col] = self.subjects.each_ref().map(
@ -471,8 +474,12 @@ impl ProblemPoser for InversiveDistanceRegulator {
"Subjects should be indexed before inversive distance regulator writes problem data" "Subjects should be indexed before inversive distance regulator writes problem data"
) )
); );
if soft {
problem.soft.push_sym(row, col, val);
} else {
problem.gram.push_sym(row, col, val); problem.gram.push_sym(row, col, val);
} }
}
}); });
} }
} }

View file

@ -882,10 +882,11 @@ fn load_irisawa_hexlet(assembly: &Assembly) {
assembly.insert_regulator(Rc::new(outer_moon_tangency)); assembly.insert_regulator(Rc::new(outer_moon_tangency));
} }
fn regular_diagonals<'a, const N: usize>(vertex_ids: [&'a str; N]) -> Vec<(f64, Vec<[&'a str; 2]>)> { fn regular_diagonals<'a, const N: usize>(vertex_ids: [&'a str; N]) -> Vec<(bool, f64, Vec<[&'a str; 2]>)> {
let ang = PI / (N as f64); let ang = PI / (N as f64);
let ang_sin = ang.sin(); let ang_sin = ang.sin();
(2..N-1).map(|sep| ( (2..N-1).map(|sep| (
false,
(sep as f64 * ang).sin() / ang_sin, (sep as f64 * ang).sin() / ang_sin,
(0..N-sep).map(|k| [vertex_ids[k], vertex_ids[k + sep]]).collect() (0..N-sep).map(|k| [vertex_ids[k], vertex_ids[k + sep]]).collect()
)).collect() )).collect()
@ -1113,7 +1114,7 @@ fn load_554_base(assembly: &Assembly) {
"g_SSW", "g_SSE", "g_ESE", "g_ENE", "g_SSW", "g_SSE", "g_ESE", "g_ENE",
]; ];
let struts: Vec<_> = [ let struts: Vec<_> = [
(1.0, vec![ (false, 1.0, vec![
["a_NE", "a_NW"], ["a_NE", "a_NW"],
["a_NW", "a_SW"], ["a_NW", "a_SW"],
["a_SW", "a_SE"], ["a_SW", "a_SE"],
@ -1138,6 +1139,16 @@ fn load_554_base(assembly: &Assembly) {
["c_S", "d_SE"], ["c_S", "d_SE"],
["c_E", "d_SE"], ["c_E", "d_SE"],
["c_E", "d_NE"], ["c_E", "d_NE"],
["g_NNE", "g_NNW"],
["g_NNW", "g_WNW"],
["g_WNW", "g_WSW"],
["g_WSW", "g_SSW"],
["g_SSW", "g_SSE"],
["g_SSE", "g_ESE"],
["g_ESE", "g_ENE"],
["g_ENE", "g_NNE"],
]),
(true, 1.0, vec![
["d_NE", "e_N"], ["d_NE", "e_N"],
["d_NW", "e_N"], ["d_NW", "e_N"],
["d_NW", "e_W"], ["d_NW", "e_W"],
@ -1178,14 +1189,6 @@ fn load_554_base(assembly: &Assembly) {
["e_S", "g_SSE"], ["e_S", "g_SSE"],
["e_E", "g_ESE"], ["e_E", "g_ESE"],
["e_E", "g_ENE"], ["e_E", "g_ENE"],
["g_NNE", "g_NNW"],
["g_NNW", "g_WNW"],
["g_WNW", "g_WSW"],
["g_WSW", "g_SSW"],
["g_SSW", "g_SSE"],
["g_SSE", "g_ESE"],
["g_ESE", "g_ENE"],
["g_ENE", "g_NNE"],
]), ]),
].into_iter() ].into_iter()
.chain(regular_diagonals(f_a)) .chain(regular_diagonals(f_a))
@ -1199,7 +1202,7 @@ fn load_554_base(assembly: &Assembly) {
.chain(regular_diagonals(f_abc_e)) .chain(regular_diagonals(f_abc_e))
.chain(regular_diagonals(f_g)) .chain(regular_diagonals(f_g))
.collect(); .collect();
for (length, vertex_pairs) in struts { for (soft, length, vertex_pairs) in struts {
let inv_dist = Some(-0.5 * length * length); let inv_dist = Some(-0.5 * length * length);
for pair in vertex_pairs { for pair in vertex_pairs {
let adjacent_vertices = pair.map( let adjacent_vertices = pair.map(
@ -1209,6 +1212,7 @@ fn load_554_base(assembly: &Assembly) {
); );
let distance = InversiveDistanceRegulator::new(adjacent_vertices); let distance = InversiveDistanceRegulator::new(adjacent_vertices);
distance.set_point.set(SpecifiedValue::from(inv_dist)); distance.set_point.set(SpecifiedValue::from(inv_dist));
distance.soft.set(soft);
assembly.insert_regulator(Rc::new(distance)); assembly.insert_regulator(Rc::new(distance));
} }
} }
@ -1560,7 +1564,7 @@ fn load_554_aug1(assembly: &Assembly) {
"g_SSW", "g_SSE", "g_ESE", "g_ENE", "g_SSW", "g_SSE", "g_ESE", "g_ENE",
]; ];
let struts: Vec<_> = [ let struts: Vec<_> = [
(1.0, vec![ (false, 1.0, vec![
["a_NE", "a_NW"], ["a_NE", "a_NW"],
["a_NW", "a_SW"], ["a_NW", "a_SW"],
["a_SW", "a_SE"], ["a_SW", "a_SE"],
@ -1577,6 +1581,16 @@ fn load_554_aug1(assembly: &Assembly) {
["b_SE", "c_S"], ["b_SE", "c_S"],
["b_SE", "c_E"], ["b_SE", "c_E"],
["b_NE", "c_E"], ["b_NE", "c_E"],
["g_NNE", "g_NNW"],
["g_NNW", "g_WNW"],
["g_WNW", "g_WSW"],
["g_WSW", "g_SSW"],
["g_SSW", "g_SSE"],
["g_SSE", "g_ESE"],
["g_ESE", "g_ENE"],
["g_ENE", "g_NNE"],
]),
(true, 1.0, vec![
["c_N", "d_NE"], ["c_N", "d_NE"],
["c_N", "d_NW"], ["c_N", "d_NW"],
["c_W", "d_NW"], ["c_W", "d_NW"],
@ -1641,14 +1655,6 @@ fn load_554_aug1(assembly: &Assembly) {
["e_S", "g_SSE"], ["e_S", "g_SSE"],
["e_E", "g_ESE"], ["e_E", "g_ESE"],
["e_E", "g_ENE"], ["e_E", "g_ENE"],
["g_NNE", "g_NNW"],
["g_NNW", "g_WNW"],
["g_WNW", "g_WSW"],
["g_WSW", "g_SSW"],
["g_SSW", "g_SSE"],
["g_SSE", "g_ESE"],
["g_ESE", "g_ENE"],
["g_ENE", "g_NNE"],
]), ]),
].into_iter() ].into_iter()
.chain(regular_diagonals(f_a)) .chain(regular_diagonals(f_a))
@ -1658,7 +1664,7 @@ fn load_554_aug1(assembly: &Assembly) {
.chain(regular_diagonals(f_abc_e)) .chain(regular_diagonals(f_abc_e))
.chain(regular_diagonals(f_g)) .chain(regular_diagonals(f_g))
.collect(); .collect();
for (length, vertex_pairs) in struts { for (soft, length, vertex_pairs) in struts {
let inv_dist = Some(-0.5 * length * length); let inv_dist = Some(-0.5 * length * length);
for pair in vertex_pairs { for pair in vertex_pairs {
let adjacent_vertices = pair.map( let adjacent_vertices = pair.map(
@ -1668,6 +1674,7 @@ fn load_554_aug1(assembly: &Assembly) {
); );
let distance = InversiveDistanceRegulator::new(adjacent_vertices); let distance = InversiveDistanceRegulator::new(adjacent_vertices);
distance.set_point.set(SpecifiedValue::from(inv_dist)); distance.set_point.set(SpecifiedValue::from(inv_dist));
distance.soft.set(soft);
assembly.insert_regulator(Rc::new(distance)); assembly.insert_regulator(Rc::new(distance));
} }
} }
@ -1991,7 +1998,7 @@ fn load_554_aug1_inner(assembly: &Assembly) {
"g_SSW", "g_SSE", "g_ESE", "g_ENE", "g_SSW", "g_SSE", "g_ESE", "g_ENE",
]; ];
let struts: Vec<_> = [ let struts: Vec<_> = [
(1.0, vec![ (false, 1.0, vec![
["a_NE", "a_NW"], ["a_NE", "a_NW"],
["a_NW", "a_SW"], ["a_NW", "a_SW"],
["a_SW", "a_SE"], ["a_SW", "a_SE"],
@ -2008,6 +2015,16 @@ fn load_554_aug1_inner(assembly: &Assembly) {
["b_SE", "c_S"], ["b_SE", "c_S"],
["b_SE", "c_E"], ["b_SE", "c_E"],
["b_NE", "c_E"], ["b_NE", "c_E"],
["g_NNE", "g_NNW"],
["g_NNW", "g_WNW"],
["g_WNW", "g_WSW"],
["g_WSW", "g_SSW"],
["g_SSW", "g_SSE"],
["g_SSE", "g_ESE"],
["g_ESE", "g_ENE"],
["g_ENE", "g_NNE"],
]),
(true, 1.0, vec![
["c_N", "d_NE"], ["c_N", "d_NE"],
["c_N", "d_NW"], ["c_N", "d_NW"],
["c_W", "d_NW"], ["c_W", "d_NW"],
@ -2072,14 +2089,6 @@ fn load_554_aug1_inner(assembly: &Assembly) {
["e_S", "g_SSE"], ["e_S", "g_SSE"],
["e_E", "g_ESE"], ["e_E", "g_ESE"],
["e_E", "g_ENE"], ["e_E", "g_ENE"],
["g_NNE", "g_NNW"],
["g_NNW", "g_WNW"],
["g_WNW", "g_WSW"],
["g_WSW", "g_SSW"],
["g_SSW", "g_SSE"],
["g_SSE", "g_ESE"],
["g_ESE", "g_ENE"],
["g_ENE", "g_NNE"],
]), ]),
].into_iter() ].into_iter()
.chain(regular_diagonals(f_a)) .chain(regular_diagonals(f_a))
@ -2089,7 +2098,7 @@ fn load_554_aug1_inner(assembly: &Assembly) {
.chain(regular_diagonals(f_abc_e)) .chain(regular_diagonals(f_abc_e))
.chain(regular_diagonals(f_g)) .chain(regular_diagonals(f_g))
.collect(); .collect();
for (length, vertex_pairs) in struts { for (soft, length, vertex_pairs) in struts {
let inv_dist = Some(-0.5 * length * length); let inv_dist = Some(-0.5 * length * length);
for pair in vertex_pairs { for pair in vertex_pairs {
let adjacent_vertices = pair.map( let adjacent_vertices = pair.map(
@ -2099,6 +2108,7 @@ fn load_554_aug1_inner(assembly: &Assembly) {
); );
let distance = InversiveDistanceRegulator::new(adjacent_vertices); let distance = InversiveDistanceRegulator::new(adjacent_vertices);
distance.set_point.set(SpecifiedValue::from(inv_dist)); distance.set_point.set(SpecifiedValue::from(inv_dist));
distance.soft.set(soft);
assembly.insert_regulator(Rc::new(distance)); assembly.insert_regulator(Rc::new(distance));
} }
} }
@ -2433,7 +2443,7 @@ fn load_554_aug2(assembly: &Assembly) {
"g_SSW", "g_SSE", "g_ESE", "g_ENE", "g_SSW", "g_SSE", "g_ESE", "g_ENE",
]; ];
let struts: Vec<_> = [ let struts: Vec<_> = [
(1.0, vec![ (false, 1.0, vec![
["a_NE", "a_NW"], ["a_NE", "a_NW"],
["a_NW", "a_SW"], ["a_NW", "a_SW"],
["a_SW", "a_SE"], ["a_SW", "a_SE"],
@ -2441,11 +2451,21 @@ fn load_554_aug2(assembly: &Assembly) {
["a_NE", "b_NE"], ["a_NE", "b_NE"],
["a_NW", "b_NW"], ["a_NW", "b_NW"],
["a_SW", "b_SW"], ["a_SW", "b_SW"],
["a_SE", "b_SE"],
["b_NE", "c_N"], ["b_NE", "c_N"],
["b_NW", "c_N"], ["b_NW", "c_N"],
["b_NW", "c_W"], ["b_NW", "c_W"],
["b_SW", "c_W"], ["b_SW", "c_W"],
["g_NNE", "g_NNW"],
["g_NNW", "g_WNW"],
["g_WNW", "g_WSW"],
["g_WSW", "g_SSW"],
["g_SSW", "g_SSE"],
["g_SSE", "g_ESE"],
["g_ESE", "g_ENE"],
["g_ENE", "g_NNE"],
]),
(true, 1.0, vec![
["a_SE", "b_SE"],
["b_SW", "c_S"], ["b_SW", "c_S"],
["b_SE", "c_S"], ["b_SE", "c_S"],
["b_SE", "c_E"], ["b_SE", "c_E"],
@ -2524,14 +2544,6 @@ fn load_554_aug2(assembly: &Assembly) {
["e_S", "g_SSE"], ["e_S", "g_SSE"],
["e_E", "g_ESE"], ["e_E", "g_ESE"],
["e_E", "g_ENE"], ["e_E", "g_ENE"],
["g_NNE", "g_NNW"],
["g_NNW", "g_WNW"],
["g_WNW", "g_WSW"],
["g_WSW", "g_SSW"],
["g_SSW", "g_SSE"],
["g_SSE", "g_ESE"],
["g_ESE", "g_ENE"],
["g_ENE", "g_NNE"],
]), ]),
].into_iter() ].into_iter()
.chain(regular_diagonals(f_a)) .chain(regular_diagonals(f_a))
@ -2539,7 +2551,7 @@ fn load_554_aug2(assembly: &Assembly) {
.chain(regular_diagonals(f_abc_w)) .chain(regular_diagonals(f_abc_w))
.chain(regular_diagonals(f_g)) .chain(regular_diagonals(f_g))
.collect(); .collect();
for (length, vertex_pairs) in struts { for (soft, length, vertex_pairs) in struts {
let inv_dist = Some(-0.5 * length * length); let inv_dist = Some(-0.5 * length * length);
for pair in vertex_pairs { for pair in vertex_pairs {
let adjacent_vertices = pair.map( let adjacent_vertices = pair.map(
@ -2549,6 +2561,7 @@ fn load_554_aug2(assembly: &Assembly) {
); );
let distance = InversiveDistanceRegulator::new(adjacent_vertices); let distance = InversiveDistanceRegulator::new(adjacent_vertices);
distance.set_point.set(SpecifiedValue::from(inv_dist)); distance.set_point.set(SpecifiedValue::from(inv_dist));
distance.soft.set(soft);
assembly.insert_regulator(Rc::new(distance)); assembly.insert_regulator(Rc::new(distance));
} }
} }
@ -2922,7 +2935,7 @@ fn load_554_domed(assembly: &Assembly) {
let f_abc_n = ["a_NW", "b_NW", "c_N", "b_NE", "a_NE"]; let f_abc_n = ["a_NW", "b_NW", "c_N", "b_NE", "a_NE"];
let f_abc_w = ["a_SW", "b_SW", "c_W", "b_NW", "a_NW"]; let f_abc_w = ["a_SW", "b_SW", "c_W", "b_NW", "a_NW"];
let struts: Vec<_> = [ let struts: Vec<_> = [
(1.0, vec![ (false, 1.0, vec![
["a_NE", "a_NW"], ["a_NE", "a_NW"],
["a_NW", "a_SW"], ["a_NW", "a_SW"],
["a_SW", "a_SE"], ["a_SW", "a_SE"],
@ -2930,11 +2943,13 @@ fn load_554_domed(assembly: &Assembly) {
["a_NE", "b_NE"], ["a_NE", "b_NE"],
["a_NW", "b_NW"], ["a_NW", "b_NW"],
["a_SW", "b_SW"], ["a_SW", "b_SW"],
["a_SE", "b_SE"],
["b_NE", "c_N"], ["b_NE", "c_N"],
["b_NW", "c_N"], ["b_NW", "c_N"],
["b_NW", "c_W"], ["b_NW", "c_W"],
["b_SW", "c_W"], ["b_SW", "c_W"],
]),
(true, 1.0, vec![
["a_SE", "b_SE"],
["b_SW", "c_S"], ["b_SW", "c_S"],
["b_SE", "c_S"], ["b_SE", "c_S"],
["b_SE", "c_E"], ["b_SE", "c_E"],
@ -3059,7 +3074,7 @@ fn load_554_domed(assembly: &Assembly) {
.chain(regular_diagonals(f_abc_n)) .chain(regular_diagonals(f_abc_n))
.chain(regular_diagonals(f_abc_w)) .chain(regular_diagonals(f_abc_w))
.collect(); .collect();
for (length, vertex_pairs) in struts { for (soft, length, vertex_pairs) in struts {
let inv_dist = Some(-0.5 * length * length); let inv_dist = Some(-0.5 * length * length);
for pair in vertex_pairs { for pair in vertex_pairs {
let adjacent_vertices = pair.map( let adjacent_vertices = pair.map(
@ -3069,6 +3084,7 @@ fn load_554_domed(assembly: &Assembly) {
); );
let distance = InversiveDistanceRegulator::new(adjacent_vertices); let distance = InversiveDistanceRegulator::new(adjacent_vertices);
distance.set_point.set(SpecifiedValue::from(inv_dist)); distance.set_point.set(SpecifiedValue::from(inv_dist));
distance.soft.set(soft);
assembly.insert_regulator(Rc::new(distance)); assembly.insert_regulator(Rc::new(distance));
} }
} }

View file

@ -1,6 +1,7 @@
use lazy_static::lazy_static; use lazy_static::lazy_static;
use nalgebra::{Const, DMatrix, DVector, DVectorView, Dyn, SymmetricEigen}; use nalgebra::{Const, DMatrix, DVector, DVectorView, Dyn, SymmetricEigen};
use std::fmt::{Display, Error, Formatter}; use std::fmt::{Display, Error, Formatter};
use sycamore::prelude::console_log; /* DEBUG */
// --- elements --- // --- elements ---
@ -240,6 +241,7 @@ impl DescentHistory {
pub struct ConstraintProblem { pub struct ConstraintProblem {
pub gram: PartialMatrix, pub gram: PartialMatrix,
pub soft: PartialMatrix,
pub frozen: PartialMatrix, pub frozen: PartialMatrix,
pub guess: DMatrix<f64>, pub guess: DMatrix<f64>,
} }
@ -249,6 +251,7 @@ impl ConstraintProblem {
const ELEMENT_DIM: usize = 5; const ELEMENT_DIM: usize = 5;
Self { Self {
gram: PartialMatrix::new(), gram: PartialMatrix::new(),
soft: PartialMatrix::new(),
frozen: PartialMatrix::new(), frozen: PartialMatrix::new(),
guess: DMatrix::<f64>::zeros(ELEMENT_DIM, element_count), guess: DMatrix::<f64>::zeros(ELEMENT_DIM, element_count),
} }
@ -258,6 +261,7 @@ impl ConstraintProblem {
pub fn from_guess(guess_columns: &[DVector<f64>]) -> Self { pub fn from_guess(guess_columns: &[DVector<f64>]) -> Self {
Self { Self {
gram: PartialMatrix::new(), gram: PartialMatrix::new(),
soft: PartialMatrix::new(),
frozen: PartialMatrix::new(), frozen: PartialMatrix::new(),
guess: DMatrix::from_columns(guess_columns), guess: DMatrix::from_columns(guess_columns),
} }
@ -280,14 +284,18 @@ lazy_static! {
struct SearchState { struct SearchState {
config: DMatrix<f64>, config: DMatrix<f64>,
err_proj: DMatrix<f64>, err_proj: DMatrix<f64>,
loss_hard: f64,
loss: f64, loss: f64,
} }
impl SearchState { impl SearchState {
fn from_config(gram: &PartialMatrix, config: DMatrix<f64>) -> Self { fn from_config(gram: &PartialMatrix, soft: &PartialMatrix, softness: f64, config: DMatrix<f64>) -> Self {
let err_proj = gram.sub_proj(&(config.tr_mul(&*Q) * &config)); let config_gram = &(config.tr_mul(&*Q) * &config);
let err_proj_hard = gram.sub_proj(config_gram);
let err_proj = &err_proj_hard + softness * soft.sub_proj(config_gram);
let loss_hard = err_proj_hard.norm_squared();
let loss = err_proj.norm_squared(); let loss = err_proj.norm_squared();
Self { config, err_proj, loss } Self { config, err_proj, loss_hard, loss }
} }
} }
@ -331,6 +339,8 @@ pub fn local_unif_to_std(v: DVectorView<f64>) -> DMatrix<f64> {
// use backtracking line search to find a better configuration // use backtracking line search to find a better configuration
fn seek_better_config( fn seek_better_config(
gram: &PartialMatrix, gram: &PartialMatrix,
soft: &PartialMatrix,
softness: f64,
state: &SearchState, state: &SearchState,
base_step: &DMatrix<f64>, base_step: &DMatrix<f64>,
base_target_improvement: f64, base_target_improvement: f64,
@ -341,7 +351,7 @@ fn seek_better_config(
let mut rate = 1.0; let mut rate = 1.0;
for backoff_steps in 0..max_backoff_steps { for backoff_steps in 0..max_backoff_steps {
let trial_config = &state.config + rate * base_step; let trial_config = &state.config + rate * base_step;
let trial_state = SearchState::from_config(gram, trial_config); let trial_state = SearchState::from_config(gram, soft, softness, trial_config);
let improvement = state.loss - trial_state.loss; let improvement = state.loss - trial_state.loss;
if improvement >= min_efficiency * rate * base_target_improvement { if improvement >= min_efficiency * rate * base_target_improvement {
return Some((trial_state, backoff_steps)); return Some((trial_state, backoff_steps));
@ -376,7 +386,7 @@ pub fn realize_gram(
max_backoff_steps: i32, max_backoff_steps: i32,
) -> Realization { ) -> Realization {
// destructure the problem data // destructure the problem data
let ConstraintProblem { gram, guess, frozen } = problem; let ConstraintProblem { gram, soft, guess, frozen } = problem;
// start the descent history // start the descent history
let mut history = DescentHistory::new(); let mut history = DescentHistory::new();
@ -403,13 +413,16 @@ pub fn realize_gram(
let scale_adjustment = (gram.0.len() as f64).sqrt(); let scale_adjustment = (gram.0.len() as f64).sqrt();
let tol = scale_adjustment * scaled_tol; let tol = scale_adjustment * scaled_tol;
// initialize the softness parameter
let mut softness = 1.0;
// convert the frozen indices to stacked format // convert the frozen indices to stacked format
let frozen_stacked: Vec<usize> = frozen.into_iter().map( let frozen_stacked: Vec<usize> = frozen.into_iter().map(
|MatrixEntry { index: (row, col), .. }| col*element_dim + row |MatrixEntry { index: (row, col), .. }| col*element_dim + row
).collect(); ).collect();
// use a regularized Newton's method with backtracking // use a regularized Newton's method with backtracking
let mut state = SearchState::from_config(gram, frozen.freeze(guess)); let mut state = SearchState::from_config(gram, soft, softness, frozen.freeze(guess));
let mut hess = DMatrix::zeros(element_dim, assembly_dim); let mut hess = DMatrix::zeros(element_dim, assembly_dim);
for _ in 0..max_descent_steps { for _ in 0..max_descent_steps {
// find the negative gradient of the loss function // find the negative gradient of the loss function
@ -426,7 +439,7 @@ pub fn realize_gram(
let neg_d_err = let neg_d_err =
basis_mat.tr_mul(&*Q) * &state.config basis_mat.tr_mul(&*Q) * &state.config
+ state.config.tr_mul(&*Q) * &basis_mat; + state.config.tr_mul(&*Q) * &basis_mat;
let neg_d_err_proj = gram.proj(&neg_d_err); let neg_d_err_proj = gram.proj(&neg_d_err) + softness * soft.proj(&neg_d_err);
let deriv_grad = 4.0 * &*Q * ( let deriv_grad = 4.0 * &*Q * (
-&basis_mat * &state.err_proj -&basis_mat * &state.err_proj
+ &state.config * &neg_d_err_proj + &state.config * &neg_d_err_proj
@ -455,10 +468,10 @@ pub fn realize_gram(
hess[(k, k)] = 1.0; hess[(k, k)] = 1.0;
} }
// stop if the loss is tolerably low // stop if the hard loss is tolerably low
history.config.push(state.config.clone()); history.config.push(state.config.clone());
history.scaled_loss.push(state.loss / scale_adjustment); history.scaled_loss.push(state.loss_hard / scale_adjustment);
if state.loss < tol { break; } if state.loss_hard < tol { break; }
// compute the Newton step // compute the Newton step
/* TO DO */ /* TO DO */
@ -482,7 +495,7 @@ pub fn realize_gram(
// use backtracking line search to find a better configuration // use backtracking line search to find a better configuration
if let Some((better_state, backoff_steps)) = seek_better_config( if let Some((better_state, backoff_steps)) = seek_better_config(
gram, &state, &base_step, neg_grad.dot(&base_step), gram, soft, softness, &state, &base_step, neg_grad.dot(&base_step),
min_efficiency, backoff, max_backoff_steps, min_efficiency, backoff, max_backoff_steps,
) { ) {
state = better_state; state = better_state;
@ -493,8 +506,18 @@ pub fn realize_gram(
history, history,
}; };
} }
// if we're near a minimum of the total loss, but the hard loss still
// isn't tolerably low, make the soft constraints softer
const GRAD_TOL: f64 = 1e-4;
const SOFTNESS_BACKOFF: f64 = 0.5;
if neg_grad.norm_squared() < GRAD_TOL {
// if we're close to a minimum, make the soft constraints softer
softness *= SOFTNESS_BACKOFF;
console_log!("Softness decreased to {softness}");
} }
let result = if state.loss < tol { }
let result = if state.loss_hard < tol {
// express the uniform basis in the standard basis // express the uniform basis in the standard basis
const UNIFORM_DIM: usize = 4; const UNIFORM_DIM: usize = 4;
let total_dim_unif = UNIFORM_DIM * assembly_dim; let total_dim_unif = UNIFORM_DIM * assembly_dim;