Compare commits

..

1 commit

Author SHA1 Message Date
2c8c09d20d feat: Point coordinate regulators (#118)
Implement regulators for the Euclidean coordinates of `Point` entities, automatically creating all three of them for each added point entity. When such a regulator is set, it freezes the corresponding representation coordinate to the set point. In addition, if all three coordinates of a given `Point` are set, the coradius coordinate (which holds the norm of the point) is frozen as well.

Note that a `PointCoordinateRegulator` must be created with a `Point` as the subject. This commit modifies `HalfCurvatureRegulator` analogously, so that it can only be created with a `Sphere`.
Co-authored-by: Glen Whitney <glen@studioinfinity.org>
Co-committed-by: Glen Whitney <glen@studioinfinity.org>
2025-10-13 22:52:02 +00:00
8 changed files with 101 additions and 3044 deletions

View file

@ -12,11 +12,11 @@ Note that currently this is just the barest beginnings of the project, more of a
### Implementation goals
* Comfortable, intuitive UI
* Provide a comfortable, intuitive UI
* Able to run in browser (so implemented in WASM-compatible language)
* Allow execution in browser (so implemented in WASM-compatible language)
* Produce scalable graphics of 3D diagrams, and maybe STL files (or other fabricatable file format) as well.
* Produce scalable graphics of 3D diagrams, and maybe STL files (or other fabricatable file format) as well
## Prototype
@ -24,38 +24,40 @@ The latest prototype is in the folder `app-proto`. It includes both a user inter
### Install the prerequisites
1. Install [`rustup`](https://rust-lang.github.io/rustup/): the officially recommended Rust toolchain manager
- It's available on Ubuntu as a [Snap](https://snapcraft.io/rustup)
2. Call `rustup default stable` to "download the latest stable release of Rust and set it as your default toolchain"
- If you forget, the `rustup` [help system](https://github.com/rust-lang/rustup/blob/d9b3601c3feb2e88cf3f8ca4f7ab4fdad71441fd/src/errors.rs#L109-L112) will remind you
3. Call `rustup target add wasm32-unknown-unknown` to add the [most generic 32-bit WebAssembly target](https://doc.rust-lang.org/nightly/rustc/platform-support/wasm32-unknown-unknown.html)
4. Call `cargo install wasm-pack` to install the [WebAssembly toolchain](https://rustwasm.github.io/docs/wasm-pack/)
5. Call `cargo install trunk` to install the [Trunk](https://trunkrs.dev/) web-build tool
1. Install [`rustup`](https://rust-lang.github.io/rustup/): the officially recommended Rust toolchain manager.
- It's available on Ubuntu as a [Snap](https://snapcraft.io/rustup).
2. Call `rustup default stable` to "download the latest stable release of Rust and set it as your default toolchain".
- If you forget, the `rustup` [help system](https://github.com/rust-lang/rustup/blob/d9b3601c3feb2e88cf3f8ca4f7ab4fdad71441fd/src/errors.rs#L109-L112) will remind you.
3. Call `rustup target add wasm32-unknown-unknown` to add the [most generic 32-bit WebAssembly target](https://doc.rust-lang.org/nightly/rustc/platform-support/wasm32-unknown-unknown.html).
4. Call `cargo install wasm-pack` to install the [WebAssembly toolchain](https://rustwasm.github.io/docs/wasm-pack/).
5. Call `cargo install trunk` to install the [Trunk](https://trunkrs.dev/) web-build tool.
- In the future, `trunk` can be updated with the same command. (You may need the `--locked` flag if your ambient version of `rustc` does not match that required by `trunk`.)
6. Add the `.cargo/bin` folder in your home directory to your executable search path
- This lets you call Trunk, and other tools installed by Cargo, without specifying their paths
- On POSIX systems, the search path is stored in the `PATH` environment variable
- This lets you call Trunk, and other tools installed by Cargo, without specifying their paths.
- On POSIX systems, the search path is stored in the `PATH` environment variable.
- Alternatively, if you don't want to adjust your `PATH`, you can install `trunk` in another directory `DIR` via `cargo install --root DIR trunk`.
### Play with the prototype
1. From the `app-proto` folder, call `trunk serve --release` to build and serve the prototype
- The crates the prototype depends on will be downloaded and served automatically
- For a faster build, at the expense of a much slower prototype, you can call `trunk serve` without the `--release` flag
1. From the `app-proto` folder, call `trunk serve --release` to build and serve the prototype.
- The crates the prototype depends on will be downloaded and served automatically.
- For a faster build, at the expense of a much slower prototype, you can call `trunk serve` without the `--release` flag.
- If you want to stay in the top-level folder, you can call `trunk serve --config app-proto [--release]` from there instead.
3. In a web browser, visit one of the URLs listed under the message `INFO 📡 server listening at:`
- Touching any file in the `app-proto` folder will make Trunk rebuild and live-reload the prototype
4. Press *ctrl+C* in the shell where Trunk is running to stop serving the prototype
3. In a web browser, visit one of the URLs listed under the message `INFO 📡 server listening at:`.
- Touching any file in the `app-proto` folder will make Trunk rebuild and live-reload the prototype.
4. Press *ctrl+C* in the shell where Trunk is running to stop serving the prototype.
### Run the engine on some example problems
1. Use `sh` to run the script `tools/run-examples.sh`
- The script is location-independent, so you can do this from anywhere in the dyna3 repository
1. Use `sh` to run the script `tools/run-examples.sh`.
- The script is location-independent, so you can do this from anywhere in the dyna3 repository.
- The call from the top level of the repository is:
```bash
sh tools/run-examples.sh
```
- For each example problem, the engine will print the value of the loss function at each optimization step
- The first example that prints is the same as the Irisawa hexlet example from the Julia version of the engine prototype. If you go into `engine-proto/gram-test`, launch Julia, and then
- For each example problem, the engine will print the value of the loss function at each optimization step.
- The first example that prints is the same as the Irisawa hexlet example from the Julia version of the engine prototype. If you go into `engine-proto/gram-test`, launch Julia, and then execute
```julia
include("irisawa-hexlet.jl")
@ -64,24 +66,24 @@ The latest prototype is in the folder `app-proto`. It includes both a user inter
end
```
you should see that it prints basically the same loss history until the last few steps, when the lower default precision of the Rust engine really starts to show
you should see that it prints basically the same loss history until the last few steps, when the lower default precision of the Rust engine really starts to show.
### Run the automated tests
1. Go into the `app-proto` folder
2. Call `cargo test`
1. Go into the `app-proto` folder.
2. Call `cargo test`.
### Deploy the prototype
1. From the `app-proto` folder, call `trunk build --release`
- Building in [release mode](https://doc.rust-lang.org/cargo/reference/profiles.html#release) produces an executable which is smaller and often much faster, but harder to debug and more time-consuming to build
- If you want to stay in the top-level folder, you can call `trunk build --config app-proto --release` from there instead
1. From the `app-proto` folder, call `trunk build --release`.
- Building in [release mode](https://doc.rust-lang.org/cargo/reference/profiles.html#release) produces an executable which is smaller and often much faster, but harder to debug and more time-consuming to build.
- If you want to stay in the top-level folder, you can call `trunk build --config app-proto --release` from there instead.
2. Use `sh` to run the packaging script `tools/package-for-deployment.sh`.
- The script is location-independent, so you can do this from anywhere in the dyna3 repository
- The script is location-independent, so you can do this from anywhere in the dyna3 repository.
- The call from the top level of the repository is:
```bash
sh tools/package-for-deployment.sh
```
- This will overwrite or replace the files in `deploy/dyna3`
- This will overwrite or replace the files in `deploy/dyna3`.
3. Put the contents of `deploy/dyna3` in the folder on your server that the prototype will be served from.
- To simplify uploading, you might want to combine these files into an archive called `deploy/dyna3.zip`. Git has been set to ignore this path
- To simplify uploading, you might want to combine these files into an archive called `deploy/dyna3.zip`. Git has been set to ignore this path.

View file

@ -227,16 +227,6 @@ details[open]:has(li) .element-switch::after {
border-radius: 8px;
}
#distortion-bar {
display: flex;
margin-top: 8px;
gap: 8px;
}
#distortion-gauge {
flex-grow: 1;
}
/* display */
#display {

View file

@ -1,13 +1,11 @@
use enum_iterator::{all, Sequence};
use nalgebra::{DMatrix, DVector, DVectorView};
use std::{
any::Any,
cell::Cell,
cmp::Ordering,
collections::{BTreeMap, BTreeSet},
f64::consts::SQRT_2,
fmt,
fmt::{Debug, Formatter},
fmt::{Debug, Display, Formatter},
hash::{Hash, Hasher},
rc::Rc,
sync::{atomic, atomic::AtomicU64},
@ -88,6 +86,14 @@ impl Ord for dyn Serial {
}
}
// Small helper function to generate consistent errors when there
// are indexing issues in a ProblemPoser
fn indexing_error(item: &str, name: &str, actor: &str) -> String {
format!(
"{item} \"{name}\" must be indexed before {actor} writes problem data"
)
}
pub trait ProblemPoser {
fn pose(&self, problem: &mut ConstraintProblem);
}
@ -126,16 +132,11 @@ pub trait Element: Serial + ProblemPoser + DisplayItem {
// be used carefully to preserve invariant (1), described in the comment on
// the `tangent` field of the `Assembly` structure
fn set_column_index(&self, index: usize);
/* KLUDGE */
fn is_point(&self) -> bool {
false
}
}
impl Debug for dyn Element {
fn fmt(&self, f: &mut Formatter<'_>) -> Result<(), fmt::Error> {
self.id().fmt(f)
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
Debug::fmt(&self.id(), f)
}
}
@ -258,8 +259,7 @@ impl Serial for Sphere {
impl ProblemPoser for Sphere {
fn pose(&self, problem: &mut ConstraintProblem) {
let index = self.column_index().expect(
format!("Sphere \"{}\" should be indexed before writing problem data", self.id).as_str()
);
indexing_error("Sphere", &self.id, "it").as_str());
problem.gram.push_sym(index, index, 1.0);
problem.guess.set_column(index, &self.representation.get_clone_untracked());
}
@ -353,10 +353,6 @@ impl Element for Point {
fn set_column_index(&self, index: usize) {
self.column_index.set(Some(index));
}
fn is_point(&self) -> bool {
true
}
}
impl Serial for Point {
@ -368,25 +364,17 @@ impl Serial for Point {
impl ProblemPoser for Point {
fn pose(&self, problem: &mut ConstraintProblem) {
let index = self.column_index().expect(
format!("Point \"{}\" should be indexed before writing problem data", self.id).as_str()
);
indexing_error("Point", &self.id, "it").as_str());
problem.gram.push_sym(index, index, 0.0);
problem.frozen.push(Self::WEIGHT_COMPONENT, index, 0.5);
problem.guess.set_column(index, &self.representation.get_clone_untracked());
}
}
pub trait Regulator: Any + Serial + ProblemPoser + OutlineItem {
pub trait Regulator: Serial + ProblemPoser + OutlineItem {
fn subjects(&self) -> Vec<Rc<dyn Element>>;
fn measurement(&self) -> ReadSignal<f64>;
fn set_point(&self) -> Signal<SpecifiedValue>;
fn soft(&self) -> Option<Signal<bool>> {
None
}
fn distortion(&self) -> Option<ReadSignal<f64>> { /* KLUDGE */
None
}
fn as_any(&self) -> &dyn Any;
}
impl Hash for dyn Regulator {
@ -419,8 +407,6 @@ pub struct InversiveDistanceRegulator {
pub subjects: [Rc<dyn Element>; 2],
pub measurement: ReadSignal<f64>,
pub set_point: Signal<SpecifiedValue>,
pub soft: Signal<bool>,
distortion: Option<ReadSignal<f64>>, /* KLUDGE */
serial: u64,
}
@ -436,24 +422,9 @@ impl InversiveDistanceRegulator {
});
let set_point = create_signal(SpecifiedValue::from_empty_spec());
let distortion = if subjects.iter().all(|subj| subj.is_point()) {
Some(create_memo(move || {
let set_point_opt = set_point.with(|set_pt| set_pt.value);
let measurement_val = measurement.get();
match set_point_opt {
None => 0.0,
Some(set_point_val) => SQRT_2 * (
(-measurement_val).sqrt() - (-set_point_val).sqrt()
),
}
}))
} else {
None
};
let soft = create_signal(false);
let serial = Self::next_serial();
Self { subjects, measurement, set_point, soft, distortion, serial }
Self { subjects, measurement, set_point, serial }
}
}
@ -469,18 +440,6 @@ impl Regulator for InversiveDistanceRegulator {
fn set_point(&self) -> Signal<SpecifiedValue> {
self.set_point
}
fn soft(&self) -> Option<Signal<bool>> {
Some(self.soft)
}
fn distortion(&self) -> Option<ReadSignal<f64>> {
self.distortion
}
fn as_any(&self) -> &dyn Any {
self
}
}
impl Serial for InversiveDistanceRegulator {
@ -491,19 +450,14 @@ impl Serial for InversiveDistanceRegulator {
impl ProblemPoser for InversiveDistanceRegulator {
fn pose(&self, problem: &mut ConstraintProblem) {
let soft = self.soft.get_untracked();
self.set_point.with_untracked(|set_pt| {
if let Some(val) = set_pt.value {
let [row, col] = self.subjects.each_ref().map(
|subj| subj.column_index().expect(
"Subjects should be indexed before inversive distance regulator writes problem data"
)
indexing_error("Subject", subj.id(),
"inversive distance regulator").as_str())
);
if soft {
problem.soft.push_sym(row, col, val);
} else {
problem.gram.push_sym(row, col, val);
}
problem.gram.push_sym(row, col, val);
}
});
}
@ -541,10 +495,6 @@ impl Regulator for HalfCurvatureRegulator {
fn set_point(&self) -> Signal<SpecifiedValue> {
self.set_point
}
fn as_any(&self) -> &dyn Any {
self
}
}
impl Serial for HalfCurvatureRegulator {
@ -558,8 +508,8 @@ impl ProblemPoser for HalfCurvatureRegulator {
self.set_point.with_untracked(|set_pt| {
if let Some(val) = set_pt.value {
let col = self.subject.column_index().expect(
"Subject should be indexed before half-curvature regulator writes problem data"
);
indexing_error("Subject", &self.subject.id,
"half-curvature regulator").as_str());
problem.frozen.push(Sphere::CURVATURE_COMPONENT, col, val);
}
});
@ -567,11 +517,18 @@ impl ProblemPoser for HalfCurvatureRegulator {
}
#[derive(Clone, Copy, Sequence)]
pub enum Axis {X = 0, Y = 1, Z = 2}
pub enum Axis { X = 0, Y = 1, Z = 2 }
impl Axis {
pub const N_AXIS: usize = (Axis::Z as usize) + 1;
pub const NAME: [&str; Axis::N_AXIS] = ["X", "Y", "Z"];
fn name(&self) -> &'static str {
match self { Axis::X => "X", Axis::Y => "Y", Axis::Z => "Z" }
}
}
impl Display for Axis {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.write_str(self.name())
}
}
pub struct PointCoordinateRegulator {
@ -597,21 +554,9 @@ impl Serial for PointCoordinateRegulator {
}
impl Regulator for PointCoordinateRegulator {
fn subjects(&self) -> Vec<Rc<dyn Element>> {
vec![self.subject.clone()]
}
fn measurement(&self) -> ReadSignal<f64> {
self.measurement
}
fn set_point(&self) -> Signal<SpecifiedValue> {
self.set_point
}
fn as_any(&self) -> &dyn Any {
self
}
fn subjects(&self) -> Vec<Rc<dyn Element>> { vec![self.subject.clone()] }
fn measurement(&self) -> ReadSignal<f64> { self.measurement }
fn set_point(&self) -> Signal<SpecifiedValue> { self.set_point }
}
impl ProblemPoser for PointCoordinateRegulator {
@ -619,19 +564,20 @@ impl ProblemPoser for PointCoordinateRegulator {
self.set_point.with_untracked(|set_pt| {
if let Some(val) = set_pt.value {
let col = self.subject.column_index().expect(
"Subject must be indexed before point-coordinate regulator poses.");
indexing_error("Subject", &self.subject.id,
"point-coordinate regulator").as_str());
problem.frozen.push(self.axis as usize, col, val);
// Check if all three coordinates have been frozen, and if so,
// freeze the coradius as well
let mut coords = [0.0; Axis::N_AXIS];
// If all three of the subject's spatial coordinates have been
// frozen, then freeze its norm component:
let mut coords = [0.0; Axis::CARDINALITY];
let mut nset: usize = 0;
for &MatrixEntry {index, value} in &(problem.frozen) {
if index.1 == col && index.0 < Axis::N_AXIS {
if index.1 == col && index.0 < Axis::CARDINALITY {
nset += 1;
coords[index.0] = value
}
}
if nset == Axis::N_AXIS {
if nset == Axis::CARDINALITY {
let [x, y, z] = coords;
problem.frozen.push(
Point::NORM_COMPONENT, col, point(x,y,z)[Point::NORM_COMPONENT]);
@ -841,6 +787,7 @@ impl Assembly {
/* DEBUG */
// log the Gram matrix
console_log!("Gram matrix:\n{}", problem.gram);
console_log!("Frozen entries:\n{}", problem.frozen);
/* DEBUG */
// log the initial configuration matrix
@ -848,7 +795,7 @@ impl Assembly {
// look for a configuration with the given Gram matrix
let Realization { result, history } = realize_gram(
&problem, 1.0e-20, 0.5, 0.9, 1.1, 400, 110
&problem, 1.0e-12, 0.5, 0.9, 1.1, 200, 110
);
/* DEBUG */
@ -1000,7 +947,8 @@ mod tests {
use crate::engine;
#[test]
#[should_panic(expected = "Sphere \"sphere\" should be indexed before writing problem data")]
#[should_panic(expected =
"Sphere \"sphere\" must be indexed before it writes problem data")]
fn unindexed_element_test() {
let _ = create_root(|| {
let elt = Sphere::default("sphere".to_string(), 0);
@ -1009,7 +957,8 @@ mod tests {
}
#[test]
#[should_panic(expected = "Subjects should be indexed before inversive distance regulator writes problem data")]
#[should_panic(expected = "Subject \"sphere1\" must be indexed before \
inversive distance regulator writes problem data")]
fn unindexed_subject_test_inversive_distance() {
let _ = create_root(|| {
let subjects = [0, 1].map(

View file

@ -111,94 +111,6 @@ fn StepInput() -> View {
}
}
#[component]
fn DistortionGauge() -> View {
let state = use_context::<AppState>();
let total_distortion = create_memo(move || {
state.assembly.regulators.with(|regs| {
let mut total = 0.0;
for reg in regs {
if let Some(distortion) = reg.distortion() {
total += distortion.get().abs();
}
}
total
})
});
view! {
div(id = "distortion-gauge") {
"Distortion: " (total_distortion.with(|distort| distort.to_string()))
}
}
}
#[component]
fn PrintButton() -> View {
view! {
button(
on:click = |_| {
let state = use_context::<AppState>();
// print the edge length distortions
let mut hard_distortion_table = String::new();
let mut soft_distortion_table = String::new();
let mut highest_distortion = f64::NEG_INFINITY;
let mut lowest_distortion = f64::INFINITY;
let mut largest_hard_distortion = f64::NEG_INFINITY;
state.assembly.regulators.with_untracked(|regs| {
for reg in regs {
if let Some(distortion) = reg.distortion() {
let distortion_val = distortion.get();
let subjects = reg.subjects();
let distortion_line = format!(
"{}, {}: {distortion_val}\n",
subjects[0].id(),
subjects[1].id(),
);
match reg.soft() {
Some(soft) if soft.get() => {
soft_distortion_table += &distortion_line;
highest_distortion = highest_distortion.max(distortion_val);
lowest_distortion = lowest_distortion.min(distortion_val);
},
_ => {
hard_distortion_table += &distortion_line;
largest_hard_distortion = largest_hard_distortion.max(distortion_val.abs());
}
};
}
}
});
console_log!("\
=== Distortions of flexible edges (for labels) ===\n\n\
--- Range ---\n\n\
Highest: {highest_distortion}\n\
Lowest: {lowest_distortion}\n\n\
--- Table ---\n\n{soft_distortion_table}\n\
=== Distortions of rigid edges (for validation) ===\n\n\
These values should be small relative to the ones for the flexible edges\n\n\
--- Range ---\n\n\
Largest absolute: {largest_hard_distortion}\n\n\
--- Table ---\n\n{hard_distortion_table}\
");
// print the vertex coordinates
let mut coords_table = String::new();
state.assembly.elements.with_untracked(|elts| {
for elt in elts.iter().filter(|elt| elt.is_point()) {
let (x, y, z) = elt.representation().with(
|rep| (rep[0], rep[1], rep[2])
);
coords_table += &format!("{}: {x}, {y}, {z}\n", elt.id());
}
});
console_log!("=== Vertex coordinates ===\n\n{coords_table}");
},
) { "Print" }
}
}
fn into_log10_time_point((step, value): (usize, f64)) -> Vec<Option<f64>> {
vec![
Some(step as f64),
@ -403,10 +315,6 @@ pub fn Diagnostics() -> View {
}
DiagnosticsPanel(name = "loss") { LossHistory {} }
DiagnosticsPanel(name = "spectrum") { SpectrumHistory {} }
div(id = "distortion-bar") {
DistortionGauge {}
PrintButton {}
}
}
}
}

View file

@ -597,16 +597,16 @@ pub fn Display() -> View {
|status| status.is_ok()
);
let step_val = state.assembly.step.with_untracked(|step| step.value);
let on_init_step = step_val.is_some_and(|n| n == 0.0);
let on_last_step = step_val.is_some_and(
|n| state.assembly.descent_history.with_untracked(
|history| n as usize + 1 == history.config.len().max(1)
)
);
if
state.selection.with(|sel| sel.len() == 1)
&& realization_successful
&& on_last_step
{
let on_manipulable_step =
!realization_successful && on_init_step
|| realization_successful && on_last_step;
if on_manipulable_step && state.selection.with(|sel| sel.len() == 1) {
let sel = state.selection.with(
|sel| sel.into_iter().next().unwrap().clone()
);

View file

@ -6,7 +6,6 @@ use web_sys::{KeyboardEvent, MouseEvent, wasm_bindgen::JsCast};
use crate::{
AppState,
assembly::{
Axis,
Element,
HalfCurvatureRegulator,
InversiveDistanceRegulator,
@ -123,10 +122,11 @@ impl OutlineItem for HalfCurvatureRegulator {
impl OutlineItem for PointCoordinateRegulator {
fn outline_item(self: Rc<Self>, _element: &Rc<dyn Element>) -> View {
let name = format!("{} coordinate", self.axis);
view! {
li(class = "regulator") {
div(class = "regulator-label") { (Axis::NAME[self.axis as usize]) }
div(class = "regulator-type") { "Coordinate" }
div(class = "regulator-label") // for spacing
div(class = "regulator-type") { (name) }
RegulatorInput(regulator = self)
div(class = "status")
}

File diff suppressed because it is too large Load diff

View file

@ -1,7 +1,6 @@
use lazy_static::lazy_static;
use nalgebra::{Const, DMatrix, DVector, DVectorView, Dyn, SymmetricEigen};
use std::fmt::{Display, Error, Formatter};
use sycamore::prelude::console_log; /* DEBUG */
// --- elements ---
@ -47,7 +46,7 @@ pub fn project_sphere_to_normalized(rep: &mut DVector<f64>) {
// normalize a point's representation vector by scaling
pub fn project_point_to_normalized(rep: &mut DVector<f64>) {
rep.scale_mut(0.5 / rep[3]); //FIXME: This 3 should be Point::WEIGHT_COMPONENT
rep.scale_mut(0.5 / rep[3]);
}
// --- partial matrices ---
@ -241,7 +240,6 @@ impl DescentHistory {
pub struct ConstraintProblem {
pub gram: PartialMatrix,
pub soft: PartialMatrix,
pub frozen: PartialMatrix,
pub guess: DMatrix<f64>,
}
@ -251,7 +249,6 @@ impl ConstraintProblem {
const ELEMENT_DIM: usize = 5;
Self {
gram: PartialMatrix::new(),
soft: PartialMatrix::new(),
frozen: PartialMatrix::new(),
guess: DMatrix::<f64>::zeros(ELEMENT_DIM, element_count),
}
@ -261,7 +258,6 @@ impl ConstraintProblem {
pub fn from_guess(guess_columns: &[DVector<f64>]) -> Self {
Self {
gram: PartialMatrix::new(),
soft: PartialMatrix::new(),
frozen: PartialMatrix::new(),
guess: DMatrix::from_columns(guess_columns),
}
@ -284,18 +280,14 @@ lazy_static! {
struct SearchState {
config: DMatrix<f64>,
err_proj: DMatrix<f64>,
loss_hard: f64,
loss: f64,
}
impl SearchState {
fn from_config(gram: &PartialMatrix, soft: &PartialMatrix, softness: f64, config: DMatrix<f64>) -> Self {
let config_gram = &(config.tr_mul(&*Q) * &config);
let err_proj_hard = gram.sub_proj(config_gram);
let err_proj = &err_proj_hard + softness * soft.sub_proj(config_gram);
let loss_hard = err_proj_hard.norm_squared();
fn from_config(gram: &PartialMatrix, config: DMatrix<f64>) -> Self {
let err_proj = gram.sub_proj(&(config.tr_mul(&*Q) * &config));
let loss = err_proj.norm_squared();
Self { config, err_proj, loss_hard, loss }
Self { config, err_proj, loss }
}
}
@ -339,8 +331,6 @@ pub fn local_unif_to_std(v: DVectorView<f64>) -> DMatrix<f64> {
// use backtracking line search to find a better configuration
fn seek_better_config(
gram: &PartialMatrix,
soft: &PartialMatrix,
softness: f64,
state: &SearchState,
base_step: &DMatrix<f64>,
base_target_improvement: f64,
@ -351,7 +341,7 @@ fn seek_better_config(
let mut rate = 1.0;
for backoff_steps in 0..max_backoff_steps {
let trial_config = &state.config + rate * base_step;
let trial_state = SearchState::from_config(gram, soft, softness, trial_config);
let trial_state = SearchState::from_config(gram, trial_config);
let improvement = state.loss - trial_state.loss;
if improvement >= min_efficiency * rate * base_target_improvement {
return Some((trial_state, backoff_steps));
@ -386,7 +376,7 @@ pub fn realize_gram(
max_backoff_steps: i32,
) -> Realization {
// destructure the problem data
let ConstraintProblem { gram, soft, guess, frozen } = problem;
let ConstraintProblem { gram, guess, frozen } = problem;
// start the descent history
let mut history = DescentHistory::new();
@ -413,18 +403,13 @@ pub fn realize_gram(
let scale_adjustment = (gram.0.len() as f64).sqrt();
let tol = scale_adjustment * scaled_tol;
// set up constants and variables related to minimizing the soft loss
const GRAD_TOL: f64 = 1e-12;
let mut grad_size = f64::INFINITY;
let mut softness = 1.0;
// convert the frozen indices to stacked format
let frozen_stacked: Vec<usize> = frozen.into_iter().map(
|MatrixEntry { index: (row, col), .. }| col*element_dim + row
).collect();
// use a regularized Newton's method with backtracking
let mut state = SearchState::from_config(gram, soft, softness, frozen.freeze(guess));
let mut state = SearchState::from_config(gram, frozen.freeze(guess));
let mut hess = DMatrix::zeros(element_dim, assembly_dim);
for _ in 0..max_descent_steps {
// find the negative gradient of the loss function
@ -441,7 +426,7 @@ pub fn realize_gram(
let neg_d_err =
basis_mat.tr_mul(&*Q) * &state.config
+ state.config.tr_mul(&*Q) * &basis_mat;
let neg_d_err_proj = gram.proj(&neg_d_err) + softness * soft.proj(&neg_d_err);
let neg_d_err_proj = gram.proj(&neg_d_err);
let deriv_grad = 4.0 * &*Q * (
-&basis_mat * &state.err_proj
+ &state.config * &neg_d_err_proj
@ -470,13 +455,10 @@ pub fn realize_gram(
hess[(k, k)] = 1.0;
}
// stop if the hard loss is tolerably low and the total loss is close to
// stationary. we use `neg_grad_stacked` to measure the size of the
// gradient because it's been projected onto the frozen subspace
// stop if the loss is tolerably low
history.config.push(state.config.clone());
history.scaled_loss.push(state.loss_hard / scale_adjustment);
grad_size = neg_grad_stacked.norm_squared();
if state.loss_hard < tol && grad_size < softness * GRAD_TOL { break; }
history.scaled_loss.push(state.loss / scale_adjustment);
if state.loss < tol { break; }
// compute the Newton step
/* TO DO */
@ -500,7 +482,7 @@ pub fn realize_gram(
// use backtracking line search to find a better configuration
if let Some((better_state, backoff_steps)) = seek_better_config(
gram, soft, softness, &state, &base_step, neg_grad.dot(&base_step),
gram, &state, &base_step, neg_grad.dot(&base_step),
min_efficiency, backoff, max_backoff_steps,
) {
state = better_state;
@ -511,18 +493,8 @@ pub fn realize_gram(
history,
};
}
// if we're near a minimum of the total loss, but the hard loss still
// isn't tolerably low, make the soft constraints softer
const SOFTNESS_BACKOFF_THRESHOLD: f64 = 1e-6;
const SOFTNESS_BACKOFF: f64 = 0.95;
if state.loss_hard >= tol && grad_size < softness * SOFTNESS_BACKOFF_THRESHOLD {
softness *= SOFTNESS_BACKOFF;
state = SearchState::from_config(gram, soft, softness, state.config);
console_log!("Softness decreased to {softness}");
}
}
let result = if state.loss_hard < tol && grad_size < softness * GRAD_TOL {
let result = if state.loss < tol {
// express the uniform basis in the standard basis
const UNIFORM_DIM: usize = 4;
let total_dim_unif = UNIFORM_DIM * assembly_dim;