From f84d475580391e5ef9b5ef2d0eca02af9c121087 Mon Sep 17 00:00:00 2001 From: Aaron Fenyes Date: Tue, 9 Jul 2024 14:01:30 -0700 Subject: [PATCH] Visualize neighborhoods of global minima --- engine-proto/gram-test/basin-shapes.jl | 99 ++++++++++++++++++++++++++ 1 file changed, 99 insertions(+) create mode 100644 engine-proto/gram-test/basin-shapes.jl diff --git a/engine-proto/gram-test/basin-shapes.jl b/engine-proto/gram-test/basin-shapes.jl new file mode 100644 index 0000000..5c03c01 --- /dev/null +++ b/engine-proto/gram-test/basin-shapes.jl @@ -0,0 +1,99 @@ +include("Engine.jl") + +using LinearAlgebra +using SparseArrays + +function sphere_in_tetrahedron_shape() + # initialize the partial gram matrix for a sphere inscribed in a regular + # tetrahedron + J = Int64[] + K = Int64[] + values = BigFloat[] + for j in 1:5 + for k in 1:5 + push!(J, j) + push!(K, k) + if j == k + push!(values, 1) + elseif (j <= 4 && k <= 4) + push!(values, -1/BigFloat(3)) + else + push!(values, -1) + end + end + end + gram = sparse(J, K, values) + + # plot loss along a slice + loss_lin = [] + loss_sq = [] + mesh = range(0.9, 1.1, 101) + for t in mesh + L = hcat( + Engine.plane(normalize(BigFloat[ 1, 1, 1]), BigFloat(1)), + Engine.plane(normalize(BigFloat[ 1, -1, -1]), BigFloat(1)), + Engine.plane(normalize(BigFloat[-1, 1, -1]), BigFloat(1)), + Engine.plane(normalize(BigFloat[-1, -1, 1]), BigFloat(1)), + Engine.sphere(BigFloat[0, 0, 0], BigFloat(t)) + ) + Δ_proj = Engine.proj_diff(gram, L'*Engine.Q*L) + push!(loss_lin, norm(Δ_proj)) + push!(loss_sq, dot(Δ_proj, Δ_proj)) + end + mesh, loss_lin, loss_sq +end + +function circles_in_triangle_shape() + # initialize the partial gram matrix for a sphere inscribed in a regular + # tetrahedron + J = Int64[] + K = Int64[] + values = BigFloat[] + for j in 1:8 + for k in 1:8 + filled = false + if j == k + push!(values, 1) + filled = true + elseif (j == 1 || k == 1) + push!(values, 0) + filled = true + elseif (j == 2 || k == 2) + push!(values, -1) + filled = true + end + #=elseif (j <= 5 && j != 2 && k == 9 || k == 9 && k <= 5 && k != 2) + push!(values, 0) + filled = true + end=# + if filled + push!(J, j) + push!(K, k) + end + end + end + append!(J, [6, 4, 6, 5, 7, 5, 7, 3, 8, 3, 8, 4]) + append!(K, [4, 6, 5, 6, 5, 7, 3, 7, 3, 8, 4, 8]) + append!(values, fill(-1, 12)) + + # plot loss along a slice + loss_lin = [] + loss_sq = [] + mesh = range(0.99, 1.01, 101) + for t in mesh + L = hcat( + Engine.plane(BigFloat[0, 0, 1], BigFloat(0)), + Engine.sphere(BigFloat[0, 0, 0], BigFloat(t)), + Engine.plane(BigFloat[1, 0, 0], BigFloat(1)), + Engine.plane(BigFloat[cos(2pi/3), sin(2pi/3), 0], BigFloat(1)), + Engine.plane(BigFloat[cos(-2pi/3), sin(-2pi/3), 0], BigFloat(1)), + Engine.sphere(4//3*BigFloat[-1, 0, 0], BigFloat(1//3)), + Engine.sphere(4//3*BigFloat[cos(-pi/3), sin(-pi/3), 0], BigFloat(1//3)), + Engine.sphere(4//3*BigFloat[cos(pi/3), sin(pi/3), 0], BigFloat(1//3)) + ) + Δ_proj = Engine.proj_diff(gram, L'*Engine.Q*L) + push!(loss_lin, norm(Δ_proj)) + push!(loss_sq, dot(Δ_proj, Δ_proj)) + end + mesh, loss_lin, loss_sq +end \ No newline at end of file