Compare commits
10 commits
Author | SHA1 | Date | |
---|---|---|---|
![]() |
cc2da3406b | ||
![]() |
b74cbf10c1 | ||
![]() |
bc17d71f4a | ||
![]() |
a203f6bc1b | ||
![]() |
3664ea73b1 | ||
![]() |
9e74d4e837 | ||
![]() |
0de32f5e11 | ||
![]() |
48a640605a | ||
![]() |
8bedb0baf7 | ||
![]() |
8a0d81d707 |
6 changed files with 2454 additions and 20 deletions
|
@ -227,6 +227,16 @@ details[open]:has(li) .element-switch::after {
|
||||||
border-radius: 8px;
|
border-radius: 8px;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#distortion-bar {
|
||||||
|
display: flex;
|
||||||
|
margin-top: 8px;
|
||||||
|
gap: 8px;
|
||||||
|
}
|
||||||
|
|
||||||
|
#distortion-gauge {
|
||||||
|
flex-grow: 1;
|
||||||
|
}
|
||||||
|
|
||||||
/* display */
|
/* display */
|
||||||
|
|
||||||
#display {
|
#display {
|
||||||
|
|
|
@ -3,6 +3,7 @@ use std::{
|
||||||
cell::Cell,
|
cell::Cell,
|
||||||
cmp::Ordering,
|
cmp::Ordering,
|
||||||
collections::{BTreeMap, BTreeSet},
|
collections::{BTreeMap, BTreeSet},
|
||||||
|
f64::consts::SQRT_2,
|
||||||
fmt,
|
fmt,
|
||||||
fmt::{Debug, Formatter},
|
fmt::{Debug, Formatter},
|
||||||
hash::{Hash, Hasher},
|
hash::{Hash, Hasher},
|
||||||
|
@ -122,6 +123,11 @@ pub trait Element: Serial + ProblemPoser + DisplayItem {
|
||||||
// be used carefully to preserve invariant (1), described in the comment on
|
// be used carefully to preserve invariant (1), described in the comment on
|
||||||
// the `tangent` field of the `Assembly` structure
|
// the `tangent` field of the `Assembly` structure
|
||||||
fn set_column_index(&self, index: usize);
|
fn set_column_index(&self, index: usize);
|
||||||
|
|
||||||
|
/* KLUDGE */
|
||||||
|
fn has_distortion(&self) -> bool {
|
||||||
|
false
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Debug for dyn Element {
|
impl Debug for dyn Element {
|
||||||
|
@ -334,6 +340,10 @@ impl Element for Point {
|
||||||
fn set_column_index(&self, index: usize) {
|
fn set_column_index(&self, index: usize) {
|
||||||
self.column_index.set(Some(index));
|
self.column_index.set(Some(index));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn has_distortion(&self) -> bool {
|
||||||
|
true
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Serial for Point {
|
impl Serial for Point {
|
||||||
|
@ -357,6 +367,12 @@ pub trait Regulator: Serial + ProblemPoser + OutlineItem {
|
||||||
fn subjects(&self) -> Vec<Rc<dyn Element>>;
|
fn subjects(&self) -> Vec<Rc<dyn Element>>;
|
||||||
fn measurement(&self) -> ReadSignal<f64>;
|
fn measurement(&self) -> ReadSignal<f64>;
|
||||||
fn set_point(&self) -> Signal<SpecifiedValue>;
|
fn set_point(&self) -> Signal<SpecifiedValue>;
|
||||||
|
fn soft(&self) -> Option<Signal<bool>> {
|
||||||
|
None
|
||||||
|
}
|
||||||
|
fn distortion(&self) -> Option<ReadSignal<f64>> { /* KLUDGE */
|
||||||
|
None
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Hash for dyn Regulator {
|
impl Hash for dyn Regulator {
|
||||||
|
@ -389,6 +405,8 @@ pub struct InversiveDistanceRegulator {
|
||||||
pub subjects: [Rc<dyn Element>; 2],
|
pub subjects: [Rc<dyn Element>; 2],
|
||||||
pub measurement: ReadSignal<f64>,
|
pub measurement: ReadSignal<f64>,
|
||||||
pub set_point: Signal<SpecifiedValue>,
|
pub set_point: Signal<SpecifiedValue>,
|
||||||
|
pub soft: Signal<bool>,
|
||||||
|
distortion: Option<ReadSignal<f64>>, /* KLUDGE */
|
||||||
serial: u64,
|
serial: u64,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -404,9 +422,24 @@ impl InversiveDistanceRegulator {
|
||||||
});
|
});
|
||||||
|
|
||||||
let set_point = create_signal(SpecifiedValue::from_empty_spec());
|
let set_point = create_signal(SpecifiedValue::from_empty_spec());
|
||||||
|
let distortion = if subjects.iter().all(|subj| subj.has_distortion()) {
|
||||||
|
Some(create_memo(move || {
|
||||||
|
let set_point_opt = set_point.with(|set_pt| set_pt.value);
|
||||||
|
let measurement_val = measurement.get();
|
||||||
|
match set_point_opt {
|
||||||
|
None => 0.0,
|
||||||
|
Some(set_point_val) => SQRT_2 * (
|
||||||
|
(-measurement_val).sqrt() - (-set_point_val).sqrt()
|
||||||
|
),
|
||||||
|
}
|
||||||
|
}))
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
};
|
||||||
|
let soft = create_signal(false);
|
||||||
let serial = Self::next_serial();
|
let serial = Self::next_serial();
|
||||||
|
|
||||||
Self { subjects, measurement, set_point, serial }
|
Self { subjects, measurement, set_point, soft, distortion, serial }
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -422,6 +455,14 @@ impl Regulator for InversiveDistanceRegulator {
|
||||||
fn set_point(&self) -> Signal<SpecifiedValue> {
|
fn set_point(&self) -> Signal<SpecifiedValue> {
|
||||||
self.set_point
|
self.set_point
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn soft(&self) -> Option<Signal<bool>> {
|
||||||
|
Some(self.soft)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn distortion(&self) -> Option<ReadSignal<f64>> {
|
||||||
|
self.distortion
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Serial for InversiveDistanceRegulator {
|
impl Serial for InversiveDistanceRegulator {
|
||||||
|
@ -432,6 +473,7 @@ impl Serial for InversiveDistanceRegulator {
|
||||||
|
|
||||||
impl ProblemPoser for InversiveDistanceRegulator {
|
impl ProblemPoser for InversiveDistanceRegulator {
|
||||||
fn pose(&self, problem: &mut ConstraintProblem) {
|
fn pose(&self, problem: &mut ConstraintProblem) {
|
||||||
|
let soft = self.soft.get_untracked();
|
||||||
self.set_point.with_untracked(|set_pt| {
|
self.set_point.with_untracked(|set_pt| {
|
||||||
if let Some(val) = set_pt.value {
|
if let Some(val) = set_pt.value {
|
||||||
let [row, col] = self.subjects.each_ref().map(
|
let [row, col] = self.subjects.each_ref().map(
|
||||||
|
@ -439,8 +481,12 @@ impl ProblemPoser for InversiveDistanceRegulator {
|
||||||
"Subjects should be indexed before inversive distance regulator writes problem data"
|
"Subjects should be indexed before inversive distance regulator writes problem data"
|
||||||
)
|
)
|
||||||
);
|
);
|
||||||
|
if soft {
|
||||||
|
problem.soft.push_sym(row, col, val);
|
||||||
|
} else {
|
||||||
problem.gram.push_sym(row, col, val);
|
problem.gram.push_sym(row, col, val);
|
||||||
}
|
}
|
||||||
|
}
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -705,7 +751,7 @@ impl Assembly {
|
||||||
|
|
||||||
// look for a configuration with the given Gram matrix
|
// look for a configuration with the given Gram matrix
|
||||||
let Realization { result, history } = realize_gram(
|
let Realization { result, history } = realize_gram(
|
||||||
&problem, 1.0e-12, 0.5, 0.9, 1.1, 200, 110
|
&problem, 1.0e-20, 0.5, 0.9, 1.1, 400, 110
|
||||||
);
|
);
|
||||||
|
|
||||||
/* DEBUG */
|
/* DEBUG */
|
||||||
|
|
|
@ -111,6 +111,80 @@ fn StepInput() -> View {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[component]
|
||||||
|
fn DistortionGauge() -> View {
|
||||||
|
let state = use_context::<AppState>();
|
||||||
|
let total_distortion = create_memo(move || {
|
||||||
|
state.assembly.regulators.with(|regs| {
|
||||||
|
let mut total = 0.0;
|
||||||
|
for reg in regs {
|
||||||
|
if let Some(distortion) = reg.distortion() {
|
||||||
|
total += distortion.get().abs();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
total
|
||||||
|
})
|
||||||
|
});
|
||||||
|
|
||||||
|
view! {
|
||||||
|
div(id = "distortion-gauge") {
|
||||||
|
"Distortion: " (total_distortion.with(|distort| distort.to_string()))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[component]
|
||||||
|
fn DistortionPrintButton() -> View {
|
||||||
|
view! {
|
||||||
|
button(
|
||||||
|
on:click = |_| {
|
||||||
|
let state = use_context::<AppState>();
|
||||||
|
let mut hard_distortion_table = String::new();
|
||||||
|
let mut soft_distortion_table = String::new();
|
||||||
|
let mut highest_distortion = f64::NEG_INFINITY;
|
||||||
|
let mut lowest_distortion = f64::INFINITY;
|
||||||
|
let mut largest_hard_distortion = f64::NEG_INFINITY;
|
||||||
|
state.assembly.regulators.with_untracked(|regs| {
|
||||||
|
for reg in regs {
|
||||||
|
if let Some(distortion) = reg.distortion() {
|
||||||
|
let distortion_val = distortion.get();
|
||||||
|
let subjects = reg.subjects();
|
||||||
|
let distortion_line = format!(
|
||||||
|
"{}, {}: {distortion_val}\n",
|
||||||
|
subjects[0].id(),
|
||||||
|
subjects[1].id(),
|
||||||
|
);
|
||||||
|
match reg.soft() {
|
||||||
|
Some(soft) if soft.get() => {
|
||||||
|
soft_distortion_table += &distortion_line;
|
||||||
|
highest_distortion = highest_distortion.max(distortion_val);
|
||||||
|
lowest_distortion = lowest_distortion.min(distortion_val);
|
||||||
|
},
|
||||||
|
_ => {
|
||||||
|
hard_distortion_table += &distortion_line;
|
||||||
|
largest_hard_distortion = largest_hard_distortion.max(distortion_val.abs());
|
||||||
|
}
|
||||||
|
};
|
||||||
|
}
|
||||||
|
}
|
||||||
|
});
|
||||||
|
console_log!("\
|
||||||
|
=== Distortions of flexible edges (for labels) ===\n\n\
|
||||||
|
--- Range ---\n\n\
|
||||||
|
Highest: {highest_distortion}\n\
|
||||||
|
Lowest: {lowest_distortion}\n\n\
|
||||||
|
--- Table ---\n\n{soft_distortion_table}\n\
|
||||||
|
=== Distortions of rigid edges (for validation) ===\n\n\
|
||||||
|
These values should be small relative to the ones for the flexible edges\n\n\
|
||||||
|
--- Range ---\n\n\
|
||||||
|
Largest absolute: {largest_hard_distortion}\n\n\
|
||||||
|
--- Table ---\n\n{hard_distortion_table}\
|
||||||
|
");
|
||||||
|
},
|
||||||
|
) { "Print" }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
fn into_log10_time_point((step, value): (usize, f64)) -> Vec<Option<f64>> {
|
fn into_log10_time_point((step, value): (usize, f64)) -> Vec<Option<f64>> {
|
||||||
vec![
|
vec![
|
||||||
Some(step as f64),
|
Some(step as f64),
|
||||||
|
@ -315,6 +389,10 @@ pub fn Diagnostics() -> View {
|
||||||
}
|
}
|
||||||
DiagnosticsPanel(name = "loss") { LossHistory {} }
|
DiagnosticsPanel(name = "loss") { LossHistory {} }
|
||||||
DiagnosticsPanel(name = "spectrum") { SpectrumHistory {} }
|
DiagnosticsPanel(name = "spectrum") { SpectrumHistory {} }
|
||||||
|
div(id = "distortion-bar") {
|
||||||
|
DistortionGauge {}
|
||||||
|
DistortionPrintButton {}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
|
@ -597,16 +597,16 @@ pub fn Display() -> View {
|
||||||
|status| status.is_ok()
|
|status| status.is_ok()
|
||||||
);
|
);
|
||||||
let step_val = state.assembly.step.with_untracked(|step| step.value);
|
let step_val = state.assembly.step.with_untracked(|step| step.value);
|
||||||
let on_init_step = step_val.is_some_and(|n| n == 0.0);
|
|
||||||
let on_last_step = step_val.is_some_and(
|
let on_last_step = step_val.is_some_and(
|
||||||
|n| state.assembly.descent_history.with_untracked(
|
|n| state.assembly.descent_history.with_untracked(
|
||||||
|history| n as usize + 1 == history.config.len().max(1)
|
|history| n as usize + 1 == history.config.len().max(1)
|
||||||
)
|
)
|
||||||
);
|
);
|
||||||
let on_manipulable_step =
|
if
|
||||||
!realization_successful && on_init_step
|
state.selection.with(|sel| sel.len() == 1)
|
||||||
|| realization_successful && on_last_step;
|
&& realization_successful
|
||||||
if on_manipulable_step && state.selection.with(|sel| sel.len() == 1) {
|
&& on_last_step
|
||||||
|
{
|
||||||
let sel = state.selection.with(
|
let sel = state.selection.with(
|
||||||
|sel| sel.into_iter().next().unwrap().clone()
|
|sel| sel.into_iter().next().unwrap().clone()
|
||||||
);
|
);
|
||||||
|
|
File diff suppressed because it is too large
Load diff
|
@ -1,6 +1,7 @@
|
||||||
use lazy_static::lazy_static;
|
use lazy_static::lazy_static;
|
||||||
use nalgebra::{Const, DMatrix, DVector, DVectorView, Dyn, SymmetricEigen};
|
use nalgebra::{Const, DMatrix, DVector, DVectorView, Dyn, SymmetricEigen};
|
||||||
use std::fmt::{Display, Error, Formatter};
|
use std::fmt::{Display, Error, Formatter};
|
||||||
|
use sycamore::prelude::console_log; /* DEBUG */
|
||||||
|
|
||||||
// --- elements ---
|
// --- elements ---
|
||||||
|
|
||||||
|
@ -240,6 +241,7 @@ impl DescentHistory {
|
||||||
|
|
||||||
pub struct ConstraintProblem {
|
pub struct ConstraintProblem {
|
||||||
pub gram: PartialMatrix,
|
pub gram: PartialMatrix,
|
||||||
|
pub soft: PartialMatrix,
|
||||||
pub frozen: PartialMatrix,
|
pub frozen: PartialMatrix,
|
||||||
pub guess: DMatrix<f64>,
|
pub guess: DMatrix<f64>,
|
||||||
}
|
}
|
||||||
|
@ -249,6 +251,7 @@ impl ConstraintProblem {
|
||||||
const ELEMENT_DIM: usize = 5;
|
const ELEMENT_DIM: usize = 5;
|
||||||
Self {
|
Self {
|
||||||
gram: PartialMatrix::new(),
|
gram: PartialMatrix::new(),
|
||||||
|
soft: PartialMatrix::new(),
|
||||||
frozen: PartialMatrix::new(),
|
frozen: PartialMatrix::new(),
|
||||||
guess: DMatrix::<f64>::zeros(ELEMENT_DIM, element_count),
|
guess: DMatrix::<f64>::zeros(ELEMENT_DIM, element_count),
|
||||||
}
|
}
|
||||||
|
@ -258,6 +261,7 @@ impl ConstraintProblem {
|
||||||
pub fn from_guess(guess_columns: &[DVector<f64>]) -> Self {
|
pub fn from_guess(guess_columns: &[DVector<f64>]) -> Self {
|
||||||
Self {
|
Self {
|
||||||
gram: PartialMatrix::new(),
|
gram: PartialMatrix::new(),
|
||||||
|
soft: PartialMatrix::new(),
|
||||||
frozen: PartialMatrix::new(),
|
frozen: PartialMatrix::new(),
|
||||||
guess: DMatrix::from_columns(guess_columns),
|
guess: DMatrix::from_columns(guess_columns),
|
||||||
}
|
}
|
||||||
|
@ -280,14 +284,18 @@ lazy_static! {
|
||||||
struct SearchState {
|
struct SearchState {
|
||||||
config: DMatrix<f64>,
|
config: DMatrix<f64>,
|
||||||
err_proj: DMatrix<f64>,
|
err_proj: DMatrix<f64>,
|
||||||
|
loss_hard: f64,
|
||||||
loss: f64,
|
loss: f64,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl SearchState {
|
impl SearchState {
|
||||||
fn from_config(gram: &PartialMatrix, config: DMatrix<f64>) -> Self {
|
fn from_config(gram: &PartialMatrix, soft: &PartialMatrix, softness: f64, config: DMatrix<f64>) -> Self {
|
||||||
let err_proj = gram.sub_proj(&(config.tr_mul(&*Q) * &config));
|
let config_gram = &(config.tr_mul(&*Q) * &config);
|
||||||
|
let err_proj_hard = gram.sub_proj(config_gram);
|
||||||
|
let err_proj = &err_proj_hard + softness * soft.sub_proj(config_gram);
|
||||||
|
let loss_hard = err_proj_hard.norm_squared();
|
||||||
let loss = err_proj.norm_squared();
|
let loss = err_proj.norm_squared();
|
||||||
Self { config, err_proj, loss }
|
Self { config, err_proj, loss_hard, loss }
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -331,6 +339,8 @@ pub fn local_unif_to_std(v: DVectorView<f64>) -> DMatrix<f64> {
|
||||||
// use backtracking line search to find a better configuration
|
// use backtracking line search to find a better configuration
|
||||||
fn seek_better_config(
|
fn seek_better_config(
|
||||||
gram: &PartialMatrix,
|
gram: &PartialMatrix,
|
||||||
|
soft: &PartialMatrix,
|
||||||
|
softness: f64,
|
||||||
state: &SearchState,
|
state: &SearchState,
|
||||||
base_step: &DMatrix<f64>,
|
base_step: &DMatrix<f64>,
|
||||||
base_target_improvement: f64,
|
base_target_improvement: f64,
|
||||||
|
@ -341,7 +351,7 @@ fn seek_better_config(
|
||||||
let mut rate = 1.0;
|
let mut rate = 1.0;
|
||||||
for backoff_steps in 0..max_backoff_steps {
|
for backoff_steps in 0..max_backoff_steps {
|
||||||
let trial_config = &state.config + rate * base_step;
|
let trial_config = &state.config + rate * base_step;
|
||||||
let trial_state = SearchState::from_config(gram, trial_config);
|
let trial_state = SearchState::from_config(gram, soft, softness, trial_config);
|
||||||
let improvement = state.loss - trial_state.loss;
|
let improvement = state.loss - trial_state.loss;
|
||||||
if improvement >= min_efficiency * rate * base_target_improvement {
|
if improvement >= min_efficiency * rate * base_target_improvement {
|
||||||
return Some((trial_state, backoff_steps));
|
return Some((trial_state, backoff_steps));
|
||||||
|
@ -376,7 +386,7 @@ pub fn realize_gram(
|
||||||
max_backoff_steps: i32,
|
max_backoff_steps: i32,
|
||||||
) -> Realization {
|
) -> Realization {
|
||||||
// destructure the problem data
|
// destructure the problem data
|
||||||
let ConstraintProblem { gram, guess, frozen } = problem;
|
let ConstraintProblem { gram, soft, guess, frozen } = problem;
|
||||||
|
|
||||||
// start the descent history
|
// start the descent history
|
||||||
let mut history = DescentHistory::new();
|
let mut history = DescentHistory::new();
|
||||||
|
@ -403,13 +413,18 @@ pub fn realize_gram(
|
||||||
let scale_adjustment = (gram.0.len() as f64).sqrt();
|
let scale_adjustment = (gram.0.len() as f64).sqrt();
|
||||||
let tol = scale_adjustment * scaled_tol;
|
let tol = scale_adjustment * scaled_tol;
|
||||||
|
|
||||||
|
// set up constants and variables related to minimizing the soft loss
|
||||||
|
const GRAD_TOL: f64 = 1e-12;
|
||||||
|
let mut grad_size = f64::INFINITY;
|
||||||
|
let mut softness = 1.0;
|
||||||
|
|
||||||
// convert the frozen indices to stacked format
|
// convert the frozen indices to stacked format
|
||||||
let frozen_stacked: Vec<usize> = frozen.into_iter().map(
|
let frozen_stacked: Vec<usize> = frozen.into_iter().map(
|
||||||
|MatrixEntry { index: (row, col), .. }| col*element_dim + row
|
|MatrixEntry { index: (row, col), .. }| col*element_dim + row
|
||||||
).collect();
|
).collect();
|
||||||
|
|
||||||
// use a regularized Newton's method with backtracking
|
// use a regularized Newton's method with backtracking
|
||||||
let mut state = SearchState::from_config(gram, frozen.freeze(guess));
|
let mut state = SearchState::from_config(gram, soft, softness, frozen.freeze(guess));
|
||||||
let mut hess = DMatrix::zeros(element_dim, assembly_dim);
|
let mut hess = DMatrix::zeros(element_dim, assembly_dim);
|
||||||
for _ in 0..max_descent_steps {
|
for _ in 0..max_descent_steps {
|
||||||
// find the negative gradient of the loss function
|
// find the negative gradient of the loss function
|
||||||
|
@ -426,7 +441,7 @@ pub fn realize_gram(
|
||||||
let neg_d_err =
|
let neg_d_err =
|
||||||
basis_mat.tr_mul(&*Q) * &state.config
|
basis_mat.tr_mul(&*Q) * &state.config
|
||||||
+ state.config.tr_mul(&*Q) * &basis_mat;
|
+ state.config.tr_mul(&*Q) * &basis_mat;
|
||||||
let neg_d_err_proj = gram.proj(&neg_d_err);
|
let neg_d_err_proj = gram.proj(&neg_d_err) + softness * soft.proj(&neg_d_err);
|
||||||
let deriv_grad = 4.0 * &*Q * (
|
let deriv_grad = 4.0 * &*Q * (
|
||||||
-&basis_mat * &state.err_proj
|
-&basis_mat * &state.err_proj
|
||||||
+ &state.config * &neg_d_err_proj
|
+ &state.config * &neg_d_err_proj
|
||||||
|
@ -455,10 +470,13 @@ pub fn realize_gram(
|
||||||
hess[(k, k)] = 1.0;
|
hess[(k, k)] = 1.0;
|
||||||
}
|
}
|
||||||
|
|
||||||
// stop if the loss is tolerably low
|
// stop if the hard loss is tolerably low and the total loss is close to
|
||||||
|
// stationary. we use `neg_grad_stacked` to measure the size of the
|
||||||
|
// gradient because it's been projected onto the frozen subspace
|
||||||
history.config.push(state.config.clone());
|
history.config.push(state.config.clone());
|
||||||
history.scaled_loss.push(state.loss / scale_adjustment);
|
history.scaled_loss.push(state.loss_hard / scale_adjustment);
|
||||||
if state.loss < tol { break; }
|
grad_size = neg_grad_stacked.norm_squared();
|
||||||
|
if state.loss_hard < tol && grad_size < softness * GRAD_TOL { break; }
|
||||||
|
|
||||||
// compute the Newton step
|
// compute the Newton step
|
||||||
/* TO DO */
|
/* TO DO */
|
||||||
|
@ -482,7 +500,7 @@ pub fn realize_gram(
|
||||||
|
|
||||||
// use backtracking line search to find a better configuration
|
// use backtracking line search to find a better configuration
|
||||||
if let Some((better_state, backoff_steps)) = seek_better_config(
|
if let Some((better_state, backoff_steps)) = seek_better_config(
|
||||||
gram, &state, &base_step, neg_grad.dot(&base_step),
|
gram, soft, softness, &state, &base_step, neg_grad.dot(&base_step),
|
||||||
min_efficiency, backoff, max_backoff_steps,
|
min_efficiency, backoff, max_backoff_steps,
|
||||||
) {
|
) {
|
||||||
state = better_state;
|
state = better_state;
|
||||||
|
@ -493,8 +511,18 @@ pub fn realize_gram(
|
||||||
history,
|
history,
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// if we're near a minimum of the total loss, but the hard loss still
|
||||||
|
// isn't tolerably low, make the soft constraints softer
|
||||||
|
const SOFTNESS_BACKOFF_THRESHOLD: f64 = 1e-6;
|
||||||
|
const SOFTNESS_BACKOFF: f64 = 0.95;
|
||||||
|
if state.loss_hard >= tol && grad_size < softness * SOFTNESS_BACKOFF_THRESHOLD {
|
||||||
|
softness *= SOFTNESS_BACKOFF;
|
||||||
|
state = SearchState::from_config(gram, soft, softness, state.config);
|
||||||
|
console_log!("Softness decreased to {softness}");
|
||||||
}
|
}
|
||||||
let result = if state.loss < tol {
|
}
|
||||||
|
let result = if state.loss_hard < tol && grad_size < softness * GRAD_TOL {
|
||||||
// express the uniform basis in the standard basis
|
// express the uniform basis in the standard basis
|
||||||
const UNIFORM_DIM: usize = 4;
|
const UNIFORM_DIM: usize = 4;
|
||||||
let total_dim_unif = UNIFORM_DIM * assembly_dim;
|
let total_dim_unif = UNIFORM_DIM * assembly_dim;
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue