diff --git a/Project.toml b/Project.toml index 8990f70..06b28b1 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "GraphDynamics" uuid = "bcd5d0fe-e6b7-4ef1-9848-780c183c7f4c" -version = "0.4.7" +version = "0.4.8" [workspace] projects = ["test", "scrap"] @@ -22,6 +22,7 @@ Symbolics = "0c5d862f-8b57-4792-8d23-62f2024744c7" [extensions] MTKExt = ["Symbolics", "ModelingToolkit"] + [compat] Accessors = "0.1" ConstructionBase = "1.5" diff --git a/src/graph_system.jl b/src/graph_system.jl index 1178063..7c0b1b3 100644 --- a/src/graph_system.jl +++ b/src/graph_system.jl @@ -229,7 +229,7 @@ function make_connection_matrices(g_flat, nodes_partitioned=make_partitioned_nod push!(conns, conn) for (prop, name) ∈ pairs(connection_property_namemap(conn, get_name(nodekl), get_name(nodeij))) - connection_namemap[name] = ConnectionIndex(nc, k, i, l, j, prop) + connection_namemap[name] = ConnectionIndex(nc, k, i, l, j, name, prop) end for t ∈ event_times(conn) @@ -281,7 +281,7 @@ end name_kl = names_partitioned[k][l] name_ij = names_partitioned[i][j] for (prop, name) ∈ pairs(connection_property_namemap(conn, name_kl, name_ij)) - connection_namemap[name] = ConnectionIndex(nc, k, i, l, j, prop) + connection_namemap[name] = ConnectionIndex(nc, k, i, l, j, name, prop) end end end diff --git a/src/subsystems.jl b/src/subsystems.jl index caad57e..37d03ba 100644 --- a/src/subsystems.jl +++ b/src/subsystems.jl @@ -17,7 +17,7 @@ function ConstructionBase.setproperties(s::SubsystemParams{T}, patch::NamedTuple set_param_prop(s, patch; allow_typechange=false) end function set_param_prop(s::SubsystemParams{T}, key, val; allow_typechange=false) where {T} - set_param_prop(s, NamedTuple{(key,)}(val); allow_typechange) + set_param_prop(s, NamedTuple{(key,)}((val,)); allow_typechange) end function set_param_prop(s::SubsystemParams{T}, patch; allow_typechange=false) where {T} props = NamedTuple(s) diff --git a/src/symbolic_indexing.jl b/src/symbolic_indexing.jl index 46dcfb8..f31fc16 100644 --- a/src/symbolic_indexing.jl +++ b/src/symbolic_indexing.jl @@ -18,13 +18,13 @@ struct CompuIndex prop::Symbol requires_inputs::Bool end - struct ConnectionIndex nc::Int i_src::Int i_dst::Int j_src::Int j_dst::Int + connection_key::Symbol prop::Symbol end @@ -164,6 +164,25 @@ function SymbolicIndexingInterface.parameter_values(p::GraphSystemParameters, i: p.connection_matrices[i] end +function SymbolicIndexingInterface.set_parameter!(p::GraphSystemParameters, val, idx::ParamIndex) + (; tup_index, v_index, prop) = idx + (; params_partitioned) = p + params = params_partitioned[tup_index][v_index] + params_new = set_param_prop(params, prop, val; allow_typechange=false) + params_partitioned[tup_index][v_index] = params_new + p +end + +function SymbolicIndexingInterface.set_parameter!(buffer::GraphSystemParameters, value, conn_index::ConnectionIndex) + (;connection_matrices, connection_namemap) = buffer + (; nc, i_src, i_dst, j_src, j_dst, connection_key, prop) = conn_index + conn_old = connection_matrices[nc][i_src, i_dst][j_src, j_dst] + conn_new = setproperties(conn_old, NamedTuple{(prop,)}(value)) + connection_matrices[nc][i_src, i_dst][j_src, j_dst] = conn_new + buffer +end + + function SymbolicIndexingInterface.parameter_symbols(g::PartitionedGraphSystem) collect(Iterators.flatten((keys(g.param_namemap), keys(g.connection_namemap)))) end @@ -253,9 +272,9 @@ function set_params!!(buffer::GraphSystemParameters, param_map) (; param_namemap, connection_namemap) = buffer for (key, val) ∈ param_map if haskey(param_namemap, key) - buffer = set_param!!(buffer, key, param_namemap[key], val) + buffer = set_param!!(buffer, param_namemap[key], val) elseif haskey(connection_namemap, key) - buffer = set_param!!(buffer, key, connection_namemap[key], val) + buffer = set_param!!(buffer, connection_namemap[key], val) else error("Key $key does not correspond to a known parameter. ") end @@ -263,7 +282,10 @@ function set_params!!(buffer::GraphSystemParameters, param_map) buffer end -function set_param!!(buffer::GraphSystemParameters, key, (; tup_index, v_index, prop)::ParamIndex, val) + +# This is a possibly-out-of-place variant of set_parameter! that is meant to be used by `remake` where +# types are allowed to be widened. +function set_param!!(buffer::GraphSystemParameters, (; tup_index, v_index, prop)::ParamIndex, val) (; params_partitioned) = buffer params = params_partitioned[tup_index][v_index] params_new = set_param_prop(params, prop, val; allow_typechange=true) @@ -288,9 +310,9 @@ function re_eltype_params(params_partitioned) end end -function set_param!!(buffer::GraphSystemParameters, key, conn_index::ConnectionIndex, value) +function set_param!!(buffer::GraphSystemParameters, conn_index::ConnectionIndex, value) (;connection_matrices, connection_namemap) = buffer - (; nc, i_src, i_dst, j_src, j_dst, prop) = conn_index + (; nc, i_src, i_dst, j_src, j_dst, connection_key, prop) = conn_index conn_old = connection_matrices[nc][i_src, i_dst][j_src, j_dst] conn_new = setproperties(conn_old, NamedTuple{(prop,)}(value)) CR_new = typeof(conn_new) @@ -317,7 +339,7 @@ function set_param!!(buffer::GraphSystemParameters, key, conn_index::ConnectionI # Update the position in the namemap let conn_index_new = @set conn_index.nc = nc_new - connection_namemap[key] = conn_index_new # This is important so we don't lose track of where the parameter moved to! + connection_namemap[connection_key] = conn_index_new # This is important so we don't lose track of where the parameter moved to! end # Delete the old element! diff --git a/test/particle_osc_example.jl b/test/particle_osc_example.jl index 2ede123..6c09bde 100644 --- a/test/particle_osc_example.jl +++ b/test/particle_osc_example.jl @@ -135,7 +135,7 @@ function particle_osc_prob(;x1, x2, m=3.0, mp1=1.0, kc_p1_p2=1, tspan = (0.0, 10 [:particle1₊x => x1, :particle2₊x => x2, :particle2₊v => 0.0, :osc₊x => 0.0 ], tspan, - (:osc₊m => m, :particle1₊m =>mp1, :fac_coulomb_particle1_particle2 => kc_p1_p2) + (:osc₊m => m, :particle1₊m => mp1, :fac_coulomb_particle1_particle2 => kc_p1_p2) ) end function solve_particle_osc(;reltol=nothing, saveat=nothing, kwargs...) diff --git a/test/symbolic_indexing.jl b/test/symbolic_indexing.jl index adc9d34..05cda1e 100644 --- a/test/symbolic_indexing.jl +++ b/test/symbolic_indexing.jl @@ -1,10 +1,41 @@ include("particle_osc_example.jl") using SymbolicIndexingInterface -@testset "Symbolic Indexing of Vectors" begin +@testset "Symbolic Indexing of Vectors of observables" begin sol = solve_particle_osc(x1=1.0, x2=-1.0) a = getsym(sol, :particle1₊a)(sol)[end] ω = getsym(sol, :osc₊ω₀)(sol)[end] @test getsym(sol, [:osc₊ω₀, :particle1₊a])(sol)[end] == [ω, a] end + +@testset "setp and getp" begin + prob = particle_osc_prob(; x1 = 1.0, x2 = 0.0) + + # Test setp works + setp(prob, :particle1₊m)(prob, 2.0) + @test getp(prob, :particle1₊m)(prob) == 2 + + # Test type promotion and conversion + @test_broken begin + (prob, :particle1₊m)(prob, 20) + @test getp(prob, :particle1₊m)(prob) === 20.0 + end + # Test on connections as well + setp(prob, :fac_coulomb_particle1_particle2)(prob, 100) + @test getp(prob, :fac_coulomb_particle1_particle2)(prob) == 100 + + # Test type promotion and conversion on connections + @test_broken begin + setp(prob, :fac_coulomb_particle1_particle2)(prob, 100.0) + getp(prob, :fac_coulomb_particle1_particle2)(prob) == 100 + end + + # Error on type-unstable change + @test_throws ErrorException setp(prob, :particle1₊m)(prob, ones(3)) + + # Remake + prob = remake(prob, p = [:particle1₊m => 2 + 3im, :particle2₊m => 3 + 2im]) + setp(prob, :particle1₊m)(prob, 3 + 3.0im) + @test getp(prob, :particle1₊m)(prob) == 3 + 3im +end