WIP: Clean up the outline view #16

Closed
Vectornaut wants to merge 29 commits from outline-cleanup into main
13 changed files with 1144 additions and 302 deletions

View File

@ -1,7 +1,7 @@
[package] [package]
name = "sketch-outline" name = "dyna3"
version = "0.1.0" version = "0.1.0"
authors = ["Aaron"] authors = ["Aaron Fenyes", "Glen Whitney"]
edition = "2021" edition = "2021"
[features] [features]
@ -10,6 +10,7 @@ default = ["console_error_panic_hook"]
[dependencies] [dependencies]
itertools = "0.13.0" itertools = "0.13.0"
js-sys = "0.3.70" js-sys = "0.3.70"
lazy_static = "1.5.0"
nalgebra = "0.33.0" nalgebra = "0.33.0"
rustc-hash = "2.0.0" rustc-hash = "2.0.0"
slab = "0.4.9" slab = "0.4.9"
@ -25,6 +26,7 @@ console_error_panic_hook = { version = "0.1.7", optional = true }
version = "0.3.69" version = "0.3.69"
features = [ features = [
'HtmlCanvasElement', 'HtmlCanvasElement',
'HtmlInputElement',
'Performance', 'Performance',
'WebGl2RenderingContext', 'WebGl2RenderingContext',
'WebGlBuffer', 'WebGlBuffer',

View File

@ -2,8 +2,10 @@
<html> <html>
<head> <head>
<meta charset="utf-8"/> <meta charset="utf-8"/>
<title>Sketch outline</title> <title>dyna3</title>
<link data-trunk rel="css" href="main.css"/> <link data-trunk rel="css" href="main.css"/>
<link href="https://fonts.bunny.net/css?family=fira-sans:ital,wght@0,400;1,400&display=swap" rel="stylesheet">
<link href="https://fonts.bunny.net/css?family=noto-emoji:wght@400&text=%f0%9f%94%97%e2%9a%a0&display=swap" rel="stylesheet">
</head> </head>
<body></body> <body></body>
</html> </html>

View File

@ -2,6 +2,7 @@ body {
margin: 0px; margin: 0px;
color: #fcfcfc; color: #fcfcfc;
background-color: #222; background-color: #222;
font-family: 'Fira Sans', sans-serif;
} }
/* sidebar */ /* sidebar */
@ -33,6 +34,11 @@ body {
font-size: large; font-size: large;
} }
/* KLUDGE */
#add-remove > button.emoji {
font-family: 'Noto Emoji', sans-serif;
}
/* outline */ /* outline */
#outline { #outline {
@ -93,10 +99,11 @@ details[open]:has(li) .elt-switch::after {
display: flex; display: flex;
} }
.elt-rep > div, .cst-rep { .elt-rep > div {
padding: 2px 0px 0px 0px; padding: 2px 0px 0px 0px;
font-size: 10pt; font-size: 10pt;
text-align: center; font-variant-numeric: tabular-nums;
text-align: right;
width: 56px; width: 56px;
} }
@ -104,10 +111,38 @@ details[open]:has(li) .elt-switch::after {
font-style: italic; font-style: italic;
} }
.cst > input { .cst.invalid {
color: #f58fc2;
}
.cst > input[type=checkbox] {
margin: 0px 8px 0px 0px; margin: 0px 8px 0px 0px;
} }
.cst > input[type=text] {
color: inherit;
background-color: inherit;
border: 1px solid #555;
border-radius: 2px;
}
.cst.invalid > input[type=text] {
border-color: #70495c;
}
.status {
width: 20px;
padding-left: 4px;
text-align: center;
font-family: 'Noto Emoji';
font-style: normal;
}
.invalid > .status::after, details:has(.invalid):not([open]) .status::after {
content: '⚠';
color: #f58fc2;
}
/* display */ /* display */
canvas { canvas {

8
app-proto/run-examples Executable file
View File

@ -0,0 +1,8 @@
# based on "Enabling print statements in Cargo tests", by Jon Almeida
#
# https://jonalmeida.com/posts/2015/01/23/print-cargo/
#
cargo test -- --nocapture engine::tests::irisawa_hexlet_test
cargo test -- --nocapture engine::tests::three_spheres_example
cargo test -- --nocapture engine::tests::point_on_sphere_example

View File

@ -1,4 +1,3 @@
use std::collections::BTreeSet; /* DEBUG */
use sycamore::prelude::*; use sycamore::prelude::*;
use web_sys::{console, wasm_bindgen::JsValue}; use web_sys::{console, wasm_bindgen::JsValue};
@ -7,68 +6,52 @@ use crate::{engine, AppState, assembly::{Assembly, Constraint, Element}};
/* DEBUG */ /* DEBUG */
fn load_gen_assemb(assembly: &Assembly) { fn load_gen_assemb(assembly: &Assembly) {
let _ = assembly.try_insert_element( let _ = assembly.try_insert_element(
Element { Element::new(
id: String::from("gemini_a"), String::from("gemini_a"),
label: String::from("Castor"), String::from("Castor"),
color: [1.00_f32, 0.25_f32, 0.00_f32], [1.00_f32, 0.25_f32, 0.00_f32],
rep: engine::sphere(0.5, 0.5, 0.0, 1.0), engine::sphere(0.5, 0.5, 0.0, 1.0)
constraints: BTreeSet::default() )
}
); );
let _ = assembly.try_insert_element( let _ = assembly.try_insert_element(
Element { Element::new(
id: String::from("gemini_b"), String::from("gemini_b"),
label: String::from("Pollux"), String::from("Pollux"),
color: [0.00_f32, 0.25_f32, 1.00_f32], [0.00_f32, 0.25_f32, 1.00_f32],
rep: engine::sphere(-0.5, -0.5, 0.0, 1.0), engine::sphere(-0.5, -0.5, 0.0, 1.0)
constraints: BTreeSet::default() )
}
); );
let _ = assembly.try_insert_element( let _ = assembly.try_insert_element(
Element { Element::new(
id: String::from("ursa_major"), String::from("ursa_major"),
label: String::from("Ursa major"), String::from("Ursa major"),
color: [0.25_f32, 0.00_f32, 1.00_f32], [0.25_f32, 0.00_f32, 1.00_f32],
rep: engine::sphere(-0.5, 0.5, 0.0, 0.75), engine::sphere(-0.5, 0.5, 0.0, 0.75)
constraints: BTreeSet::default() )
}
); );
let _ = assembly.try_insert_element( let _ = assembly.try_insert_element(
Element { Element::new(
id: String::from("ursa_minor"), String::from("ursa_minor"),
label: String::from("Ursa minor"), String::from("Ursa minor"),
color: [0.25_f32, 1.00_f32, 0.00_f32], [0.25_f32, 1.00_f32, 0.00_f32],
rep: engine::sphere(0.5, -0.5, 0.0, 0.5), engine::sphere(0.5, -0.5, 0.0, 0.5)
constraints: BTreeSet::default() )
}
); );
let _ = assembly.try_insert_element( let _ = assembly.try_insert_element(
Element { Element::new(
id: String::from("moon_deimos"), String::from("moon_deimos"),
label: String::from("Deimos"), String::from("Deimos"),
color: [0.75_f32, 0.75_f32, 0.00_f32], [0.75_f32, 0.75_f32, 0.00_f32],
rep: engine::sphere(0.0, 0.15, 1.0, 0.25), engine::sphere(0.0, 0.15, 1.0, 0.25)
constraints: BTreeSet::default() )
}
); );
let _ = assembly.try_insert_element( let _ = assembly.try_insert_element(
Element { Element::new(
id: String::from("moon_phobos"), String::from("moon_phobos"),
label: String::from("Phobos"), String::from("Phobos"),
color: [0.00_f32, 0.75_f32, 0.50_f32], [0.00_f32, 0.75_f32, 0.50_f32],
rep: engine::sphere(0.0, -0.15, -1.0, 0.25), engine::sphere(0.0, -0.15, -1.0, 0.25)
constraints: BTreeSet::default() )
}
);
assembly.insert_constraint(
Constraint {
args: (
assembly.elements_by_id.with_untracked(|elts_by_id| elts_by_id["gemini_a"]),
assembly.elements_by_id.with_untracked(|elts_by_id| elts_by_id["gemini_b"])
),
rep: 0.5,
active: create_signal(true)
}
); );
} }
@ -76,76 +59,68 @@ fn load_gen_assemb(assembly: &Assembly) {
fn load_low_curv_assemb(assembly: &Assembly) { fn load_low_curv_assemb(assembly: &Assembly) {
let a = 0.75_f64.sqrt(); let a = 0.75_f64.sqrt();
let _ = assembly.try_insert_element( let _ = assembly.try_insert_element(
Element { Element::new(
id: "central".to_string(), "central".to_string(),
label: "Central".to_string(), "Central".to_string(),
color: [0.75_f32, 0.75_f32, 0.75_f32], [0.75_f32, 0.75_f32, 0.75_f32],
rep: engine::sphere(0.0, 0.0, 0.0, 1.0), engine::sphere(0.0, 0.0, 0.0, 1.0)
constraints: BTreeSet::default() )
}
); );
let _ = assembly.try_insert_element( let _ = assembly.try_insert_element(
Element { Element::new(
id: "assemb_plane".to_string(), "assemb_plane".to_string(),
label: "Assembly plane".to_string(), "Assembly plane".to_string(),
color: [0.75_f32, 0.75_f32, 0.75_f32], [0.75_f32, 0.75_f32, 0.75_f32],
rep: engine::sphere_with_offset(0.0, 0.0, 1.0, 0.0, 0.0), engine::sphere_with_offset(0.0, 0.0, 1.0, 0.0, 0.0)
constraints: BTreeSet::default() )
}
); );
let _ = assembly.try_insert_element( let _ = assembly.try_insert_element(
Element { Element::new(
id: "side1".to_string(), "side1".to_string(),
label: "Side 1".to_string(), "Side 1".to_string(),
color: [1.00_f32, 0.00_f32, 0.25_f32], [1.00_f32, 0.00_f32, 0.25_f32],
rep: engine::sphere_with_offset(1.0, 0.0, 0.0, 1.0, 0.0), engine::sphere_with_offset(1.0, 0.0, 0.0, 1.0, 0.0)
constraints: BTreeSet::default() )
}
); );
let _ = assembly.try_insert_element( let _ = assembly.try_insert_element(
Element { Element::new(
id: "side2".to_string(), "side2".to_string(),
label: "Side 2".to_string(), "Side 2".to_string(),
color: [0.25_f32, 1.00_f32, 0.00_f32], [0.25_f32, 1.00_f32, 0.00_f32],
rep: engine::sphere_with_offset(-0.5, a, 0.0, 1.0, 0.0), engine::sphere_with_offset(-0.5, a, 0.0, 1.0, 0.0)
constraints: BTreeSet::default() )
}
); );
let _ = assembly.try_insert_element( let _ = assembly.try_insert_element(
Element { Element::new(
id: "side3".to_string(), "side3".to_string(),
label: "Side 3".to_string(), "Side 3".to_string(),
color: [0.00_f32, 0.25_f32, 1.00_f32], [0.00_f32, 0.25_f32, 1.00_f32],
rep: engine::sphere_with_offset(-0.5, -a, 0.0, 1.0, 0.0), engine::sphere_with_offset(-0.5, -a, 0.0, 1.0, 0.0)
constraints: BTreeSet::default() )
}
); );
let _ = assembly.try_insert_element( let _ = assembly.try_insert_element(
Element { Element::new(
id: "corner1".to_string(), "corner1".to_string(),
label: "Corner 1".to_string(), "Corner 1".to_string(),
color: [0.75_f32, 0.75_f32, 0.75_f32], [0.75_f32, 0.75_f32, 0.75_f32],
rep: engine::sphere(-4.0/3.0, 0.0, 0.0, 1.0/3.0), engine::sphere(-4.0/3.0, 0.0, 0.0, 1.0/3.0)
constraints: BTreeSet::default() )
}
); );
let _ = assembly.try_insert_element( let _ = assembly.try_insert_element(
Element { Element::new(
id: "corner2".to_string(), "corner2".to_string(),
label: "Corner 2".to_string(), "Corner 2".to_string(),
color: [0.75_f32, 0.75_f32, 0.75_f32], [0.75_f32, 0.75_f32, 0.75_f32],
rep: engine::sphere(2.0/3.0, -4.0/3.0 * a, 0.0, 1.0/3.0), engine::sphere(2.0/3.0, -4.0/3.0 * a, 0.0, 1.0/3.0)
constraints: BTreeSet::default() )
}
); );
let _ = assembly.try_insert_element( let _ = assembly.try_insert_element(
Element { Element::new(
id: String::from("corner3"), String::from("corner3"),
label: String::from("Corner 3"), String::from("Corner 3"),
color: [0.75_f32, 0.75_f32, 0.75_f32], [0.75_f32, 0.75_f32, 0.75_f32],
rep: engine::sphere(2.0/3.0, 4.0/3.0 * a, 0.0, 1.0/3.0), engine::sphere(2.0/3.0, 4.0/3.0 * a, 0.0, 1.0/3.0)
constraints: BTreeSet::default() )
}
); );
} }
@ -198,6 +173,7 @@ pub fn AddRemove() -> View {
} }
) { "+" } ) { "+" }
button( button(
class="emoji", /* KLUDGE */
disabled={ disabled={
let state = use_context::<AppState>(); let state = use_context::<AppState>();
state.selection.with(|sel| sel.len() != 2) state.selection.with(|sel| sel.len() != 2)
@ -210,16 +186,21 @@ pub fn AddRemove() -> View {
(arg_vec[0].clone(), arg_vec[1].clone()) (arg_vec[0].clone(), arg_vec[1].clone())
} }
); );
let rep = create_signal(0.0);
let rep_valid = create_signal(false);
let active = create_signal(true);
state.assembly.insert_constraint(Constraint { state.assembly.insert_constraint(Constraint {
args: args, args: args,
rep: 0.0, rep: rep,
active: create_signal(true) rep_text: create_signal(String::new()),
rep_valid: rep_valid,
active: active,
}); });
state.selection.update(|sel| sel.clear()); state.selection.update(|sel| sel.clear());
/* DEBUG */ /* DEBUG */
// print updated constraint list // print updated constraint list
console::log_1(&JsValue::from("constraints:")); console::log_1(&JsValue::from("Constraints:"));
state.assembly.constraints.with(|csts| { state.assembly.constraints.with(|csts| {
for (_, cst) in csts.into_iter() { for (_, cst) in csts.into_iter() {
console::log_5( console::log_5(
@ -227,10 +208,22 @@ pub fn AddRemove() -> View {
&JsValue::from(cst.args.0), &JsValue::from(cst.args.0),
&JsValue::from(cst.args.1), &JsValue::from(cst.args.1),
&JsValue::from(":"), &JsValue::from(":"),
&JsValue::from(cst.rep) &JsValue::from(cst.rep.get_untracked())
); );
} }
}); });
// update the realization when the constraint becomes active
// and valid, or is edited while active and valid
create_effect(move || {
console::log_1(&JsValue::from(
format!("Constraint ({}, {}) updated", args.0, args.1)
));
rep.track();
if active.get() && rep_valid.get() {
state.assembly.realize();
}
});
} }
) { "🔗" } ) { "🔗" }
select(bind:value=assembly_name) { /* DEBUG */ select(bind:value=assembly_name) { /* DEBUG */

View File

@ -1,22 +1,49 @@
use nalgebra::DVector; use nalgebra::{DMatrix, DVector};
use rustc_hash::FxHashMap; use rustc_hash::FxHashMap;
use slab::Slab; use slab::Slab;
use std::collections::BTreeSet; use std::collections::BTreeSet;
use sycamore::prelude::*; use sycamore::prelude::*;
use web_sys::{console, wasm_bindgen::JsValue}; /* DEBUG */
use crate::engine::{realize_gram, PartialMatrix};
#[derive(Clone, PartialEq)] #[derive(Clone, PartialEq)]
pub struct Element { pub struct Element {
pub id: String, pub id: String,
pub label: String, pub label: String,
pub color: [f32; 3], pub color: [f32; 3],
pub rep: DVector<f64>, pub rep: Signal<DVector<f64>>,
pub constraints: BTreeSet<usize> pub constraints: Signal<BTreeSet<usize>>,
// internal properties, not reflected in any view
pub index: usize
} }
impl Element {
pub fn new(
id: String,
label: String,
color: [f32; 3],
rep: DVector<f64>
) -> Element {
Element {
id: id,
label: label,
color: color,
rep: create_signal(rep),
constraints: create_signal(BTreeSet::default()),
index: 0
}
}
}
#[derive(Clone)] #[derive(Clone)]
pub struct Constraint { pub struct Constraint {
pub args: (usize, usize), pub args: (usize, usize),
pub rep: f64, pub rep: Signal<f64>,
pub rep_text: Signal<String>,
pub rep_valid: Signal<bool>,
pub active: Signal<bool> pub active: Signal<bool>
} }
@ -40,6 +67,8 @@ impl Assembly {
} }
} }
// --- inserting elements and constraints ---
// insert an element into the assembly without checking whether we already // insert an element into the assembly without checking whether we already
// have an element with the same identifier. any element that does have the // have an element with the same identifier. any element that does have the
// same identifier will get kicked out of the `elements_by_id` index // same identifier will get kicked out of the `elements_by_id` index
@ -72,22 +101,103 @@ impl Assembly {
// create and insert a new element // create and insert a new element
self.insert_element_unchecked( self.insert_element_unchecked(
Element { Element::new(
id: id, id,
label: format!("Sphere {}", id_num), format!("Sphere {}", id_num),
color: [0.75_f32, 0.75_f32, 0.75_f32], [0.75_f32, 0.75_f32, 0.75_f32],
rep: DVector::<f64>::from_column_slice(&[0.0, 0.0, 0.0, 0.5, -0.5]), DVector::<f64>::from_column_slice(&[0.0, 0.0, 0.0, 0.5, -0.5])
constraints: BTreeSet::default() )
}
); );
} }
pub fn insert_constraint(&self, constraint: Constraint) { pub fn insert_constraint(&self, constraint: Constraint) {
let args = constraint.args; let args = constraint.args;
let key = self.constraints.update(|csts| csts.insert(constraint)); let key = self.constraints.update(|csts| csts.insert(constraint));
self.elements.update(|elts| { let arg_constraints = self.elements.with(
elts[args.0].constraints.insert(key); |elts| (elts[args.0].constraints, elts[args.1].constraints)
elts[args.1].constraints.insert(key); );
}) arg_constraints.0.update(|csts| csts.insert(key));
arg_constraints.1.update(|csts| csts.insert(key));
}
// --- realization ---
pub fn realize(&self) {
// index the elements
self.elements.update_silent(|elts| {
for (index, (_, elt)) in elts.into_iter().enumerate() {
elt.index = index;
}
});
// set up the Gram matrix and the initial configuration matrix
let (gram, guess) = self.elements.with_untracked(|elts| {
// set up the off-diagonal part of the Gram matrix
let mut gram_to_be = PartialMatrix::new();
self.constraints.with_untracked(|csts| {
for (_, cst) in csts {
if cst.active.get_untracked() && cst.rep_valid.get_untracked() {
let args = cst.args;
let row = elts[args.0].index;
let col = elts[args.1].index;
gram_to_be.push_sym(row, col, cst.rep.get_untracked());
}
}
});
// set up the initial configuration matrix and the diagonal of the
// Gram matrix
let mut guess_to_be = DMatrix::<f64>::zeros(5, elts.len());
for (_, elt) in elts {
let index = elt.index;
gram_to_be.push_sym(index, index, 1.0);
guess_to_be.set_column(index, &elt.rep.get_clone_untracked());
}
(gram_to_be, guess_to_be)
});
/* DEBUG */
// log the Gram matrix
console::log_1(&JsValue::from("Gram matrix:"));
gram.log_to_console();
/* DEBUG */
// log the initial configuration matrix
console::log_1(&JsValue::from("Old configuration:"));
for j in 0..guess.nrows() {
let mut row_str = String::new();
for k in 0..guess.ncols() {
row_str.push_str(format!(" {:>8.3}", guess[(j, k)]).as_str());
}
console::log_1(&JsValue::from(row_str));
}
// look for a configuration with the given Gram matrix
let (config, success, history) = realize_gram(
&gram, guess, &[],
1.0e-12, 0.5, 0.9, 1.1, 200, 110
);
/* DEBUG */
// report the outcome of the search
console::log_1(&JsValue::from(
if success {
"Target accuracy achieved!"
} else {
"Failed to reach target accuracy"
}
));
console::log_2(&JsValue::from("Steps:"), &JsValue::from(history.scaled_loss.len() - 1));
console::log_2(&JsValue::from("Loss:"), &JsValue::from(*history.scaled_loss.last().unwrap()));
if success {
// read out the solution
for (_, elt) in self.elements.get_clone_untracked() {
elt.rep.update(
|rep| rep.set_column(0, &config.column(elt.index))
);
}
}
} }
} }

View File

@ -103,7 +103,11 @@ pub fn Display() -> View {
// change listener // change listener
let scene_changed = create_signal(true); let scene_changed = create_signal(true);
create_effect(move || { create_effect(move || {
state.assembly.elements.track(); state.assembly.elements.with(|elts| {
for (_, elt) in elts {
elt.rep.track();
}
});
state.selection.track(); state.selection.track();
scene_changed.set(true); scene_changed.set(true);
}); });
@ -295,23 +299,40 @@ pub fn Display() -> View {
let assembly_to_world = &location * &orientation; let assembly_to_world = &location * &orientation;
// get the assembly // get the assembly
let elements = state.assembly.elements.get_clone(); let (
let element_iter = (&elements).into_iter(); elt_cnt,
let reps_world: Vec<_> = element_iter.clone().map(|(_, elt)| &assembly_to_world * &elt.rep).collect(); reps_world,
let colors: Vec<_> = element_iter.clone().map(|(key, elt)| colors,
if state.selection.with(|sel| sel.contains(&key)) { highlights
elt.color.map(|ch| 0.2 + 0.8*ch) ) = state.assembly.elements.with(|elts| {
} else { (
elt.color // number of elements
} elts.len() as i32,
).collect();
let highlights: Vec<_> = element_iter.map(|(key, _)| // representation vectors in world coordinates
if state.selection.with(|sel| sel.contains(&key)) { elts.iter().map(
1.0_f32 |(_, elt)| elt.rep.with(|rep| &assembly_to_world * rep)
} else { ).collect::<Vec<_>>(),
HIGHLIGHT
} // colors
).collect(); elts.iter().map(|(key, elt)| {
if state.selection.with(|sel| sel.contains(&key)) {
elt.color.map(|ch| 0.2 + 0.8*ch)
} else {
elt.color
}
}).collect::<Vec<_>>(),
// highlight levels
elts.iter().map(|(key, _)| {
if state.selection.with(|sel| sel.contains(&key)) {
1.0_f32
} else {
HIGHLIGHT
}
}).collect::<Vec<_>>()
)
});
// set the resolution // set the resolution
let width = canvas.width() as f32; let width = canvas.width() as f32;
@ -320,7 +341,7 @@ pub fn Display() -> View {
ctx.uniform1f(shortdim_loc.as_ref(), width.min(height)); ctx.uniform1f(shortdim_loc.as_ref(), width.min(height));
// pass the assembly // pass the assembly
ctx.uniform1i(sphere_cnt_loc.as_ref(), elements.len() as i32); ctx.uniform1i(sphere_cnt_loc.as_ref(), elt_cnt);
for n in 0..reps_world.len() { for n in 0..reps_world.len() {
let v = &reps_world[n]; let v = &reps_world[n];
ctx.uniform3f( ctx.uniform3f(

View File

@ -1,4 +1,12 @@
use nalgebra::DVector; use lazy_static::lazy_static;
use nalgebra::{Const, DMatrix, DVector, Dyn};
use web_sys::{console, wasm_bindgen::JsValue}; /* DEBUG */
// --- elements ---
pub fn point(x: f64, y: f64, z: f64) -> DVector<f64> {
DVector::from_column_slice(&[x, y, z, 0.5, 0.5*(x*x + y*y + z*z)])
}
// the sphere with the given center and radius, with inward-pointing normals // the sphere with the given center and radius, with inward-pointing normals
pub fn sphere(center_x: f64, center_y: f64, center_z: f64, radius: f64) -> DVector<f64> { pub fn sphere(center_x: f64, center_y: f64, center_z: f64, radius: f64) -> DVector<f64> {
@ -25,3 +33,470 @@ pub fn sphere_with_offset(dir_x: f64, dir_y: f64, dir_z: f64, off: f64, curv: f6
off * (1.0 + 0.5 * off * curv) off * (1.0 + 0.5 * off * curv)
]) ])
} }
// --- partial matrices ---
struct MatrixEntry {
index: (usize, usize),
value: f64
}
pub struct PartialMatrix(Vec<MatrixEntry>);
impl PartialMatrix {
pub fn new() -> PartialMatrix {
PartialMatrix(Vec::<MatrixEntry>::new())
}
pub fn push_sym(&mut self, row: usize, col: usize, value: f64) {
let PartialMatrix(entries) = self;
entries.push(MatrixEntry { index: (row, col), value: value });
if row != col {
entries.push(MatrixEntry { index: (col, row), value: value });
}
}
/* DEBUG */
pub fn log_to_console(&self) {
let PartialMatrix(entries) = self;
for ent in entries {
let ent_str = format!(" {} {} {}", ent.index.0, ent.index.1, ent.value);
console::log_1(&JsValue::from(ent_str.as_str()));
}
}
fn proj(&self, a: &DMatrix<f64>) -> DMatrix<f64> {
let mut result = DMatrix::<f64>::zeros(a.nrows(), a.ncols());
let PartialMatrix(entries) = self;
for ent in entries {
result[ent.index] = a[ent.index];
}
result
}
fn sub_proj(&self, rhs: &DMatrix<f64>) -> DMatrix<f64> {
let mut result = DMatrix::<f64>::zeros(rhs.nrows(), rhs.ncols());
let PartialMatrix(entries) = self;
for ent in entries {
result[ent.index] = ent.value - rhs[ent.index];
}
result
}
}
// --- descent history ---
pub struct DescentHistory {
pub config: Vec<DMatrix<f64>>,
pub scaled_loss: Vec<f64>,
pub neg_grad: Vec<DMatrix<f64>>,
pub min_eigval: Vec<f64>,
pub base_step: Vec<DMatrix<f64>>,
pub backoff_steps: Vec<i32>
}
impl DescentHistory {
fn new() -> DescentHistory {
DescentHistory {
config: Vec::<DMatrix<f64>>::new(),
scaled_loss: Vec::<f64>::new(),
neg_grad: Vec::<DMatrix<f64>>::new(),
min_eigval: Vec::<f64>::new(),
base_step: Vec::<DMatrix<f64>>::new(),
backoff_steps: Vec::<i32>::new(),
}
}
}
// --- gram matrix realization ---
// the Lorentz form
lazy_static! {
static ref Q: DMatrix<f64> = DMatrix::from_row_slice(5, 5, &[
1.0, 0.0, 0.0, 0.0, 0.0,
0.0, 1.0, 0.0, 0.0, 0.0,
0.0, 0.0, 1.0, 0.0, 0.0,
0.0, 0.0, 0.0, 0.0, -2.0,
0.0, 0.0, 0.0, -2.0, 0.0
]);
}
struct SearchState {
config: DMatrix<f64>,
err_proj: DMatrix<f64>,
loss: f64
}
impl SearchState {
fn from_config(gram: &PartialMatrix, config: DMatrix<f64>) -> SearchState {
let err_proj = gram.sub_proj(&(config.tr_mul(&*Q) * &config));
let loss = err_proj.norm_squared();
SearchState {
config: config,
err_proj: err_proj,
loss: loss
}
}
}
fn basis_matrix(index: (usize, usize), nrows: usize, ncols: usize) -> DMatrix<f64> {
let mut result = DMatrix::<f64>::zeros(nrows, ncols);
result[index] = 1.0;
result
}
// use backtracking line search to find a better configuration
fn seek_better_config(
gram: &PartialMatrix,
state: &SearchState,
base_step: &DMatrix<f64>,
base_target_improvement: f64,
min_efficiency: f64,
backoff: f64,
max_backoff_steps: i32
) -> Option<(SearchState, i32)> {
let mut rate = 1.0;
for backoff_steps in 0..max_backoff_steps {
let trial_config = &state.config + rate * base_step;
let trial_state = SearchState::from_config(gram, trial_config);
let improvement = state.loss - trial_state.loss;
if improvement >= min_efficiency * rate * base_target_improvement {
return Some((trial_state, backoff_steps));
}
rate *= backoff;
}
None
}
// seek a matrix `config` for which `config' * Q * config` matches the partial
// matrix `gram`. use gradient descent starting from `guess`
pub fn realize_gram(
gram: &PartialMatrix,
guess: DMatrix<f64>,
frozen: &[(usize, usize)],
scaled_tol: f64,
min_efficiency: f64,
backoff: f64,
reg_scale: f64,
max_descent_steps: i32,
max_backoff_steps: i32
) -> (DMatrix<f64>, bool, DescentHistory) {
// start the descent history
let mut history = DescentHistory::new();
// find the dimension of the search space
let element_dim = guess.nrows();
let assembly_dim = guess.ncols();
let total_dim = element_dim * assembly_dim;
// scale the tolerance
let scale_adjustment = (gram.0.len() as f64).sqrt();
let tol = scale_adjustment * scaled_tol;
// convert the frozen indices to stacked format
let frozen_stacked: Vec<usize> = frozen.into_iter().map(
|index| index.1*element_dim + index.0
).collect();
// use Newton's method with backtracking and gradient descent backup
let mut state = SearchState::from_config(gram, guess);
for _ in 0..max_descent_steps {
// stop if the loss is tolerably low
history.config.push(state.config.clone());
history.scaled_loss.push(state.loss / scale_adjustment);
if state.loss < tol { break; }
// find the negative gradient of the loss function
let neg_grad = 4.0 * &*Q * &state.config * &state.err_proj;
let mut neg_grad_stacked = neg_grad.clone().reshape_generic(Dyn(total_dim), Const::<1>);
history.neg_grad.push(neg_grad.clone());
// find the negative Hessian of the loss function
let mut hess_cols = Vec::<DVector<f64>>::with_capacity(total_dim);
for col in 0..assembly_dim {
for row in 0..element_dim {
let index = (row, col);
let basis_mat = basis_matrix(index, element_dim, assembly_dim);
let neg_d_err =
basis_mat.tr_mul(&*Q) * &state.config
+ state.config.tr_mul(&*Q) * &basis_mat;
let neg_d_err_proj = gram.proj(&neg_d_err);
let deriv_grad = 4.0 * &*Q * (
-&basis_mat * &state.err_proj
+ &state.config * &neg_d_err_proj
);
hess_cols.push(deriv_grad.reshape_generic(Dyn(total_dim), Const::<1>));
}
}
let mut hess = DMatrix::from_columns(hess_cols.as_slice());
// regularize the Hessian
let min_eigval = hess.symmetric_eigenvalues().min();
if min_eigval <= 0.0 {
hess -= reg_scale * min_eigval * DMatrix::identity(total_dim, total_dim);
}
history.min_eigval.push(min_eigval);
// project the negative gradient and negative Hessian onto the
// orthogonal complement of the frozen subspace
let zero_col = DVector::zeros(total_dim);
let zero_row = zero_col.transpose();
for &k in &frozen_stacked {
neg_grad_stacked[k] = 0.0;
hess.set_row(k, &zero_row);
hess.set_column(k, &zero_col);
hess[(k, k)] = 1.0;
}
// compute the Newton step
/*
we need to either handle or eliminate the case where the minimum
eigenvalue of the Hessian is zero, so the regularized Hessian is
singular. right now, this causes the Cholesky decomposition to return
`None`, leading to a panic when we unrap
*/
let base_step_stacked = hess.cholesky().unwrap().solve(&neg_grad_stacked);
let base_step = base_step_stacked.reshape_generic(Dyn(element_dim), Dyn(assembly_dim));
history.base_step.push(base_step.clone());
// use backtracking line search to find a better configuration
match seek_better_config(
gram, &state, &base_step, neg_grad.dot(&base_step),
min_efficiency, backoff, max_backoff_steps
) {
Some((better_state, backoff_steps)) => {
state = better_state;
history.backoff_steps.push(backoff_steps);
},
None => return (state.config, false, history)
};
}
(state.config, state.loss < tol, history)
}
// --- tests ---
#[cfg(test)]
mod tests {
use std::{array, f64::consts::PI};
use super::*;
#[test]
fn sub_proj_test() {
let target = PartialMatrix(vec![
MatrixEntry { index: (0, 0), value: 19.0 },
MatrixEntry { index: (0, 2), value: 39.0 },
MatrixEntry { index: (1, 1), value: 59.0 },
MatrixEntry { index: (1, 2), value: 69.0 }
]);
let attempt = DMatrix::<f64>::from_row_slice(2, 3, &[
1.0, 2.0, 3.0,
4.0, 5.0, 6.0
]);
let expected_result = DMatrix::<f64>::from_row_slice(2, 3, &[
18.0, 0.0, 36.0,
0.0, 54.0, 63.0
]);
assert_eq!(target.sub_proj(&attempt), expected_result);
}
#[test]
fn zero_loss_test() {
let gram = PartialMatrix({
let mut entries = Vec::<MatrixEntry>::new();
for j in 0..3 {
for k in 0..3 {
entries.push(MatrixEntry {
index: (j, k),
value: if j == k { 1.0 } else { -1.0 }
});
}
}
entries
});
let config = {
let a: f64 = 0.75_f64.sqrt();
DMatrix::from_columns(&[
sphere(1.0, 0.0, 0.0, a),
sphere(-0.5, a, 0.0, a),
sphere(-0.5, -a, 0.0, a)
])
};
let state = SearchState::from_config(&gram, config);
assert!(state.loss.abs() < f64::EPSILON);
}
// this problem is from a sangaku by Irisawa Shintarō Hiroatsu. the article
// below includes a nice translation of the problem statement, which was
// recorded in Uchida Itsumi's book _Kokon sankan_ (_Mathematics, Past and
// Present_)
//
// "Japan's 'Wasan' Mathematical Tradition", by Abe Haruki
// https://www.nippon.com/en/japan-topics/c12801/
//
#[test]
fn irisawa_hexlet_test() {
let gram = PartialMatrix({
let mut entries = Vec::<MatrixEntry>::new();
for s in 0..9 {
// each sphere is represented by a spacelike vector
entries.push(MatrixEntry { index: (s, s), value: 1.0 });
// the circumscribing sphere is tangent to all of the other
// spheres, with matching orientation
if s > 0 {
entries.push(MatrixEntry { index: (0, s), value: 1.0 });
entries.push(MatrixEntry { index: (s, 0), value: 1.0 });
}
if s > 2 {
// each chain sphere is tangent to the "sun" and "moon"
// spheres, with opposing orientation
for n in 1..3 {
entries.push(MatrixEntry { index: (s, n), value: -1.0 });
entries.push(MatrixEntry { index: (n, s), value: -1.0 });
}
// each chain sphere is tangent to the next chain sphere,
// with opposing orientation
let s_next = 3 + (s-2) % 6;
entries.push(MatrixEntry { index: (s, s_next), value: -1.0 });
entries.push(MatrixEntry { index: (s_next, s), value: -1.0 });
}
}
entries
});
let guess = DMatrix::from_columns(
[
sphere(0.0, 0.0, 0.0, 15.0),
sphere(0.0, 0.0, -9.0, 5.0),
sphere(0.0, 0.0, 11.0, 3.0)
].into_iter().chain(
(1..=6).map(
|k| {
let ang = (k as f64) * PI/3.0;
sphere(9.0 * ang.cos(), 9.0 * ang.sin(), 0.0, 2.5)
}
)
).collect::<Vec<_>>().as_slice()
);
let frozen: [(usize, usize); 4] = array::from_fn(|k| (3, k));
const SCALED_TOL: f64 = 1.0e-12;
let (config, success, history) = realize_gram(
&gram, guess, &frozen,
SCALED_TOL, 0.5, 0.9, 1.1, 200, 110
);
let entry_tol = SCALED_TOL.sqrt();
let solution_diams = [30.0, 10.0, 6.0, 5.0, 15.0, 10.0, 3.75, 2.5, 2.0 + 8.0/11.0];
for (k, diam) in solution_diams.into_iter().enumerate() {
assert!((config[(3, k)] - 1.0 / diam).abs() < entry_tol);
}
print!("\nCompleted Gram matrix:{}", config.tr_mul(&*Q) * &config);
if success {
println!("Target accuracy achieved!");
} else {
println!("Failed to reach target accuracy");
}
println!("Steps: {}", history.scaled_loss.len() - 1);
println!("Loss: {}", history.scaled_loss.last().unwrap());
if success {
println!("\nChain diameters:");
println!(" {} sun (given)", 1.0 / config[(3, 3)]);
for k in 4..9 {
println!(" {} sun", 1.0 / config[(3, k)]);
}
}
println!("\nStep │ Loss\n─────┼────────────────────────────────");
for (step, scaled_loss) in history.scaled_loss.into_iter().enumerate() {
println!("{:<4}{}", step, scaled_loss);
}
}
// --- process inspection examples ---
// these tests are meant for human inspection, not automated use. run them
// one at a time in `--nocapture` mode and read through the results and
// optimization histories that they print out. the `run-examples` script
// will run all of them
#[test]
fn three_spheres_example() {
let gram = PartialMatrix({
let mut entries = Vec::<MatrixEntry>::new();
for j in 0..3 {
for k in 0..3 {
entries.push(MatrixEntry {
index: (j, k),
value: if j == k { 1.0 } else { -1.0 }
});
}
}
entries
});
let guess = {
let a: f64 = 0.75_f64.sqrt();
DMatrix::from_columns(&[
sphere(1.0, 0.0, 0.0, 1.0),
sphere(-0.5, a, 0.0, 1.0),
sphere(-0.5, -a, 0.0, 1.0)
])
};
println!();
let (config, success, history) = realize_gram(
&gram, guess, &[],
1.0e-12, 0.5, 0.9, 1.1, 200, 110
);
print!("\nCompleted Gram matrix:{}", config.tr_mul(&*Q) * &config);
if success {
println!("Target accuracy achieved!");
} else {
println!("Failed to reach target accuracy");
}
println!("Steps: {}", history.scaled_loss.len() - 1);
println!("Loss: {}", history.scaled_loss.last().unwrap());
println!("\nStep │ Loss\n─────┼────────────────────────────────");
for (step, scaled_loss) in history.scaled_loss.into_iter().enumerate() {
println!("{:<4}{}", step, scaled_loss);
}
}
#[test]
fn point_on_sphere_example() {
let gram = PartialMatrix({
let mut entries = Vec::<MatrixEntry>::new();
for j in 0..2 {
for k in 0..2 {
entries.push(MatrixEntry {
index: (j, k),
value: if (j, k) == (1, 1) { 1.0 } else { 0.0 }
});
}
}
entries
});
let guess = DMatrix::from_columns(&[
point(0.0, 0.0, 2.0),
sphere(0.0, 0.0, 0.0, 1.0)
]);
let frozen = [(3, 0)];
println!();
let (config, success, history) = realize_gram(
&gram, guess, &frozen,
1.0e-12, 0.5, 0.9, 1.1, 200, 110
);
print!("\nCompleted Gram matrix:{}", config.tr_mul(&*Q) * &config);
print!("Configuration:{}", config);
if success {
println!("Target accuracy achieved!");
} else {
println!("Failed to reach target accuracy");
}
println!("Steps: {}", history.scaled_loss.len() - 1);
println!("Loss: {}", history.scaled_loss.last().unwrap());
println!("\nStep │ Loss\n─────┼────────────────────────────────");
for (step, scaled_loss) in history.scaled_loss.into_iter().enumerate() {
println!("{:<4}{}", step, scaled_loss);
}
}
}

View File

@ -1,26 +1,191 @@
use itertools::Itertools; use itertools::Itertools;
use sycamore::{prelude::*, web::tags::div}; use sycamore::prelude::*;
use web_sys::{Element, KeyboardEvent, MouseEvent, wasm_bindgen::JsCast}; use web_sys::{
Event,
HtmlInputElement,
KeyboardEvent,
MouseEvent,
wasm_bindgen::JsCast
};
use crate::AppState; use crate::{AppState, assembly, assembly::Constraint};
// this component lists the elements of the assembly, showing the constraints // an editable view of the Lorentz product representing a constraint
// on each element as a collapsible sub-list. its implementation is based on #[component(inline_props)]
// Kate Morley's HTML + CSS tree views: fn LorentzProductInput(constraint: Constraint) -> View {
view! {
input(
r#type="text",
bind:value=constraint.rep_text,
on:change=move |event: Event| {
let target: HtmlInputElement = event.target().unwrap().unchecked_into();
match target.value().parse::<f64>() {
Ok(rep) => batch(|| {
constraint.rep.set(rep);
constraint.rep_valid.set(true);
}),
Err(_) => constraint.rep_valid.set(false)
};
}
)
}
}
// a list item that shows a constraint in an outline view of an element
#[component(inline_props)]
fn ConstraintOutlineItem(constraint_key: usize, element_key: usize) -> View {
let state = use_context::<AppState>();
let assembly = &state.assembly;
let constraint = assembly.constraints.with(|csts| csts[constraint_key].clone());
let other_arg = if constraint.args.0 == element_key {
constraint.args.1
} else {
constraint.args.0
};
let other_arg_label = assembly.elements.with(|elts| elts[other_arg].label.clone());
let class = constraint.rep_valid.map(
|&rep_valid| if rep_valid { "cst" } else { "cst invalid" }
);
view! {
li(class=class.get()) {
input(r#type="checkbox", bind:checked=constraint.active)
div(class="cst-label") { (other_arg_label) }
LorentzProductInput(constraint=constraint)
div(class="status")
}
}
}
// a list item that shows an element in an outline view of an assembly
#[component(inline_props)]
fn ElementOutlineItem(key: usize, element: assembly::Element) -> View {
let state = use_context::<AppState>();
let class = state.selection.map(
move |sel| if sel.contains(&key) { "selected" } else { "" }
);
let label = element.label.clone();
let rep_components = element.rep.map(
|rep| rep.iter().map(
|u| format!("{:.3}", u).replace("-", "\u{2212}")
).collect()
);
let constrained = element.constraints.map(|csts| csts.len() > 0);
let constraint_list = element.constraints.map(
|csts| csts.clone().into_iter().collect()
);
let details_node = create_node_ref();
view! {
li {
details(ref=details_node) {
summary(
class=class.get(),
on:keydown={
move |event: KeyboardEvent| {
match event.key().as_str() {
"Enter" => {
if event.shift_key() {
state.selection.update(|sel| {
if !sel.remove(&key) {
sel.insert(key);
}
});
} else {
state.selection.update(|sel| {
sel.clear();
sel.insert(key);
});
}
event.prevent_default();
},
"ArrowRight" if constrained.get() => {
let _ = details_node
.get()
.unchecked_into::<web_sys::Element>()
.set_attribute("open", "");
},
"ArrowLeft" => {
let _ = details_node
.get()
.unchecked_into::<web_sys::Element>()
.remove_attribute("open");
},
_ => ()
}
}
}
) {
div(
class="elt-switch",
on:click=|event: MouseEvent| event.stop_propagation()
)
div(
class="elt",
on:click={
move |event: MouseEvent| {
if event.shift_key() {
state.selection.update(|sel| {
if !sel.remove(&key) {
sel.insert(key);
}
});
} else {
state.selection.update(|sel| {
sel.clear();
sel.insert(key);
});
}
event.stop_propagation();
event.prevent_default();
}
}
) {
div(class="elt-label") { (label) }
div(class="elt-rep") {
Indexed(
list=rep_components,
view=|coord_str| view! {
div { (coord_str) }
}
)
}
div(class="status")
}
}
ul(class="constraints") {
Keyed(
list=constraint_list,
view=move |cst_key| view! {
ConstraintOutlineItem(
constraint_key=cst_key,
element_key=key
)
},
key=|cst_key| cst_key.clone()
)
}
}
}
}
}
// a component that lists the elements of the current assembly, showing the
// constraints on each element as a collapsible sub-list. its implementation
// is based on Kate Morley's HTML + CSS tree views:
// //
// https://iamkate.com/code/tree-views/ // https://iamkate.com/code/tree-views/
// //
#[component] #[component]
pub fn Outline() -> View { pub fn Outline() -> View {
// sort the elements alphabetically by ID let state = use_context::<AppState>();
let elements_sorted = create_memo(|| {
let state = use_context::<AppState>(); // list the elements alphabetically by ID
state.assembly.elements let element_list = state.assembly.elements.map(
.get_clone() |elts| elts
.clone()
.into_iter() .into_iter()
.sorted_by_key(|(_, elt)| elt.id.clone()) .sorted_by_key(|(_, elt)| elt.id.clone())
.collect() .collect()
}); );
view! { view! {
ul( ul(
@ -31,130 +196,11 @@ pub fn Outline() -> View {
} }
) { ) {
Keyed( Keyed(
list=elements_sorted, list=element_list,
view=|(key, elt)| { view=|(key, elt)| view! {
let state = use_context::<AppState>(); ElementOutlineItem(key=key, element=elt)
let class = create_memo({
move || {
if state.selection.with(|sel| sel.contains(&key)) {
"selected"
} else {
""
}
}
});
let label = elt.label.clone();
let rep_components = elt.rep.iter().map(|u| {
let u_coord = u.to_string().replace("-", "\u{2212}");
View::from(div().children(u_coord))
}).collect::<Vec<_>>();
let constrained = elt.constraints.len() > 0;
let details_node = create_node_ref();
view! {
/* [TO DO] switch to integer-valued parameters whenever
that becomes possible again */
li {
details(ref=details_node) {
summary(
class=class.get(),
on:keydown={
move |event: KeyboardEvent| {
match event.key().as_str() {
"Enter" => {
if event.shift_key() {
state.selection.update(|sel| {
if !sel.remove(&key) {
sel.insert(key);
}
});
} else {
state.selection.update(|sel| {
sel.clear();
sel.insert(key);
});
}
event.prevent_default();
},
"ArrowRight" if constrained => {
let _ = details_node
.get()
.unchecked_into::<Element>()
.set_attribute("open", "");
},
"ArrowLeft" => {
let _ = details_node
.get()
.unchecked_into::<Element>()
.remove_attribute("open");
},
_ => ()
}
}
}
) {
div(
class="elt-switch",
on:click=|event: MouseEvent| event.stop_propagation()
)
div(
class="elt",
on:click={
move |event: MouseEvent| {
if event.shift_key() {
state.selection.update(|sel| {
if !sel.remove(&key) {
sel.insert(key);
}
});
} else {
state.selection.update(|sel| {
sel.clear();
sel.insert(key);
});
}
event.stop_propagation();
event.prevent_default();
}
}
) {
div(class="elt-label") { (label) }
div(class="elt-rep") { (rep_components) }
}
}
ul(class="constraints") {
Keyed(
list=elt.constraints.into_iter().collect::<Vec<_>>(),
view=move |c_key: usize| {
let c_state = use_context::<AppState>();
let assembly = &c_state.assembly;
let cst = assembly.constraints.with(|csts| csts[c_key].clone());
let other_arg = if cst.args.0 == key {
cst.args.1
} else {
cst.args.0
};
let other_arg_label = assembly.elements.with(|elts| elts[other_arg].label.clone());
view! {
li(class="cst") {
input(r#type="checkbox", bind:checked=cst.active)
div(class="cst-label") { (other_arg_label) }
div(class="cst-rep") { (cst.rep) }
}
}
},
key=|c_key| c_key.clone()
)
}
}
}
}
}, },
key=|(key, elt)| ( key=|(key, _)| key.clone()
key.clone(),
elt.id.clone(),
elt.label.clone(),
elt.constraints.clone()
)
) )
} }
} }

View File

@ -8,7 +8,8 @@ using Optim
export export
rand_on_shell, Q, DescentHistory, rand_on_shell, Q, DescentHistory,
realize_gram_gradient, realize_gram_newton, realize_gram_optim, realize_gram realize_gram_gradient, realize_gram_newton, realize_gram_optim,
realize_gram_alt_proj, realize_gram
# === guessing === # === guessing ===
@ -143,7 +144,7 @@ function realize_gram_gradient(
break break
end end
# find negative gradient of loss function # find the negative gradient of the loss function
neg_grad = 4*Q*L*Δ_proj neg_grad = 4*Q*L*Δ_proj
slope = norm(neg_grad) slope = norm(neg_grad)
dir = neg_grad / slope dir = neg_grad / slope
@ -232,7 +233,7 @@ function realize_gram_newton(
break break
end end
# find the negative gradient of loss function # find the negative gradient of the loss function
neg_grad = 4*Q*L*Δ_proj neg_grad = 4*Q*L*Δ_proj
# find the negative Hessian of the loss function # find the negative Hessian of the loss function
@ -313,6 +314,129 @@ function realize_gram_optim(
) )
end end
# seek a matrix `L` for which `L'QL` matches the sparse matrix `gram` at every
# explicit entry of `gram`. use gradient descent starting from `guess`, with an
# alternate technique for finding the projected base step from the unprojected
# Hessian
function realize_gram_alt_proj(
gram::SparseMatrixCSC{T, <:Any},
guess::Matrix{T},
frozen = CartesianIndex[];
scaled_tol = 1e-30,
min_efficiency = 0.5,
backoff = 0.9,
reg_scale = 1.1,
max_descent_steps = 200,
max_backoff_steps = 110
) where T <: Number
# start history
history = DescentHistory{T}()
# find the dimension of the search space
dims = size(guess)
element_dim, construction_dim = dims
total_dim = element_dim * construction_dim
# list the constrained entries of the gram matrix
J, K, _ = findnz(gram)
constrained = zip(J, K)
# scale the tolerance
scale_adjustment = sqrt(T(length(constrained)))
tol = scale_adjustment * scaled_tol
# convert the frozen indices to stacked format
frozen_stacked = [(index[2]-1)*element_dim + index[1] for index in frozen]
# initialize search state
L = copy(guess)
Δ_proj = proj_diff(gram, L'*Q*L)
loss = dot(Δ_proj, Δ_proj)
# use Newton's method with backtracking and gradient descent backup
for step in 1:max_descent_steps
# stop if the loss is tolerably low
if loss < tol
break
end
# find the negative gradient of the loss function
neg_grad = 4*Q*L*Δ_proj
# find the negative Hessian of the loss function
hess = Matrix{T}(undef, total_dim, total_dim)
indices = [(j, k) for k in 1:construction_dim for j in 1:element_dim]
for (j, k) in indices
basis_mat = basis_matrix(T, j, k, dims)
neg_dΔ = basis_mat'*Q*L + L'*Q*basis_mat
neg_dΔ_proj = proj_to_entries(neg_dΔ, constrained)
deriv_grad = 4*Q*(-basis_mat*Δ_proj + L*neg_dΔ_proj)
hess[:, (k-1)*element_dim + j] = reshape(deriv_grad, total_dim)
end
hess_sym = Hermitian(hess)
push!(history.hess, hess_sym)
# regularize the Hessian
min_eigval = minimum(eigvals(hess_sym))
push!(history.positive, min_eigval > 0)
if min_eigval <= 0
hess -= reg_scale * min_eigval * I
end
# compute the Newton step
neg_grad_stacked = reshape(neg_grad, total_dim)
for k in frozen_stacked
neg_grad_stacked[k] = 0
hess[k, :] .= 0
hess[:, k] .= 0
hess[k, k] = 1
end
base_step_stacked = Hermitian(hess) \ neg_grad_stacked
base_step = reshape(base_step_stacked, dims)
push!(history.base_step, base_step)
# store the current position, loss, and slope
L_last = L
loss_last = loss
push!(history.scaled_loss, loss / scale_adjustment)
push!(history.neg_grad, neg_grad)
push!(history.slope, norm(neg_grad))
# find a good step size using backtracking line search
push!(history.stepsize, 0)
push!(history.backoff_steps, max_backoff_steps)
empty!(history.last_line_L)
empty!(history.last_line_loss)
rate = one(T)
step_success = false
base_target_improvement = dot(neg_grad, base_step)
for backoff_steps in 0:max_backoff_steps
history.stepsize[end] = rate
L = L_last + rate * base_step
Δ_proj = proj_diff(gram, L'*Q*L)
loss = dot(Δ_proj, Δ_proj)
improvement = loss_last - loss
push!(history.last_line_L, L)
push!(history.last_line_loss, loss / scale_adjustment)
if improvement >= min_efficiency * rate * base_target_improvement
history.backoff_steps[end] = backoff_steps
step_success = true
break
end
rate *= backoff
end
# if we've hit a wall, quit
if !step_success
return L_last, false, history
end
end
# return the factorization and its history
push!(history.scaled_loss, loss / scale_adjustment)
L, loss < tol, history
end
# seek a matrix `L` for which `L'QL` matches the sparse matrix `gram` at every # seek a matrix `L` for which `L'QL` matches the sparse matrix `gram` at every
# explicit entry of `gram`. use gradient descent starting from `guess` # explicit entry of `gram`. use gradient descent starting from `guess`
function realize_gram( function realize_gram(
@ -321,7 +445,6 @@ function realize_gram(
frozen = nothing; frozen = nothing;
scaled_tol = 1e-30, scaled_tol = 1e-30,
min_efficiency = 0.5, min_efficiency = 0.5,
init_rate = 1.0,
backoff = 0.9, backoff = 0.9,
reg_scale = 1.1, reg_scale = 1.1,
max_descent_steps = 200, max_descent_steps = 200,
@ -352,20 +475,19 @@ function realize_gram(
unfrozen_stacked = reshape(is_unfrozen, total_dim) unfrozen_stacked = reshape(is_unfrozen, total_dim)
end end
# initialize variables # initialize search state
grad_rate = init_rate
L = copy(guess) L = copy(guess)
# use Newton's method with backtracking and gradient descent backup
Δ_proj = proj_diff(gram, L'*Q*L) Δ_proj = proj_diff(gram, L'*Q*L)
loss = dot(Δ_proj, Δ_proj) loss = dot(Δ_proj, Δ_proj)
# use Newton's method with backtracking and gradient descent backup
for step in 1:max_descent_steps for step in 1:max_descent_steps
# stop if the loss is tolerably low # stop if the loss is tolerably low
if loss < tol if loss < tol
break break
end end
# find the negative gradient of loss function # find the negative gradient of the loss function
neg_grad = 4*Q*L*Δ_proj neg_grad = 4*Q*L*Δ_proj
# find the negative Hessian of the loss function # find the negative Hessian of the loss function
@ -420,6 +542,7 @@ function realize_gram(
empty!(history.last_line_loss) empty!(history.last_line_loss)
rate = one(T) rate = one(T)
step_success = false step_success = false
base_target_improvement = dot(neg_grad, base_step)
for backoff_steps in 0:max_backoff_steps for backoff_steps in 0:max_backoff_steps
history.stepsize[end] = rate history.stepsize[end] = rate
L = L_last + rate * base_step L = L_last + rate * base_step
@ -428,7 +551,7 @@ function realize_gram(
improvement = loss_last - loss improvement = loss_last - loss
push!(history.last_line_L, L) push!(history.last_line_L, L)
push!(history.last_line_loss, loss / scale_adjustment) push!(history.last_line_loss, loss / scale_adjustment)
if improvement >= min_efficiency * rate * dot(neg_grad, base_step) if improvement >= min_efficiency * rate * base_target_improvement
history.backoff_steps[end] = backoff_steps history.backoff_steps[end] = backoff_steps
step_success = true step_success = true
break break

View File

@ -75,3 +75,12 @@ if success
println(" ", 1 / L[4,k], " sun") println(" ", 1 / L[4,k], " sun")
end end
end end
# test an alternate technique for finding the projected base step from the
# unprojected Hessian
L_alt, success_alt, history_alt = Engine.realize_gram_alt_proj(gram, guess, frozen)
completed_gram_alt = L_alt'*Engine.Q*L_alt
println("\nDifference in result using alternate projection:\n")
display(completed_gram_alt - completed_gram)
println("\nDifference in steps: ", size(history_alt.scaled_loss, 1) - size(history.scaled_loss, 1))
println("Difference in loss: ", history_alt.scaled_loss[end] - history.scaled_loss[end], "\n")

View File

@ -65,3 +65,12 @@ else
end end
println("Steps: ", size(history.scaled_loss, 1)) println("Steps: ", size(history.scaled_loss, 1))
println("Loss: ", history.scaled_loss[end], "\n") println("Loss: ", history.scaled_loss[end], "\n")
# test an alternate technique for finding the projected base step from the
# unprojected Hessian
L_alt, success_alt, history_alt = Engine.realize_gram_alt_proj(gram, guess, frozen)
completed_gram_alt = L_alt'*Engine.Q*L_alt
println("\nDifference in result using alternate projection:\n")
display(completed_gram_alt - completed_gram)
println("\nDifference in steps: ", size(history_alt.scaled_loss, 1) - size(history.scaled_loss, 1))
println("Difference in loss: ", history_alt.scaled_loss[end] - history.scaled_loss[end], "\n")

View File

@ -94,3 +94,12 @@ if success
radius_ratio = dot(infty, Engine.Q * L[:,5]) / dot(infty, Engine.Q * L[:,6]) radius_ratio = dot(infty, Engine.Q * L[:,5]) / dot(infty, Engine.Q * L[:,6])
println("\nCircumradius / inradius: ", radius_ratio) println("\nCircumradius / inradius: ", radius_ratio)
end end
# test an alternate technique for finding the projected base step from the
# unprojected Hessian
L_alt, success_alt, history_alt = Engine.realize_gram_alt_proj(gram, guess, frozen)
completed_gram_alt = L_alt'*Engine.Q*L_alt
println("\nDifference in result using alternate projection:\n")
display(completed_gram_alt - completed_gram)
println("\nDifference in steps: ", size(history_alt.scaled_loss, 1) - size(history.scaled_loss, 1))
println("Difference in loss: ", history_alt.scaled_loss[end] - history.scaled_loss[end], "\n")