From 3bd331d2907ac83057b4493b4b7e4ae08945aad5 Mon Sep 17 00:00:00 2001 From: vyudu Date: Fri, 5 Sep 2025 11:48:35 -0700 Subject: [PATCH 1/7] fix: change parameter_values(::GraphSystemParameters) --- src/symbolic_indexing.jl | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/src/symbolic_indexing.jl b/src/symbolic_indexing.jl index 46dcfb8..5e567d9 100644 --- a/src/symbolic_indexing.jl +++ b/src/symbolic_indexing.jl @@ -95,9 +95,14 @@ function Base.setindex!(u::AbstractArray, val, idx::StateIndex) setindex!(u, val, idx.idx) end +function Base.checkindex(::Type{Bool}, inds::AbstractUnitRange, i::StateIndex) + checkindex(Bool, inds, i.idx) +end + function Base.getindex(u::GraphSystemParameters, p::ParamIndex) u.params_partitioned[p] end + function Base.getindex(u::Tuple, (;tup_index, v_index, prop)::ParamIndex) getproperty(u[tup_index][v_index], prop) end @@ -155,7 +160,7 @@ function SymbolicIndexingInterface.parameter_index(g::PartitionedGraphSystem, sy end function SymbolicIndexingInterface.parameter_values(p::GraphSystemParameters) - p + p.params_partitioned end function SymbolicIndexingInterface.parameter_values(p::GraphSystemParameters, i::ParamIndex) p.params_partitioned[i] From af821a33a0609bcbe5f2a238ef2caf240d55bf32 Mon Sep 17 00:00:00 2001 From: vyudu Date: Fri, 5 Sep 2025 12:05:11 -0700 Subject: [PATCH 2/7] remove checkindex --- src/symbolic_indexing.jl | 5 ----- 1 file changed, 5 deletions(-) diff --git a/src/symbolic_indexing.jl b/src/symbolic_indexing.jl index 5e567d9..2fb365f 100644 --- a/src/symbolic_indexing.jl +++ b/src/symbolic_indexing.jl @@ -95,14 +95,9 @@ function Base.setindex!(u::AbstractArray, val, idx::StateIndex) setindex!(u, val, idx.idx) end -function Base.checkindex(::Type{Bool}, inds::AbstractUnitRange, i::StateIndex) - checkindex(Bool, inds, i.idx) -end - function Base.getindex(u::GraphSystemParameters, p::ParamIndex) u.params_partitioned[p] end - function Base.getindex(u::Tuple, (;tup_index, v_index, prop)::ParamIndex) getproperty(u[tup_index][v_index], prop) end From 07dc66c7c9b8c4c132730aad34b69b98132f6d59 Mon Sep 17 00:00:00 2001 From: vyudu Date: Fri, 5 Sep 2025 12:42:55 -0700 Subject: [PATCH 3/7] fix: revert and add set_parameter! dispatch --- src/symbolic_indexing.jl | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/src/symbolic_indexing.jl b/src/symbolic_indexing.jl index 2fb365f..5ca132f 100644 --- a/src/symbolic_indexing.jl +++ b/src/symbolic_indexing.jl @@ -155,7 +155,7 @@ function SymbolicIndexingInterface.parameter_index(g::PartitionedGraphSystem, sy end function SymbolicIndexingInterface.parameter_values(p::GraphSystemParameters) - p.params_partitioned + p end function SymbolicIndexingInterface.parameter_values(p::GraphSystemParameters, i::ParamIndex) p.params_partitioned[i] @@ -164,6 +164,10 @@ function SymbolicIndexingInterface.parameter_values(p::GraphSystemParameters, i: p.connection_matrices[i] end +function SymbolicIndexingInterface.set_parameter!(p::GraphSystemParameters, val, idx::ParamIndex) + set_param!!(p, nothing, idx, val) +end + function SymbolicIndexingInterface.parameter_symbols(g::PartitionedGraphSystem) collect(Iterators.flatten((keys(g.param_namemap), keys(g.connection_namemap)))) end From 93108e0e8e8efc5cb28d02f7854f6f4fb623c883 Mon Sep 17 00:00:00 2001 From: vyudu Date: Wed, 10 Sep 2025 18:13:55 -0700 Subject: [PATCH 4/7] fix: setp should error on type unstable changes --- src/symbolic_indexing.jl | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/src/symbolic_indexing.jl b/src/symbolic_indexing.jl index 5ca132f..b707d76 100644 --- a/src/symbolic_indexing.jl +++ b/src/symbolic_indexing.jl @@ -165,7 +165,11 @@ function SymbolicIndexingInterface.parameter_values(p::GraphSystemParameters, i: end function SymbolicIndexingInterface.set_parameter!(p::GraphSystemParameters, val, idx::ParamIndex) - set_param!!(p, nothing, idx, val) + (; params_partitioned) = p + params = params_partitioned[tup_index][v_index] + params_new = set_param_prop(params, prop, val; allow_typechange=false) + params_partitioned[tup_index][v_index] = params_new + p end function SymbolicIndexingInterface.parameter_symbols(g::PartitionedGraphSystem) From 345e117710c2dd6c1ca7da585b0fdd17467cc991 Mon Sep 17 00:00:00 2001 From: vyudu Date: Wed, 10 Sep 2025 19:16:25 -0700 Subject: [PATCH 5/7] test: test error on setp --- src/subsystems.jl | 2 +- src/symbolic_indexing.jl | 1 + test/symbolic_indexing.jl | 18 +++++++++++++++++- 3 files changed, 19 insertions(+), 2 deletions(-) diff --git a/src/subsystems.jl b/src/subsystems.jl index caad57e..37d03ba 100644 --- a/src/subsystems.jl +++ b/src/subsystems.jl @@ -17,7 +17,7 @@ function ConstructionBase.setproperties(s::SubsystemParams{T}, patch::NamedTuple set_param_prop(s, patch; allow_typechange=false) end function set_param_prop(s::SubsystemParams{T}, key, val; allow_typechange=false) where {T} - set_param_prop(s, NamedTuple{(key,)}(val); allow_typechange) + set_param_prop(s, NamedTuple{(key,)}((val,)); allow_typechange) end function set_param_prop(s::SubsystemParams{T}, patch; allow_typechange=false) where {T} props = NamedTuple(s) diff --git a/src/symbolic_indexing.jl b/src/symbolic_indexing.jl index b707d76..2d9719c 100644 --- a/src/symbolic_indexing.jl +++ b/src/symbolic_indexing.jl @@ -165,6 +165,7 @@ function SymbolicIndexingInterface.parameter_values(p::GraphSystemParameters, i: end function SymbolicIndexingInterface.set_parameter!(p::GraphSystemParameters, val, idx::ParamIndex) + (; tup_index, v_index, prop) = idx (; params_partitioned) = p params = params_partitioned[tup_index][v_index] params_new = set_param_prop(params, prop, val; allow_typechange=false) diff --git a/test/symbolic_indexing.jl b/test/symbolic_indexing.jl index adc9d34..9e8a56d 100644 --- a/test/symbolic_indexing.jl +++ b/test/symbolic_indexing.jl @@ -1,10 +1,26 @@ include("particle_osc_example.jl") using SymbolicIndexingInterface -@testset "Symbolic Indexing of Vectors" begin +@testset "Symbolic Indexing of Vectors of observables" begin sol = solve_particle_osc(x1=1.0, x2=-1.0) a = getsym(sol, :particle1₊a)(sol)[end] ω = getsym(sol, :osc₊ω₀)(sol)[end] @test getsym(sol, [:osc₊ω₀, :particle1₊a])(sol)[end] == [ω, a] end + +@testset "setp and getp" begin + prob = particle_osc_prob(; x1 = 1.0, x2 = 0.0) + + # Test setp works + setp(prob, :particle1₊m)(prob, 2.0) + @test getp(prob, :particle1₊m)(prob) == 2 + + # Error on type-unstable change + @test_throws ErrorException setp(prob, :particle1₊m)(prob, ones(3)) + + # Remake + prob = remake(prob, p = [:particle1₊m => 2 + 3im, :particle2₊m => 3 + 2im]) + setp(prob, :particle1₊m)(prob, 3 + 3.0im) + @test getp(prob, :particle1₊m)(prob) == 3 + 3im +end From 813338c60dbf430fb5f8e281a1f6a7ee943dfd6c Mon Sep 17 00:00:00 2001 From: Mason Protter Date: Thu, 11 Sep 2025 13:36:01 +0200 Subject: [PATCH 6/7] support and test changes to connection indices. Add failing tests for promotion --- src/graph_system.jl | 4 ++-- src/symbolic_indexing.jl | 29 +++++++++++++++++++++-------- test/particle_osc_example.jl | 2 +- test/symbolic_indexing.jl | 15 +++++++++++++++ 4 files changed, 39 insertions(+), 11 deletions(-) diff --git a/src/graph_system.jl b/src/graph_system.jl index 1178063..7c0b1b3 100644 --- a/src/graph_system.jl +++ b/src/graph_system.jl @@ -229,7 +229,7 @@ function make_connection_matrices(g_flat, nodes_partitioned=make_partitioned_nod push!(conns, conn) for (prop, name) ∈ pairs(connection_property_namemap(conn, get_name(nodekl), get_name(nodeij))) - connection_namemap[name] = ConnectionIndex(nc, k, i, l, j, prop) + connection_namemap[name] = ConnectionIndex(nc, k, i, l, j, name, prop) end for t ∈ event_times(conn) @@ -281,7 +281,7 @@ end name_kl = names_partitioned[k][l] name_ij = names_partitioned[i][j] for (prop, name) ∈ pairs(connection_property_namemap(conn, name_kl, name_ij)) - connection_namemap[name] = ConnectionIndex(nc, k, i, l, j, prop) + connection_namemap[name] = ConnectionIndex(nc, k, i, l, j, name, prop) end end end diff --git a/src/symbolic_indexing.jl b/src/symbolic_indexing.jl index 2d9719c..f31fc16 100644 --- a/src/symbolic_indexing.jl +++ b/src/symbolic_indexing.jl @@ -18,13 +18,13 @@ struct CompuIndex prop::Symbol requires_inputs::Bool end - struct ConnectionIndex nc::Int i_src::Int i_dst::Int j_src::Int j_dst::Int + connection_key::Symbol prop::Symbol end @@ -164,7 +164,7 @@ function SymbolicIndexingInterface.parameter_values(p::GraphSystemParameters, i: p.connection_matrices[i] end -function SymbolicIndexingInterface.set_parameter!(p::GraphSystemParameters, val, idx::ParamIndex) +function SymbolicIndexingInterface.set_parameter!(p::GraphSystemParameters, val, idx::ParamIndex) (; tup_index, v_index, prop) = idx (; params_partitioned) = p params = params_partitioned[tup_index][v_index] @@ -173,6 +173,16 @@ function SymbolicIndexingInterface.set_parameter!(p::GraphSystemParameters, val, p end +function SymbolicIndexingInterface.set_parameter!(buffer::GraphSystemParameters, value, conn_index::ConnectionIndex) + (;connection_matrices, connection_namemap) = buffer + (; nc, i_src, i_dst, j_src, j_dst, connection_key, prop) = conn_index + conn_old = connection_matrices[nc][i_src, i_dst][j_src, j_dst] + conn_new = setproperties(conn_old, NamedTuple{(prop,)}(value)) + connection_matrices[nc][i_src, i_dst][j_src, j_dst] = conn_new + buffer +end + + function SymbolicIndexingInterface.parameter_symbols(g::PartitionedGraphSystem) collect(Iterators.flatten((keys(g.param_namemap), keys(g.connection_namemap)))) end @@ -262,9 +272,9 @@ function set_params!!(buffer::GraphSystemParameters, param_map) (; param_namemap, connection_namemap) = buffer for (key, val) ∈ param_map if haskey(param_namemap, key) - buffer = set_param!!(buffer, key, param_namemap[key], val) + buffer = set_param!!(buffer, param_namemap[key], val) elseif haskey(connection_namemap, key) - buffer = set_param!!(buffer, key, connection_namemap[key], val) + buffer = set_param!!(buffer, connection_namemap[key], val) else error("Key $key does not correspond to a known parameter. ") end @@ -272,7 +282,10 @@ function set_params!!(buffer::GraphSystemParameters, param_map) buffer end -function set_param!!(buffer::GraphSystemParameters, key, (; tup_index, v_index, prop)::ParamIndex, val) + +# This is a possibly-out-of-place variant of set_parameter! that is meant to be used by `remake` where +# types are allowed to be widened. +function set_param!!(buffer::GraphSystemParameters, (; tup_index, v_index, prop)::ParamIndex, val) (; params_partitioned) = buffer params = params_partitioned[tup_index][v_index] params_new = set_param_prop(params, prop, val; allow_typechange=true) @@ -297,9 +310,9 @@ function re_eltype_params(params_partitioned) end end -function set_param!!(buffer::GraphSystemParameters, key, conn_index::ConnectionIndex, value) +function set_param!!(buffer::GraphSystemParameters, conn_index::ConnectionIndex, value) (;connection_matrices, connection_namemap) = buffer - (; nc, i_src, i_dst, j_src, j_dst, prop) = conn_index + (; nc, i_src, i_dst, j_src, j_dst, connection_key, prop) = conn_index conn_old = connection_matrices[nc][i_src, i_dst][j_src, j_dst] conn_new = setproperties(conn_old, NamedTuple{(prop,)}(value)) CR_new = typeof(conn_new) @@ -326,7 +339,7 @@ function set_param!!(buffer::GraphSystemParameters, key, conn_index::ConnectionI # Update the position in the namemap let conn_index_new = @set conn_index.nc = nc_new - connection_namemap[key] = conn_index_new # This is important so we don't lose track of where the parameter moved to! + connection_namemap[connection_key] = conn_index_new # This is important so we don't lose track of where the parameter moved to! end # Delete the old element! diff --git a/test/particle_osc_example.jl b/test/particle_osc_example.jl index 2ede123..6c09bde 100644 --- a/test/particle_osc_example.jl +++ b/test/particle_osc_example.jl @@ -135,7 +135,7 @@ function particle_osc_prob(;x1, x2, m=3.0, mp1=1.0, kc_p1_p2=1, tspan = (0.0, 10 [:particle1₊x => x1, :particle2₊x => x2, :particle2₊v => 0.0, :osc₊x => 0.0 ], tspan, - (:osc₊m => m, :particle1₊m =>mp1, :fac_coulomb_particle1_particle2 => kc_p1_p2) + (:osc₊m => m, :particle1₊m => mp1, :fac_coulomb_particle1_particle2 => kc_p1_p2) ) end function solve_particle_osc(;reltol=nothing, saveat=nothing, kwargs...) diff --git a/test/symbolic_indexing.jl b/test/symbolic_indexing.jl index 9e8a56d..05cda1e 100644 --- a/test/symbolic_indexing.jl +++ b/test/symbolic_indexing.jl @@ -16,6 +16,21 @@ end setp(prob, :particle1₊m)(prob, 2.0) @test getp(prob, :particle1₊m)(prob) == 2 + # Test type promotion and conversion + @test_broken begin + (prob, :particle1₊m)(prob, 20) + @test getp(prob, :particle1₊m)(prob) === 20.0 + end + # Test on connections as well + setp(prob, :fac_coulomb_particle1_particle2)(prob, 100) + @test getp(prob, :fac_coulomb_particle1_particle2)(prob) == 100 + + # Test type promotion and conversion on connections + @test_broken begin + setp(prob, :fac_coulomb_particle1_particle2)(prob, 100.0) + getp(prob, :fac_coulomb_particle1_particle2)(prob) == 100 + end + # Error on type-unstable change @test_throws ErrorException setp(prob, :particle1₊m)(prob, ones(3)) From acb5183fe8d4e767dc1ff637e4a2310a63c63f45 Mon Sep 17 00:00:00 2001 From: Mason Protter Date: Thu, 11 Sep 2025 13:39:26 +0200 Subject: [PATCH 7/7] bump version --- Project.toml | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index 8990f70..06b28b1 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "GraphDynamics" uuid = "bcd5d0fe-e6b7-4ef1-9848-780c183c7f4c" -version = "0.4.7" +version = "0.4.8" [workspace] projects = ["test", "scrap"] @@ -22,6 +22,7 @@ Symbolics = "0c5d862f-8b57-4792-8d23-62f2024744c7" [extensions] MTKExt = ["Symbolics", "ModelingToolkit"] + [compat] Accessors = "0.1" ConstructionBase = "1.5"