Visualize neighborhoods of global minima
This commit is contained in:
parent
77bc124170
commit
f84d475580
99
engine-proto/gram-test/basin-shapes.jl
Normal file
99
engine-proto/gram-test/basin-shapes.jl
Normal file
@ -0,0 +1,99 @@
|
|||||||
|
include("Engine.jl")
|
||||||
|
|
||||||
|
using LinearAlgebra
|
||||||
|
using SparseArrays
|
||||||
|
|
||||||
|
function sphere_in_tetrahedron_shape()
|
||||||
|
# initialize the partial gram matrix for a sphere inscribed in a regular
|
||||||
|
# tetrahedron
|
||||||
|
J = Int64[]
|
||||||
|
K = Int64[]
|
||||||
|
values = BigFloat[]
|
||||||
|
for j in 1:5
|
||||||
|
for k in 1:5
|
||||||
|
push!(J, j)
|
||||||
|
push!(K, k)
|
||||||
|
if j == k
|
||||||
|
push!(values, 1)
|
||||||
|
elseif (j <= 4 && k <= 4)
|
||||||
|
push!(values, -1/BigFloat(3))
|
||||||
|
else
|
||||||
|
push!(values, -1)
|
||||||
|
end
|
||||||
|
end
|
||||||
|
end
|
||||||
|
gram = sparse(J, K, values)
|
||||||
|
|
||||||
|
# plot loss along a slice
|
||||||
|
loss_lin = []
|
||||||
|
loss_sq = []
|
||||||
|
mesh = range(0.9, 1.1, 101)
|
||||||
|
for t in mesh
|
||||||
|
L = hcat(
|
||||||
|
Engine.plane(normalize(BigFloat[ 1, 1, 1]), BigFloat(1)),
|
||||||
|
Engine.plane(normalize(BigFloat[ 1, -1, -1]), BigFloat(1)),
|
||||||
|
Engine.plane(normalize(BigFloat[-1, 1, -1]), BigFloat(1)),
|
||||||
|
Engine.plane(normalize(BigFloat[-1, -1, 1]), BigFloat(1)),
|
||||||
|
Engine.sphere(BigFloat[0, 0, 0], BigFloat(t))
|
||||||
|
)
|
||||||
|
Δ_proj = Engine.proj_diff(gram, L'*Engine.Q*L)
|
||||||
|
push!(loss_lin, norm(Δ_proj))
|
||||||
|
push!(loss_sq, dot(Δ_proj, Δ_proj))
|
||||||
|
end
|
||||||
|
mesh, loss_lin, loss_sq
|
||||||
|
end
|
||||||
|
|
||||||
|
function circles_in_triangle_shape()
|
||||||
|
# initialize the partial gram matrix for a sphere inscribed in a regular
|
||||||
|
# tetrahedron
|
||||||
|
J = Int64[]
|
||||||
|
K = Int64[]
|
||||||
|
values = BigFloat[]
|
||||||
|
for j in 1:8
|
||||||
|
for k in 1:8
|
||||||
|
filled = false
|
||||||
|
if j == k
|
||||||
|
push!(values, 1)
|
||||||
|
filled = true
|
||||||
|
elseif (j == 1 || k == 1)
|
||||||
|
push!(values, 0)
|
||||||
|
filled = true
|
||||||
|
elseif (j == 2 || k == 2)
|
||||||
|
push!(values, -1)
|
||||||
|
filled = true
|
||||||
|
end
|
||||||
|
#=elseif (j <= 5 && j != 2 && k == 9 || k == 9 && k <= 5 && k != 2)
|
||||||
|
push!(values, 0)
|
||||||
|
filled = true
|
||||||
|
end=#
|
||||||
|
if filled
|
||||||
|
push!(J, j)
|
||||||
|
push!(K, k)
|
||||||
|
end
|
||||||
|
end
|
||||||
|
end
|
||||||
|
append!(J, [6, 4, 6, 5, 7, 5, 7, 3, 8, 3, 8, 4])
|
||||||
|
append!(K, [4, 6, 5, 6, 5, 7, 3, 7, 3, 8, 4, 8])
|
||||||
|
append!(values, fill(-1, 12))
|
||||||
|
|
||||||
|
# plot loss along a slice
|
||||||
|
loss_lin = []
|
||||||
|
loss_sq = []
|
||||||
|
mesh = range(0.99, 1.01, 101)
|
||||||
|
for t in mesh
|
||||||
|
L = hcat(
|
||||||
|
Engine.plane(BigFloat[0, 0, 1], BigFloat(0)),
|
||||||
|
Engine.sphere(BigFloat[0, 0, 0], BigFloat(t)),
|
||||||
|
Engine.plane(BigFloat[1, 0, 0], BigFloat(1)),
|
||||||
|
Engine.plane(BigFloat[cos(2pi/3), sin(2pi/3), 0], BigFloat(1)),
|
||||||
|
Engine.plane(BigFloat[cos(-2pi/3), sin(-2pi/3), 0], BigFloat(1)),
|
||||||
|
Engine.sphere(4//3*BigFloat[-1, 0, 0], BigFloat(1//3)),
|
||||||
|
Engine.sphere(4//3*BigFloat[cos(-pi/3), sin(-pi/3), 0], BigFloat(1//3)),
|
||||||
|
Engine.sphere(4//3*BigFloat[cos(pi/3), sin(pi/3), 0], BigFloat(1//3))
|
||||||
|
)
|
||||||
|
Δ_proj = Engine.proj_diff(gram, L'*Engine.Q*L)
|
||||||
|
push!(loss_lin, norm(Δ_proj))
|
||||||
|
push!(loss_sq, dot(Δ_proj, Δ_proj))
|
||||||
|
end
|
||||||
|
mesh, loss_lin, loss_sq
|
||||||
|
end
|
Loading…
Reference in New Issue
Block a user