diff --git a/Changelog.md b/Changelog.md new file mode 100644 index 0000000..9481ef5 --- /dev/null +++ b/Changelog.md @@ -0,0 +1,63 @@ +## GraphDynamics v0.7.0 + +### New features + ++ A `PolyesterScheduler()` scheduling object has been added which allows for parallelizing solves using Polyester.jl. This option helps performance during ODE solves where GC time dominates if multithreaded. ++ If any subsystem in a problem has a parameter named `dtmax`, then the smallest `dtmax` parameter in the system is forwarded as a keyword argument to `ODEProblem` and `SDEProblem`s to limit the maximum stepsizes allowed by the solver. ++ If a connection type overrides `GraphDynamics.connection_needs_ctx` to give true, then connections of that type will be given a fourth `ctx` argument which gives it access to the full list of `states_partitioned`, `params_partitioned`, and `connection_matrices` when accumulating inputs. ++ GraphDynamics integrates with Latexify.jl to latexify the equations of a GraphSystem. + +### Breaking changes + ++ `PartitionedGraphSystem` has been removed; `GraphSystem`s holds a field `g.flat_graph` field with a `PartitioningGraphSystem` object which actively flattens and partitions the graph during solving ++ `apply_discrete_event!`, `apply_continuous_event!`, and `ForeachConnectedSubsystem` have had their `vstates` and `vparams` arguments combined into a `sys_view` argument, which gives a view into the affected system for (and the connection form gets a `sys_view_src` and `sys_view_dst`). This `sys_view` can have it's fields be updated in place like so: +```julia +function GraphDynamics.apply_discrete_event!(integrator, sys_view, sys::Subsystem{MyType}, _) + sys_view.x[] = sys.y +end +``` +This will modify the `x` state or parameter of the system when the event triggers. + ++ `apply_subsystem_noise!`'s first argument is now modified in the same way as the view arguments of `apply_discrete_event!`. One would now write e.g. +```julia +function GraphDynamics.apply_subsystem_noise!(vstate, sys::Subsystem{BrownianParticle}, t) + # No noise in position, so we don't modify vstate[:x] + vstate.v[] = sys.σ # White noise in velocity with amplitude σ +end +``` +rather than `vstate[:v] = sys.σ` + +## GraphDynamics v0.6.0 + +[Diff since v0.5.0](https://github.com/Neuroblox/GraphDynamics.jl/compare/v0.5.0...v0.6.0) + +### Breaking changes + +Switched `computed_properties` and `computed_properties_with_inputs` to take a type tag instead of subsystem argument. Now, instead of adding methods like +```julia +function GraphDynamics.computed_properties_with_inputs(::Subsystem{Particle}) + a(sys, input) = input.F / sys.m + (; a) +end +``` +users should do +```julia +function GraphDynamics.computed_properties_with_inputs(::Type{Particle}) + a(sys, input) = input.F / sys.m + (; a) +end +``` + +## GraphDynamics v0.5.0 + +[Diff since v0.4.9](https://github.com/Neuroblox/GraphDynamics.jl/compare/v0.4.9...v0.5.0) + +### Breaking changes + +- Removed the fallback `(cr::ConnectionRule)(src, dst, t) = cr(src, dst)` which was added previously to avoid breakage when we made connection rules take a time argument in addition to the `src` and `dst` subsystems. The presence of this method was a bit of a annoying crutch +- Changed the arguments of `discrete_event_condition` from just `(conn, t)` to `(conn, t, sys_src, sys_dst)` so that events can trigger based off information in the source or destination subsystems as well +- Changed the arguments of `has_discrete_events` from just `(typeof(conn),)` for connection events to `(typeof(conn), get_tag(src), get_tag(dst))`, i.e. it now also can depend on the types of the source and dest subsystems +- Changed the arguments of `event_times` from just `(conn,)` to `(conn, sys_src, sys_dst)` for connection events. + +**Merged pull requests:** +- Remove `(::ConnectionRule)(src, dst, t)` fallback method; make `discrete_event_condition` on connections take `src` / `dst` args. (#44) (@MasonProtter) diff --git a/Project.toml b/Project.toml index c0863e1..07d0d4d 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "GraphDynamics" uuid = "bcd5d0fe-e6b7-4ef1-9848-780c183c7f4c" -version = "0.6.0" +version = "0.7.0" [workspace] projects = ["test", "scrap"] @@ -9,32 +9,36 @@ projects = ["test", "scrap"] Accessors = "7d9f7c33-5ae7-4f3b-8dc6-eff91059b697" ConstructionBase = "187b0558-2788-49d3-abe0-74a17ed4e7c9" DiffEqBase = "2b5f629d-d688-5b77-993f-72d75c75574e" +FieldViews = "ff5a1669-b1f2-423e-bbd7-b7fa0f7e0224" OhMyThreads = "67456a42-1dca-4109-a031-0a68de7e3ad5" OrderedCollections = "bac558e1-5e72-5ebc-8fee-abe8a469f55d" RecursiveArrayTools = "731186ca-8d62-57ce-b412-fbd966d074cd" SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462" SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" SymbolicIndexingInterface = "2efcf032-c050-4f8e-a9bb-153293bab1f5" +Polyester = "f517fe37-dbe3-4b94-8317-1923a5111588" [weakdeps] -ModelingToolkit = "961ee093-0014-501f-94e3-6117800e7a78" +Latexify = "23fbe1c1-3f47-55db-b15f-69d7ec21a316" Symbolics = "0c5d862f-8b57-4792-8d23-62f2024744c7" [extensions] -MTKExt = ["Symbolics", "ModelingToolkit"] +LatexifyExt = ["Symbolics", "Latexify"] +SymbolicsExt = ["Symbolics"] [compat] Accessors = "0.1" ConstructionBase = "1.5" DiffEqBase = "6" -ModelingToolkit = "9, 10" +FieldViews = "0.3.1" +Latexify = "0.16.10" OhMyThreads = "0.6, 0.7, 0.8" OrderedCollections = "1.6.3" RecursiveArrayTools = "3" SciMLBase = "2" SparseArrays = "1" SymbolicIndexingInterface = "0.3" -Symbolics = "6" +Symbolics = "6, 7" julia = "1.10" [extras] diff --git a/ext/LatexifyExt.jl b/ext/LatexifyExt.jl new file mode 100644 index 0000000..cec0e5b --- /dev/null +++ b/ext/LatexifyExt.jl @@ -0,0 +1,152 @@ +module LatexifyExt + +using Symbolics + +using Latexify + +using GraphDynamics: + GraphDynamics, + get_tag, + get_name, + initialize_input, + subsystem_differential, + get_states, + get_params, + to_subsystem, + Subsystem, + SubsystemStates, + SubsystemParams, + ConnectionRule, + connection_property_namemap, + GraphSystem, + nodes, + connections, + graph_equations, + node_equations, + connection_equations, + ConnectionIndex, + GraphSystemConnection, + system_wiring_rule!, + PartitioningGraphSystem + +@variables t +const _D = Differential(t) +const NAMESPACE_SEPARATOR_SYMBOL = Symbol(Symbolics.NAMESPACE_SEPARATOR) + +function GraphDynamics.node_equations(sys::T) where T + sym_subsys, state_vars, input_vars = to_symbolic_subsystem(sys) + rhss = subsystem_differential(sym_subsys, input_vars, t) + eqs = [_D(state_vars[i]) ~ rhss[i] for i in 1:length(rhss)] +end + +function GraphDynamics.node_equations(list_of_sys::Union{Tuple, Vector, Set, Base.KeySet}) + equations = map(collect(list_of_sys)) do n + node_equations(n) + end + reduce(vcat, equations) +end + +function GraphDynamics.connection_equations(sys::GraphSystem, src, dst) + subgraph = GraphSystem() + for (; data) in connections(sys, src, dst) + add_connection!(subgraph, src, dst; data...) + end + subnodes = [n for n ∈ nodes(subgraph) if n ∉ (src, dst)] + [connection_equations(subgraph); node_equations(subnodes)] +end + +function GraphDynamics.connection_equations(sys::PartitioningGraphSystem) + system_cache = Dict{Any, Tuple{Subsystem, NamedTuple, NamedTuple}}() + eqs = map(connections(sys)) do (; src, dst, conn) + GraphDynamics.connection_equations(conn, src, dst; system_cache) + end + reduce(vcat, eqs) +end + +function GraphDynamics.connection_equations(sys::GraphSystem) + GraphDynamics.connection_equations(sys.flat_graph) +end + +function GraphDynamics.connection_equations(conn::ConnectionRule, src, dst; system_cache=nothing) + if isnothing(system_cache) + sym_srcsys, _, _ = to_symbolic_subsystem(src) + sym_dstsys, inputs, _ = to_symbolic_subsystem(dst) + else + sym_srcsys, _, _ = get!(system_cache, src) do + to_symbolic_subsystem(src) + end + sym_dstsys, inputs, _ = get!(system_cache, dst) do + to_symbolic_subsystem(dst) + end + end + conn_props = connection_property_namemap(conn, get_name(src), get_name(dst)) + syms = map(collect(pairs(conn_props))) do (k, v) + if getfield(conn, k) isa Function + only(@variables $v(..)) + else + only(@variables $v) + end + end + + cons = typeof(conn).name.wrapper + sym_conn = cons(syms...) + rhss = sym_conn(sym_srcsys, sym_dstsys, t) + lhss = gen_variables(Val(keys(rhss)), Val(get_name(dst)), Val(true)) + + eqs = map(zip(lhss, rhss)) do (lhs, rhs) + lhs ~ rhs + end +end + +function namespaced_vars(namespace, syms; of_t = true) + ns_syms = map(syms) do name + Symbol(namespace, NAMESPACE_SEPARATOR_SYMBOL, name) + end + + if of_t + Tuple([Expr(:call, sym, :t) for sym in ns_syms]) + else + Tuple(ns_syms) + end +end + +function to_symbolic_subsystem(sys) + to_symbolic_subsystem(to_subsystem(sys); namespace = get_name(sys)) +end + +@generated function gen_variables(::Val{syms}, ::Val{namespace}, ::Val{of_t}) where {syms, namespace, of_t} + nsyms = namespaced_vars(namespace, syms; of_t) + quote + vars = @variables $(Expr(:tuple, nsyms...)) + NamedTuple{$(syms)}(vars) + end +end + +function to_symbolic_subsystem(sys::Subsystem{T}; namespace = :sys) where T + states = propertynames(get_states(sys)) + params = propertynames(get_params(sys)) + inputs = propertynames(initialize_input(sys)) + + ns = Val(namespace) + state_vars = gen_variables(Val(states), ns, Val(true)) + input_vars = gen_variables(Val(inputs), ns, Val(true)) + param_vars = gen_variables(Val(params), ns, Val(false)) + + sym_subsys_states = SubsystemStates{get_tag(sys)}(NamedTuple{states}(state_vars)) + sym_subsys_params = SubsystemParams{get_tag(sys)}(NamedTuple{params}(param_vars)) + sym_subsys = Subsystem{get_tag(sys)}(sym_subsys_states, sym_subsys_params) + + sym_subsys, state_vars, input_vars +end + +@latexrecipe function f(sys::Union{GraphSystem, PartitioningGraphSystem}; show_connection_equations = true) + sys = sys isa GraphSystem ? sys.flat_graph : sys + if show_connection_equations + return latexify([node_equations(collect(nodes(sys))); + connection_equations(sys)]) + else + return latexify(node_equations(collect(nodes(sys)))) + end +end + +end diff --git a/ext/MTKExt.jl b/ext/MTKExt.jl deleted file mode 100644 index 8f7dab1..0000000 --- a/ext/MTKExt.jl +++ /dev/null @@ -1,40 +0,0 @@ -module MTKExt - -using ModelingToolkit: ModelingToolkit, Num -using Symbolics: Symbolics, tosymbol -using GraphDynamics: GraphDynamics, GraphSystemParameters, PartitionedGraphSystem -using SymbolicIndexingInterface: SymbolicIndexingInterface - -function SymbolicIndexingInterface.is_variable(sys::PartitionedGraphSystem, var::Num) - SymbolicIndexingInterface.is_variable(sys, tosymbol(var; escape=false)) -end - -function SymbolicIndexingInterface.variable_index(sys::PartitionedGraphSystem, var::Num) - SymbolicIndexingInterface.variable_index(sys, tosymbol(var; escape=false)) -end - -function SymbolicIndexingInterface.is_parameter(sys::PartitionedGraphSystem, var::Num) - SymbolicIndexingInterface.is_parameter(sys, tosymbol(var; escape=false)) -end - -function SymbolicIndexingInterface.parameter_index(sys::PartitionedGraphSystem, var::Num) - SymbolicIndexingInterface.parameter_index(sys, tosymbol(var; escape=false)) -end - -function SymbolicIndexingInterface.is_independent_variable(sys::PartitionedGraphSystem, var::Num) - SymbolicIndexingInterface.is_independent_variable(sys, tosymbol(var; escape=false)) -end - -function SymbolicIndexingInterface.is_observed(sys::PartitionedGraphSystem, var::Num) - SymbolicIndexingInterface.is_observed(sys, tosymbol(var; escape=false)) -end - -function SymbolicIndexingInterface.observed(sys::PartitionedGraphSystem, var::Num) - SymbolicIndexingInterface.observed(sys, tosymbol(var; escape=false)) -end - -function SymbolicIndexingInterface.observed(sys::PartitionedGraphSystem, vars::Union{Vector{Num}, Tuple{Vararg{Num}}}) - SymbolicIndexingInterface.observed(sys, tosymbol.(vars; escape=false)) -end - -end diff --git a/ext/SymbolicsExt.jl b/ext/SymbolicsExt.jl new file mode 100644 index 0000000..f0017f9 --- /dev/null +++ b/ext/SymbolicsExt.jl @@ -0,0 +1,39 @@ +module SymbolicsExt + +using Symbolics: Symbolics, tosymbol, Num +using GraphDynamics: GraphDynamics, GraphSystemParameters, GraphNamemap +using SymbolicIndexingInterface: SymbolicIndexingInterface + +function SymbolicIndexingInterface.is_variable(sys::GraphNamemap, var::Num) + SymbolicIndexingInterface.is_variable(sys, tosymbol(var; escape=false)) +end + +function SymbolicIndexingInterface.variable_index(sys::GraphNamemap, var::Num) + SymbolicIndexingInterface.variable_index(sys, tosymbol(var; escape=false)) +end + +function SymbolicIndexingInterface.is_parameter(sys::GraphNamemap, var::Num) + SymbolicIndexingInterface.is_parameter(sys, tosymbol(var; escape=false)) +end + +function SymbolicIndexingInterface.parameter_index(sys::GraphNamemap, var::Num) + SymbolicIndexingInterface.parameter_index(sys, tosymbol(var; escape=false)) +end + +function SymbolicIndexingInterface.is_independent_variable(sys::GraphNamemap, var::Num) + SymbolicIndexingInterface.is_independent_variable(sys, tosymbol(var; escape=false)) +end + +function SymbolicIndexingInterface.is_observed(sys::GraphNamemap, var::Num) + SymbolicIndexingInterface.is_observed(sys, tosymbol(var; escape=false)) +end + +function SymbolicIndexingInterface.observed(sys::GraphNamemap, var::Num) + SymbolicIndexingInterface.observed(sys, tosymbol(var; escape=false)) +end + +function SymbolicIndexingInterface.observed(sys::GraphNamemap, vars::Union{Vector{Num}, Tuple{Vararg{Num}}}) + SymbolicIndexingInterface.observed(sys, tosymbol.(vars; escape=false)) +end + +end diff --git a/src/GraphDynamics.jl b/src/GraphDynamics.jl index e26f07e..e8fa1f5 100644 --- a/src/GraphDynamics.jl +++ b/src/GraphDynamics.jl @@ -14,6 +14,7 @@ end subsystem_differential, apply_subsystem_noise!, subsystem_differential_requires_inputs, + connection_needs_ctx, initialize_input, combine, @@ -42,7 +43,8 @@ end get_name, connection_property_namemap, - make_connection_matrices + ArrayOfSubsystems, + ArrayOfSubsystemStates ) export @@ -50,7 +52,10 @@ export SubsystemParams, SubsystemStates, GraphSystem, - PartitionedGraphSystem, + GraphSystemParameters, + PartitioningGraphSystem, + GraphNamemap, + PartitionedIndex, get_tag, get_states, get_params, @@ -65,11 +70,14 @@ export add_node!, nodes, has_connection, - delete_connection! - + flatten_graph, + connection_equations, + is_flat, + PolyesterScheduler #---------------------------------------------------------- -using Base: @kwdef, @propagate_inbounds, isassigned + +using Base: @kwdef, @propagate_inbounds, isassigned, isstored using Base.Iterators: map as imap using Base.Cartesian: @nexprs @@ -98,13 +106,17 @@ using SymbolicIndexingInterface: setu, setp, getp, - observed + observed, + is_parameter, + parameter_index using Accessors: Accessors, @set, @reset, - @insert + @insert, + set, + PropertyLens using ConstructionBase: ConstructionBase, @@ -129,20 +141,38 @@ using DiffEqBase: DiffEqBase, anyeltypedual +using FieldViews: + FieldViews, + FieldViewable, + FieldView + +using Polyester: + Polyester, + @batch + + #---------------------------------------------------------- # Random utils include("utils.jl") +struct PolyesterScheduler end + #---------------------------------------------------------- # API functions to be implemented by new Systems struct SubsystemStates{T, Eltype, States <: NamedTuple} <: AbstractVector{Eltype} states::States end +function FieldViews.fieldmap(p::Type{SubsystemStates{T, Elt, NamedTuple{state_names, tup}}}) where {T, Elt, state_names, tup} + map(name -> :states => name, state_names) +end struct SubsystemParams{T, Params <: NamedTuple} params::Params end +function FieldViews.fieldmap(p::Type{SubsystemParams{T, NamedTuple{param_names, tup}}}) where {T, param_names, tup} + map(name -> :params => name, param_names) +end """ Subsystem{T, Eltype, StateNT, ParamNT} @@ -155,6 +185,11 @@ struct Subsystem{T, Eltype, States, Params} states::SubsystemStates{T, Eltype, States} params::SubsystemParams{T, Params} end +function FieldViews.fieldmap(::Type{Subsystem{T, Elt, <:NamedTuple{state_names}, <:NamedTuple{param_names}}}) where {T, Elt, state_names, param_names} + state_map = map(name -> :states => :states => name, state_names) + param_map = map(name -> :params => :params => name, param_names) + (state_map..., param_map...) +end """ get_tag(subsystem::Subsystem{T}) = T @@ -285,7 +320,7 @@ By default, it does nothing (no noise). Override this for stochastic subsystems. ```julia function GraphDynamics.apply_subsystem_noise!(vstate, sys::Subsystem{BrownianParticle}, t) # No noise in position, so we don't modify vstate[:x] - vstate[:v] = sys.σ # White noise in velocity with amplitude σ + vstate.v[] = sys.σ # White noise in velocity with amplitude σ end ``` """ @@ -294,12 +329,12 @@ function apply_subsystem_noise!(vstate, subsystem, t) end -# """ -# must_run_before(::Type{T}, ::Type{U}) +""" + must_run_before(::Type{T}, ::Type{U}) -# Overload this function to tell the ODE solver that subsystems of type `T` must run before subsystems of type `U`. Default `false`. -# """ -# must_run_before(::Type{T}, ::Type{U}) where {T, U} = false +Overload this function to tell the ODE solver that subsystems of type `T` must run before subsystems of type `U`. Default `false`. +""" +must_run_before(::Type{T}, ::Type{U}) where {T, U} = false function continuous_event_condition end function apply_continuous_event! end @@ -506,7 +541,6 @@ end ``` the default implementation would give ```julia - julia> GraphDynamics.connection_property_namemap(Coulomb(1.0), :p1, :p2) (:fac_Coulomb_p1_p2,) ``` @@ -521,14 +555,91 @@ function connection_property_namemap(conn::CR, name_src, name_dst) where CR NamedTuple{pnames}(vals) end +######## Equations +""" + node_equations(::Subsystem{T}) where T + +Output the differential equations for a node as LaTeX strings. Requires Latexify and Symbolics to be loaded. +""" +function node_equations end +""" + graph_equations(::PartitionedGraphSystem) + +Output the equations for a flattened graph system as LaTeX strings. Requires Latexify and Symbolics to be loaded. +""" +function graph_equations end +""" + connection_equations(conn::ConnectionRule, src::Subsystem{U}, dst::Subsystem{T}) + +Output the equations for the connection between node `src` and node `dst`. Requires Latexify and Symbolics to be loaded. +""" +function connection_equations end + +function graph_ode! end + + +""" + connection_needs_ctx(conn::ConnectionRule) :: Bool + +(default: `false`) determines if the call signature to `conn` should be of the form + + conn(src, dst, t, ctx::NamedTuple{states_partitioned, params_partitioned, connection_matrices}) + +or + + conn(src, dst, t) + +Overload this function to return `true` if you have a connection rule type that needs access to the wider-graph +structure. +""" +@inline connection_needs_ctx(x) = false + +struct StateIndex + idx::Int +end +struct ParamIndex + tup_index::Int + v_index::Int + prop::Symbol +end +struct CompuIndex + tup_index::Int + v_index::Int + prop::Symbol + requires_inputs::Bool +end +struct ConnectionIndex + nc::Int + i_src::Int + i_dst::Int + j_src::Int + j_dst::Int + connection_key::Union{Symbol, Nothing} + prop::Union{Symbol, Nothing} +end + +struct GraphNamemap + state_namemap::OrderedDict{Symbol, StateIndex} + param_namemap::OrderedDict{Symbol, ParamIndex} + compu_namemap::OrderedDict{Symbol, CompuIndex} + connection_namemap::OrderedDict{Symbol, ConnectionIndex} +end +function Base.copy(g::GraphNamemap) + GraphNamemap(copy.(( + g.state_namemap, + g.param_namemap, + g.compu_namemap, + g.connection_namemap + ))...) +end #---------------------------------------------------------- # Infrastructure for subsystems include("subsystems.jl") #---------------------------------------------------------- -# The GraphSystem type, and the stuff to turn it into a -# PartitionedGraphSystem +# The GraphSystem type +include("partitioning_graph_system.jl") include("graph_system.jl") #---------------------------------------------------------- diff --git a/src/graph_solve.jl b/src/graph_solve.jl index 81afd2a..3da5741 100644 --- a/src/graph_solve.jl +++ b/src/graph_solve.jl @@ -69,6 +69,33 @@ end end end +@generated function GraphDynamics._graph_ode!(dstates_partitioned::NTuple{Len, Any}, + states_partitioned ::NTuple{Len, Any}, + params_partitioned ::NTuple{Len, Any}, + connection_matrices::ConnectionMatrices{NConn}, + scheduler::PolyesterScheduler, + t,) where {Len, NConn} + quote + @nexprs $Len i -> begin + f = make_graph_ode_mapping_f( + Val(i), + dstates_partitioned, + states_partitioned, + params_partitioned, + connection_matrices, + SerialScheduler(), + t + ) + pforeach(f, eachindex(states_partitioned[i])) + end + end +end + +pforeach(f, itr) = @batch for j ∈ itr + f(j) +end + + @generated function _graph_ode!(dstates_partitioned::NTuple{Len, Any}#=mutated=#, states_partitioned ::NTuple{Len, Any}, @@ -116,52 +143,30 @@ function _graph_ode_mapping_f(j, ::Val{i}, connection_matrices::ConnectionMatrices{NConn}, scheduler, t) where {i, Len, NConn} - sys_dst = Subsystem(states_partitioned[i][j], params_partitioned[i][j]) - input = if subsystem_differential_requires_inputs(sys_dst) + sys = Subsystem(states_partitioned[i][j], params_partitioned[i][j]) + input = if subsystem_differential_requires_inputs(sys) calculate_inputs(Val(i), j, states_partitioned, params_partitioned, connection_matrices, t) else - initialize_input(sys_dst) + initialize_input(sys) end - apply_subsystem_differential!(@view(dstates_partitioned[i][j]), sys_dst, input, t) + apply_subsystem_differential!(@view(dstates_partitioned[i][j]), sys, input, t) end -""" - combine_inputs(subsys::Subsystem, - M::AbstractMatrix{<:ConnectionRule}, - j::Integer, - states_partitioned::AbstractVector{<:SubsystemStates}, - params_partitioned::abstractvector{<:SubsystemParams}, - scheduler; - init=initialize_input(subsys)) - -Given an input graph-subsystem `subsys`, a (sub-)connection matrix `M` whose `j`-th column describes the -connections between `subsys` and a (sub-)list of subsystems defined by `states_partitioned` and -`params_partitioned`, compute the total input that should be passed to `subsys` by combining all the input -signals sent from each connected subsystem. - -e.g. if the inputs are just numbers who are combined by adding them together, then this computes - -```math -\\sum_{l} M[l, j](Subsystem(states_partitioned[i], params_partitioned[l]), subsys) -``` -""" -function combine_inputs end - @generated function calculate_inputs(::Val{i}, j, states_partitioned::NTuple{Len, Any}, params_partitioned::NTuple{Len, Any}, connection_matrices::ConnectionMatrices{NConn}, - #TODO: remove the =nothing fallback - t=nothing) where {i, Len, NConn} + t) where {i, Len, NConn} quote - state = @inbounds states_partitioned[i][j] - subsys = @inbounds Subsystem(state, params_partitioned[i][j]) + subsys = @inbounds Subsystem(states_partitioned[i][j], params_partitioned[i][j]) input = initialize_input(subsys) + ctx = (; states_partitioned, params_partitioned, connection_matrices) @nexprs $Len k -> begin + subsystems_k = ArrayOfSubsystems(states_partitioned[k], params_partitioned[k]) @nexprs $NConn nc -> begin @inbounds begin M = connection_matrices[nc].data[k][i] # Same as cm[nc][k,i] but performs better when there's many types - input′ = combine_inputs(subsys, M, j, states_partitioned[k], params_partitioned[k], t, SerialScheduler();) + input′ = @inline combine_inputs(subsys, M, j, subsystems_k, t, ctx) input = combine(input, input′) end end @@ -170,31 +175,24 @@ function combine_inputs end end end -@noinline function combine_inputs(subsys, M, j, states_partitioned, params_partitioned, t, ::SerialScheduler; - init=initialize_input(subsys)) +function combine_inputs(subsys, M, j, subsystems_k, t, ctx; init=initialize_input(subsys)) acc = init - if M isa SparseMatrixCSC - @inbounds for (l, Mlj) ∈ maybe_sparse_enumerate_col(M, j) - acc′ = Mlj(Subsystem(states_partitioned[l], params_partitioned[l]), subsys, t) # Now do the actual reducing step just like the above method - acc = combine(acc, acc′) - end - else - @inbounds @simd for l ∈ axes(M, 1) - acc′ = M[l,j](Subsystem(states_partitioned[l], params_partitioned[l]), subsys, t) # Now do the actual reducing step just like the above method - acc = combine(acc, acc′) - end + @inbounds for (l, Mlj) ∈ maybe_sparse_enumerate_col(M, j) + t_ctx = connection_needs_ctx(Mlj) ? (t, ctx) : (t,) + acc′ = Mlj(subsystems_k[l], subsys, t_ctx...) + acc = combine(acc, acc′) end acc end -@inline combine_inputs(subsys, M::NotConnected, j, states_partitioned, params_partitioned, t, scheduler::SerialScheduler; - init=initialize_input(subsys)) = init +combine_inputs(subsys, M::NotConnected, j, subsystems_k, t, ctx; + init=initialize_input(subsys)) = init """ maybe_sparse_enumerate_col(M::AbstractMatrix, j) -Equivalent to `((l, M[l, j]) for l ∈ axes(M, 1))`, except if `M` isa `SparseMatrixCSC`, this will -only iterate over the non-zero values of `M`. +Equivalent to `((l, M[l, j]) for l ∈ axes(M, 1))`. If `M` isa `SparseMatrixCSC`, this will +only iterate over the non-zero values of `M`, ignoring structural zeros. """ function maybe_sparse_enumerate_col(M::SparseMatrixCSC, j) rows = rowvals(M) @@ -214,7 +212,6 @@ function maybe_sparse_enumerate_col(::NotConnected, j) () end - #---------------------------------------------------------- # Infra. for stochastic noise #---------------------------------------------------------- @@ -249,7 +246,8 @@ end for j ∈ js @inbounds begin sys = Subsystem(states_partitioned[i][j], params_partitioned[i][j]) - apply_subsystem_noise!(@view(dstates_partitioned[i].data[:, j]), sys, t) + states_view = view(dstates_partitioned[i], j) + apply_subsystem_noise!(states_view, sys, t) idx += l end end @@ -307,18 +305,18 @@ function _continuous_affect!(integrator, @inbounds begin if has_continuous_events(eltype(states_partitioned[i])) N = length(states_partitioned[i]) + subsystems_i = ArrayOfSubsystems(states_partitioned[i], params_partitioned[i]) js = (1:N) .+ offset if idx ∈ js j = idx - offset - sview = @view states_partitioned[i][j] - pview = @view params_partitioned[i][j] sys = Subsystem(states_partitioned[i][j], params_partitioned[i][j]) + sys_view = @view subsystems_i[j] F = ForeachConnectedSubsystem{i}(j, states_partitioned, params_partitioned, connection_matrices) if continuous_events_require_inputs(sys) input = calculate_inputs(Val(i), j, states_partitioned, params_partitioned, connection_matrices, t) - apply_continuous_event!(integrator, sview, pview, sys, F, input) + apply_continuous_event!(integrator, sys_view, sys, F, input) else - apply_continuous_event!(integrator, sview, pview, sys, F) + apply_continuous_event!(integrator, sys_view, sys, F) end end offset += N @@ -346,9 +344,10 @@ end trigger = false @nexprs $Len i -> begin if has_discrete_events(eltype(states_partitioned[i])) + subsystems_i = ArrayOfSubsystems(states_partitioned[i], params_partitioned[i]) for j ∈ eachindex(states_partitioned[i]) F = ForeachConnectedSubsystem{i}(j, states_partitioned, params_partitioned, connection_matrices) - sys = Subsystem(states_partitioned[i][j], params_partitioned[i][j]) + sys = subsystems_i[j] cond = discrete_event_condition(sys, t, F) trigger |= cond discrete_event_cache[i][j] = cond @@ -360,14 +359,15 @@ end @nexprs $Len i -> begin @nexprs $Len k -> begin M = connection_matrices[nc].data[k][i] # Same as cm[nc][k,i] but performs better when there's many types - if !(M isa NotConnected) && has_discrete_events(eltype(M), - get_tag(eltype(states_partitioned[i])), - get_tag(eltype(states_partitioned[k]))) + get_tag(eltype(states_partitioned[k])), + get_tag(eltype(states_partitioned[i]))) + subsystems_k = ArrayOfSubsystems(states_partitioned[k], params_partitioned[k]) + subsystems_i = ArrayOfSubsystems(states_partitioned[i], params_partitioned[i]) for j ∈ eachindex(states_partitioned[i]) - sys_dst = Subsystem(states_partitioned[i][j], params_partitioned[i][j]) + sys_dst = subsystems_i[j] for (l, Mlj) ∈ maybe_sparse_enumerate_col(M, j) - sys_src = Subsystem(states_partitioned[k][l], params_partitioned[k][l]) + sys_src = subsystems_k[l] discrete_event_condition(Mlj, t, sys_src, sys_dst) && return true end end @@ -401,17 +401,17 @@ end @nexprs $Len i -> begin # First we apply events to the states if has_discrete_events(eltype(states_partitioned[i])) - @inbounds for j ∈ eachindex(states_partitioned[i]) + subsystems_i = ArrayOfSubsystems(states_partitioned[i], params_partitioned[i]) + @inbounds for j ∈ eachindex(subsystems_i) if discrete_event_cache[i][j] sys = Subsystem(states_partitioned[i][j], params_partitioned[i][j]) - sview = @view states_partitioned[i][j] - pview = @view params_partitioned[i][j] + sys_view = @view subsystems_i[j] F = ForeachConnectedSubsystem{i}(j, states_partitioned, params_partitioned, connection_matrices) if discrete_events_require_inputs(sys) input = calculate_inputs(Val(i), j, states_partitioned, params_partitioned, connection_matrices, t) - apply_discrete_event!(integrator, sview, pview, sys, F, input) + apply_discrete_event!(integrator, sys_view, sys, F, input) else - apply_discrete_event!(integrator, sview, pview, sys, F) + apply_discrete_event!(integrator, sys_view, sys, F) end end end @@ -436,18 +436,21 @@ function _discrete_connection_affect!(::Val{i}, ::Val{k}, ::Val{nc}, t, connection_matrices::ConnectionMatrices{NConn}, integrator) where {i, k, nc, Len, NConn} function (j) - sys_dst = Subsystem(states_partitioned[i][j], params_partitioned[i][j]) - sview_dst = @view states_partitioned[i][j] - pview_dst = @view params_partitioned[i][j] + subsystems_i = ArrayOfSubsystems(states_partitioned[i], params_partitioned[i]) + subsystems_k = ArrayOfSubsystems(states_partitioned[k], params_partitioned[k]) + + sys_view_dst = @view subsystems_i[j] + sys_dst = sys_view_dst[] + M = connection_matrices.matrices[nc].data[k][i] if !(M isa NotConnected) && has_discrete_events(eltype(M), - get_tag(eltype(states_partitioned[i])), - get_tag(eltype(states_partitioned[k]))) + get_tag(eltype(states_partitioned[k])), + get_tag(eltype(states_partitioned[i]))) for (l, Mlj) ∈ maybe_sparse_enumerate_col(M, j) - sys_src = Subsystem(states_partitioned[k][l], params_partitioned[k][l]) + sys_view_src = @view subsystems_k[l] + sys_src = sys_view_src[] if discrete_event_condition(Mlj, t, sys_src, sys_dst) - sview_src = @view states_partitioned[k][l] - pview_src = @view params_partitioned[k][l] + if discrete_events_require_inputs(typeof(Mlj)) input_dst = calculate_inputs(Val(i), j, states_partitioned, @@ -460,15 +463,15 @@ function _discrete_connection_affect!(::Val{i}, ::Val{k}, ::Val{nc}, t, connection_matrices, t) apply_discrete_event!(integrator, - sview_src, pview_src, - sview_dst, pview_dst, + sys_view_src, + sys_view_dst, Mlj, sys_src, input_src, sys_dst, input_dst) else apply_discrete_event!(integrator, - sview_src, pview_src, - sview_dst, pview_dst, + sys_view_src, + sys_view_dst, Mlj, sys_src, sys_dst) end @@ -479,6 +482,8 @@ function _discrete_connection_affect!(::Val{i}, ::Val{k}, ::Val{nc}, t, end + + #----------------------------------------------------------------------- """ @@ -534,18 +539,18 @@ end (;l, states_partitioned, params_partitioned, connection_matrices) = FCS state = init @nexprs $Len i -> begin + subsystems_i = ArrayOfSubsystems(states_partitioned[i], params_partitioned[i]) @nexprs $NConn nc -> begin M = connection_matrices[nc].data[k][i] # Same as cm[nc][k,i] but performs better when there's many types if M isa NotConnected nothing else - for j ∈ eachindex(states_partitioned[i]) - if isassigned(M, l, j) + for j ∈ eachindex(subsystems_i) + if isstored(M, l, j) conn = M[l, j] - @inbounds states_view_dst = @view states_partitioned[i][j] - @inbounds params_view_dst = @view params_partitioned[i][j] - sys_dst = Subsystem(states_view_dst[], params_view_dst[]) - res = f(conn, sys_dst, states_view_dst, params_view_dst) + @inbounds sys_view_dst = @view subsystems_i[j] + sys_dst = sys_view_dst[] + res = f(conn, sys_dst, sys_view_dst) state = op(state, res) end end @@ -558,3 +563,17 @@ end (FCS::ForeachConnectedSubsystem)(f::F) where {F} = mapreduce(f, (_, _) -> nothing, FCS; init=nothing) +@generated function foreach_incoming_conn(f, cm::ConnectionMatrices{NConn, Tup}, ::Val{i}, j) where {NConn, NPar, i, Tup <: NTuple{NConn, ConnectionMatrix{NPar}}} + quote + @nexprs $NConn nc -> begin + @nexprs $NPar k -> begin + M = cm[nc].data[k][i] + if !(M isa NotConnected) + for (l, Mlj) ∈ maybe_sparse_enumerate_col(M, j) + f(nc, k, l, Mlj) + end + end + end + end + end +end diff --git a/src/graph_system.jl b/src/graph_system.jl index ecb775a..2614405 100644 --- a/src/graph_system.jl +++ b/src/graph_system.jl @@ -3,10 +3,38 @@ struct GraphSystemConnection dst data::NamedTuple end + +function Base.show(io::IO, conn::GraphSystemConnection) + printstyled("GraphSystemConnection", bold = true) + println("\nsrc: $(get_name(conn.src))\ndst: $(get_name(conn.dst))\ndata: $(conn.data)") +end + struct GraphSystem + name::Union{Nothing, Symbol} data::OrderedDict{Any, OrderedDict{Any, Vector{GraphSystemConnection}}} + flat_graph::PartitioningGraphSystem +end + +is_flat(g::GraphSystem) = isnothing(g.flat_graph) + +function GraphSystem(name, data) + flat_graph = PartitioningGraphSystem(Symbol(name, :_flat)) + g = GraphSystem(name, data, flat_graph) + for n in nodes(g) + system_wiring_rule!(flat_graph, n) + end + for (;src, dst, data) in connections(g) + system_wiring_rule!(flat_graph, src, dst; data...) + end + g +end + +function Base.copy(g::GraphSystem) + GraphSystem(g.name, copy(g.data), copy(g.flat_graph)) end -GraphSystem() = GraphSystem(OrderedDict{Any, OrderedDict{Any, GraphSystemConnection}}()) + +GraphSystem(; name=nothing) = GraphSystem(name, OrderedDict{Any, OrderedDict{Any, GraphSystemConnection}}()) + GraphSystemConnection(src, dst; kwargs...) = GraphSystemConnection(src, dst, NamedTuple(kwargs)) function Base.show(io::IO, sys::GraphSystem) @@ -23,6 +51,7 @@ end nodes(g::GraphSystem) = keys(g.data) function add_node!(g::GraphSystem, blox) get!(g.data, blox) do + system_wiring_rule!(g.flat_graph, blox) OrderedDict{Any, GraphSystemConnection}() end end @@ -31,12 +60,19 @@ function connections(g::GraphSystem, src, dst) g.data[src][dst] end +function connections(g::GraphSystem, src) + Iterators.flatmap(g.data[src]) do (_, edges) + edges + end +end + function add_connection!(g::GraphSystem, src, dst; kwargs...) d_src = add_node!(g, src) d_dst = add_node!(g, dst) v = get!(d_src, dst, GraphSystemConnection[]) push!(v, GraphSystemConnection(src, dst, NamedTuple(kwargs))) + system_wiring_rule!(g.flat_graph, src, dst; kwargs...) end add_connection!(g::GraphSystem, src, dst, d::AbstractDict) = add_connection!(g, src, dst; d...) @@ -57,24 +93,16 @@ function Base.merge!(g1::GraphSystem, g2::GraphSystem) g1 end function Base.merge(g1::GraphSystem, g2::GraphSystem) - g3 = GraphSystem() + g3 = GraphSystem(;name=g1.name) merge!(g3, g1) merge!(g3, g2) g3 end -function delete_connection!(g::GraphSystem, conn::GraphSystemConnection) - v = g.data[conn.src][conn.dst] - i = findfirst(==(conn), v) - if isnothing(i) - @warn "Attempted to remove a connection that doesn't exist" - end - deleteat!(v, i) -end - function system_wiring_rule!(g, node) add_node!(g, node) end + function system_wiring_rule!(g, src, dst; kwargs...) if !haskey(kwargs, :conn) error("conn keyword argument not specified for connection between $src and $dst") @@ -82,224 +110,5 @@ function system_wiring_rule!(g, src, dst; kwargs...) add_connection!(g, src, dst; conn=kwargs[:conn], kwargs...) end -@kwdef struct PartitionedGraphSystem{CM <: ConnectionMatrices, S, P, SP, EVT, Ns, CONM, SNM, PNM, CNM, EP} - graph::Union{Nothing, GraphSystem} = nothing - flat_graph::Union{Nothing, GraphSystem} = nothing - connection_matrices::CM - states_partitioned::S - params_partitioned::P - subsystems_partitioned::SP = map(i -> map(j -> Subsystem(states_partitioned[i][j], params_partitioned[i][j]), - eachindex(states_partitioned[i], params_partitioned[i])), - eachindex(states_partitioned, params_partitioned)) - tstops::EVT = Float64[] - names_partitioned::Ns - connection_namemap::CONM = make_connection_namemape(names_partitioned, connection_matrices) - state_namemap::SNM = make_state_namemap(names_partitioned, states_partitioned) - param_namemap::PNM = make_param_namemap(names_partitioned, params_partitioned) - compu_namemap::CNM = make_compu_namemap(names_partitioned, states_partitioned, params_partitioned) - is_stochastic::Bool=any(v -> any(isstochastic, v), states_partitioned) - extra_params::EP = (;) -end - -function PartitionedGraphSystem(g::GraphSystem) - g_flat = GraphSystem() - for sys ∈ nodes(g) - system_wiring_rule!(g_flat, sys) - end - for (;src, dst, data) ∈ connections(g) - system_wiring_rule!(g_flat, src, dst; data...) - end - #================================================================================================== - Create a list of lists of the lowest level nodes in the flattened graph, partitioned by their type - so different types can be handled efficiently - - e.g. if we have - @named n1 = SysType1(x=1, y=2) - @named n2 = SysType1(x=1, y=3) - @named n3 = SysType2(a=1, b=2, c=3) - - in the graph, then we'd end up with - - nodes_paritioned = [SysType1[n1, n2], SysType1[n3]] - - ===================================================================================================# - - node_types = (unique ∘ imap)(typeof, nodes(g_flat)) - nodes_partitioned = map(node_types) do T - if isstochastic(T) - system_is_stochastic = true - end - filter(collect(nodes(g_flat))) do sys - sys isa T - end - end - tstops = Float64[] - subsystems_partitioned = (Tuple ∘ map)(nodes_partitioned) do v - map(v) do node - sys = to_subsystem(node) - for t ∈ event_times(sys) - push!(tstops, t) - end - sys - end - end - states_partitioned = (Tuple ∘ map)(v -> map(get_states, v), subsystems_partitioned) - params_partitioned = (Tuple ∘ map)(v -> map(get_params, v), subsystems_partitioned) - names_partitioned = (Tuple ∘ map)(v -> map(x -> convert(Symbol, get_name(x)), v), nodes_partitioned) - - #================================================================================================== - Create a ConnectionMatrices object containing structured information about how each lowest level nodes - is connected to other nodes, partitioned by the types of the nodes, and the types of the connections for - type stability. - e.g. if we have - - @named n1 = SysType1(x=1, y=2) - @named n2 = SysType1(x=1, y=3) - @named n3 = SysType2(a=1, b=2, c=3) - - add_connection!(g, n1, n2; conn=Conn1(1)) - add_connection!(g, n2, n3; conn=Conn1(2)) - add_connection!(g, n3, n1; conn=Conn2(3)) - add_connection!(g, n3, n2; conn=Conn2(4)) - - we'd get - connection_matrix_1 = Conn1[⎡. 1⎤⎡.⎤ - ⎣. .⎦⎣2⎦ - [. .][.]] - - connection_matrix_2 = Conn2[⎡. .⎤⎡.⎤ - ⎣. .⎦⎣.⎦ - [3 4][.]] - - ConnectionMatrices((connection_matrix_1, connection_matrix_2)) - - where the sub-matrices are sparse arrays. - - This allows for type-stable calculations involving the subsystems and their connections - ===================================================================================================# - (;connection_matrices, connection_tstops, connection_namemap) = make_connection_matrices(g_flat, nodes_partitioned; - subsystems_partitioned, names_partitioned) - - append!(tstops, connection_tstops) - - - PartitionedGraphSystem( - ;graph=g, - flat_graph=g_flat, - is_stochastic = any(isstochastic, node_types), - connection_matrices, - subsystems_partitioned, - states_partitioned, - params_partitioned, - tstops=unique!(tstops), - names_partitioned, - connection_namemap - ) -end - -function make_partitioned_nodes(g_flat) - node_types = (unique ∘ imap)(typeof, nodes(g_flat)) - nodes_partitioned = map(node_types) do T - filter(collect(nodes(g_flat))) do sys - sys isa T - end - end -end - - -function make_connection_matrices(g_flat, nodes_partitioned=make_partitioned_nodes(g_flat); - pred=(_) -> true, - conn_key=:conn, - subsystems_partitioned=map(v -> map(to_subsystem, v), nodes_partitioned), - names_partitioned=map(v -> map(x -> convert(Symbol, get_name(x)), v), nodes_partitioned)) - check_no_double_connections(g_flat, conn_key) - connection_types = (imap)(connections(g_flat)) do (; src, dst, data) - if haskey(data, conn_key) && pred(data[conn_key]) - typeof(data[conn_key]) - else - nothing - end - end |> unique |> x -> filter(!isnothing, x) - connection_tstops = Float64[] - connection_namemap = OrderedDict{Symbol, ConnectionIndex}() - connection_matrices = (ConnectionMatrices ∘ Tuple ∘ map)(enumerate(connection_types)) do (nc, CT) - (ConnectionMatrix ∘ Tuple ∘ map)(enumerate(nodes_partitioned)) do (k, nodeks) - (Tuple ∘ map)(enumerate(nodes_partitioned)) do (i, nodeis) - ls = Int[] - js = Int[] - conns = CT[] - for (j, nodeij) ∈ enumerate(nodeis) - for (l, nodekl) ∈ enumerate(nodeks) - if has_connection(g_flat, nodekl, nodeij) - for (; data) = connections(g_flat, nodekl, nodeij) - if haskey(data, conn_key) - conn = data[conn_key] - if conn isa CT && pred(conn) - push!(js, j) - push!(ls, l) - push!(conns, conn) - - for (prop, name) ∈ pairs(connection_property_namemap(conn, names_partitioned[k][l], names_partitioned[i][j])) - connection_namemap[name] = ConnectionIndex(nc, k, i, l, j, name, prop) - end - - for t ∈ event_times(conn, subsystems_partitioned[k][l], subsystems_partitioned[i][j]) - push!(connection_tstops, t) - end - end - end - end - end - end - end - rule_matrix = if isempty(conns) - NotConnected{CT}() #{CT}(length(nodeks), length(nodeis)) - else - sparse(ls, js, conns, length(nodeks), length(nodeis)) - end - rule_matrix - end - end - end - (; connection_matrices, connection_tstops, connection_namemap) -end - -function check_no_double_connections(g, conn_key) - for src ∈ nodes(g) - for dst ∈ nodes(g) - if has_connection(g, src, dst) - ps = connections(g, src, dst) - conns = [data[conn_key] for (;data) ∈ connections(g, src, dst) if haskey(data, conn_key)] - if length(unique(typeof, conns)) < length(conns) - error("Cannot have multiple connections between the same two nodes of the same type. Got $(conns) between $src and $dst.") - end - end - end - end -end - -@generated function make_connection_namemape(names_partitioned::NTuple{Len, Any}, - connection_matrices::ConnectionMatrices{NConn}) where {Len, NConn} - quote - connection_namemap = OrderedDict{Symbol, ConnectionIndex}() - @nexprs $Len k -> begin - @nexprs $Len i -> begin - @nexprs $NConn nc -> begin - M = connection_matrices[nc].data[k][i] - if !(M isa NotConnected) - for j ∈ eachindex(names_partitioned) - for (l, conn) ∈ maybe_sparse_enumerate_col(M, j) - name_kl = names_partitioned[k][l] - name_ij = names_partitioned[i][j] - for (prop, name) ∈ pairs(connection_property_namemap(conn, name_kl, name_ij)) - connection_namemap[name] = ConnectionIndex(nc, k, i, l, j, name, prop) - end - end - end - end - end - end - end - connection_namemap - end -end +# Should not give different results on consecutive re-flattenings. +flatten_graph(g::GraphSystem; name=g.name) = g.flat_graph diff --git a/src/partitioning_graph_system.jl b/src/partitioning_graph_system.jl new file mode 100644 index 0000000..ba91a70 --- /dev/null +++ b/src/partitioning_graph_system.jl @@ -0,0 +1,262 @@ + +""" + PartitionedIndex{i}(j) + +Used for indexing into partitioned structures where the `i` index refers to the outer (typically static) +structure, and the inner index is dynamic. In GraphDynamics.jl, we often replace vectors of objects with +many possible types with tuples where each element of the tuple is a vector of a concrete type. This +allows for type-stable indexing and iteration. +""" +struct PartitionedIndex{i} + j::Int +end +Base.zero(::Type{PartitionedIndex{i}}) where {i} = PartitionedIndex{i}(0) +Base.zero(::PartitionedIndex{i}) where {i} = PartitionedIndex{i}(0) + +@propagate_inbounds Base.getindex(v::Union{AbstractVector, Tuple}, (;j)::PartitionedIndex{i}) where {i} = v[i][j] +@propagate_inbounds function Base.getindex(m::AbstractMatrix, + idx1::PartitionedIndex{k}, + idx2::PartitionedIndex{i}) where {i,k} + l = idx1.j + j = idx2.j + m[k,i][l,j] +end +function Base.getproperty(idx::PartitionedIndex{i}, s::Symbol) where {i} + if s == :i + convert(Int, i) + else + getfield(idx, s) + end +end +Base.propertynames(idx::PartitionedIndex) = (:i, :j) + +struct SparseMatrixBuilder{T} + data::OrderedDict{Tuple{Int, Int}, @NamedTuple{conn::T, kwargs::NamedTuple}} +end +Base.eltype(::SparseMatrixBuilder{T}) where {T} = T +Base.eltype(::Type{SparseMatrixBuilder{T}}) where {T} = T +SparseMatrixBuilder{T}() where {T} = SparseMatrixBuilder{T}(OrderedDict{Tuple{Int, Int}, @NamedTuple{conn::T, kwargs::NamedTuple}}()) +function Base.getindex(m::SparseMatrixBuilder, i::Integer, j::Integer) + m.data[(Int(i), Int(j))] +end + +function Base.zeros(::Type{SparseMatrixBuilder{T}}, sz::Integer...) where {T} + map(CartesianIndices(sz)) do _ + SparseMatrixBuilder{T}() + end +end +Base.zero(::Type{SparseMatrixBuilder{T}}) where {T} = SparseMatrixBuilder{T}() + +function SparseArrays.sparse(m::SparseMatrixBuilder, N::Integer, M::Integer) + Ls = (inds[1] for inds ∈ keys(m.data)) + Js = (inds[2] for inds ∈ keys(m.data)) + conns = [conn for (; conn) ∈ values(m.data)] + sparse(collect(Ls), collect(Js), conns, N, M) +end + +function Base.setindex!(M::AbstractMatrix{SparseMatrixBuilder{T}}, val, idx1::PartitionedIndex{k}, idx2::PartitionedIndex{i}) where {T, k, i} + l = idx1.j + j = idx2.j + M[k,i].data[(l,j)] = val +end +function extrude(M::AbstractMatrix{SparseMatrixBuilder{T}}) where {T} + n, m = size(M) + @assert n == m + [M zeros(SparseMatrixBuilder{T}, n, 1) + zeros(SparseMatrixBuilder{T}, 1, n) zero( SparseMatrixBuilder{T})] +end + +mutable struct PartitioningGraphSystem + const name::Union{Nothing, Symbol} + const node_namemap::OrderedDict{Symbol, PartitionedIndex} + + # One vector of nodes per node-type. + const nodes_partitioned::Vector{Vector} + const subsystems_partitioned::Vector{Vector} + + # A vector where each element corresponds to one connection type (BasicConnection, ReverseConnection, etc.) + # The matrices correspond to each node type, and then the SparseMatrixBuilders correspond to + # connections within one combination of types + const connections_partitioned::Vector{Matrix{SparseMatrixBuilder{T}} where {T}} + const tstops::Vector{Float64} + is_stochastic::Bool + const extra_params::OrderedDict{Symbol, Any} +end + +function Base.copy(g::PartitioningGraphSystem) + PartitioningGraphSystem( + g.name, + copy(g.node_namemap), + copy(g.nodes_partitioned), + copy(g.subsystems_partitioned), + copy(g.connections_partitioned), + copy(g.tstops), + g.is_stochastic, + copy(g.extra_params) + ) +end + +function PartitioningGraphSystem(name=nothing) + PartitioningGraphSystem(name, + OrderedDict{Symbol, PartitionedIndex}(), + Vector{<:Any}[], + Vector{<:Any}[], + Matrix{<:SparseMatrixBuilder}[], + Float64[], + false, + OrderedDict{Symbol, Any}()) +end + +function PartitionedIndex(g::PartitioningGraphSystem, node) + g.node_namemap[get_name(node)] +end + +function add_node!(g::PartitioningGraphSystem, x::T) where {T} + name = get_name(x) + if haskey(g.node_namemap, name) + node_old = g.nodes_partitioned[g.node_namemap[name]] + if !isequal(x, node_old) + error("Tried to add node with name $name to a PartitioningGraphSystem, but the PartitioningGraphSystem already had a node with that name which is not equal to the new value.\n New value: $x\n Old value: $node_old") + else + return x + end + end + sys = to_subsystem(x) + i = findfirst(v -> T <: eltype(v), g.nodes_partitioned) + if isnothing(i) + push!(g.nodes_partitioned, T[]) + push!(g.subsystems_partitioned, Subsystem{get_tag(sys)}[]) + for nc ∈ eachindex(g.connections_partitioned) + g.connections_partitioned[nc] = extrude(g.connections_partitioned[nc]) + end + i = length(g.nodes_partitioned) + end + push!(g.nodes_partitioned[i], x) + push!(g.subsystems_partitioned[i], sys) + foreach(t -> push!(g.tstops, t), event_times(sys)) + isstochastic(sys) && (g.is_stochastic = true) + g.node_namemap[name] = PartitionedIndex{i}(lastindex(g.nodes_partitioned[i])) + x +end + +add_connection!(g::PartitioningGraphSystem, src, dst; conn, kwargs...) = add_connection!(g, src, conn, dst; kwargs...) + +function add_connection!(g::PartitioningGraphSystem, src::T, conn::Conn, dst::U; kwargs...) where {T, Conn, U} + name_src = get_name(src) + name_dst = get_name(dst) + idx_src = get!(g.node_namemap, name_src) do + add_node!(g, src) + g.node_namemap[name_src] + end + idx_dst = get!(g.node_namemap, name_dst) do + add_node!(g, dst) + g.node_namemap[name_dst] + end + nc = findfirst(g.connections_partitioned) do mat + Conn <: eltype(eltype(mat)) + end + if isnothing(nc) + M = zeros(SparseMatrixBuilder{Conn}, length(g.nodes_partitioned), length(g.nodes_partitioned)) + push!(g.connections_partitioned, M) + nc = length(g.connections_partitioned) + end + for t ∈ event_times(conn, g.subsystems_partitioned[idx_src], g.subsystems_partitioned[idx_dst]) + push!(g.tstops, t) + end + builder = g.connections_partitioned[nc][idx_src.i, idx_dst.i] + if haskey(builder.data, (idx_src.j, idx_dst.j)) + error("Tried to add a connection of type $Conn between $name_src and $name_dst, but a connection of that type already exists.") + else + g.connections_partitioned[nc][idx_src, idx_dst] = (; conn, kwargs=NamedTuple(kwargs)) + end +end + +function nodes(g::PartitioningGraphSystem) + Iterators.flatten(g.nodes_partitioned) +end +function connections(g::PartitioningGraphSystem) + Iterators.flatmap(enumerate(g.connections_partitioned)) do (nc, mat) + Iterators.flatmap(CartesianIndices(mat)) do Idx + (k, i) = Tuple(Idx) + builder = mat[k, i] + Iterators.map(builder.data) do ((l,j), (; conn, kwargs)) + (; src=g.nodes_partitioned[k][l], dst=g.nodes_partitioned[i][j], conn, kwargs, nc, k, i, l, j) + end + end + end +end +function connections(g::PartitioningGraphSystem, src, dst) + name_src = get_name(src) + name_dst = get_name(dst) + (has_node(g, src) && has_node(g, dst)) || return () # empty iterator + idx_src = g.node_namemap[name_src] + idx_dst = g.node_namemap[name_dst] + + k = idx_src.i + l = idx_src.j + i = idx_dst.i + j = idx_dst.j + + src = g.nodes_partitioned[idx_src] + dst = g.nodes_partitioned[idx_dst] + itr = Iterators.filter(enumerate(g.partitioned_connections)) do (nc, mat) + haskey(mat[k,i].data, (l,j)) + end + Iterators.map(itr) do (nc, mat) + (; src, dst, mat[idx_src, idx_dst]..., nc, k, i, l, j) + end +end + +function has_node(g::PartitioningGraphSystem, x) + haskey(g.node_namemap, get_name(x)) +end +function has_connection(g::PartitioningGraphSystem, src, dst) + (has_node(g, src) && has_node(g, dst)) || return false + name_src = get_name(src) + name_dst = get_name(dst) + idx_src = g.node_namemap[name_src] + idx_dst = g.node_namemap[name_dst] + any(g.connections_partitioned) do mat + builder = mat[idx_src.i, idx_dst.i] + haskey(builder.data, (idx_src.j, idx_dst.j)) + end +end + +function Base.merge!(g1::PartitioningGraphSystem, g2::PartitioningGraphSystem) + for x ∈ nodes(g2) + #overloadable function that defaults to just adding the node to g1 + merge_node!(g1, g2, x) + end + for (;src, dst, conn, kwargs) ∈ connections(g2) + #overloadable function that defaults to just adding the connection to g1 + merge_connection!(g1, g2, src, conn, dst; kwargs...) + end + for (k, v) ∈ g2.extra_params + if !haskey(g1.extra_params, k) + g1.extra_params[k] = v + end + end + g1 +end +function Base.merge(g1::PartitioningGraphSystem, g2::PartitioningGraphSystem) + g3 = GraphSystem(;name=g1.name) + merge!(g3, g1) + merge!(g3, g2) + g3 +end + +""" + merge_node!(g1, g2, x) + +Default: `add_node!(g1, x)`. This function is called during `merge!(g1, g2)` to add the nodes from `g2` +into `g1`. Overload it for nodes `x` which may require custom handling. +""" +merge_node!(g1, g2, x) = add_node!(g1, x) + +""" + merge_connection!(g1, g2, src, conn, dst; kwargs...) + +Default: `add_connection!(g1, g2, src, conn, dst; kwargs...)`. This function is called during `merge!(g1, g2)` +to add the connections from `g2` into `g1`. Overload it for connections which may require custom handling. +""" +merge_connection!(g1, g2, src, conn, dst; kwargs...) = add_connection!(g1, src, conn, dst; kwargs...) diff --git a/src/problems.jl b/src/problems.jl index dc7449b..0091170 100644 --- a/src/problems.jl +++ b/src/problems.jl @@ -4,29 +4,19 @@ function (::Type{T})(g::GraphSystem, args...; kwargs...) where {T <: SciMLBase.A end function SciMLBase.ODEProblem(g::GraphSystem, u0map, tspan, param_map=[]; - scheduler=SerialScheduler(), tstops=Float64[], - allow_nonconcrete=false, global_events=(), kwargs...) - g_part = PartitionedGraphSystem(g) - ODEProblem(g_part, u0map, tspan, param_map; scheduler, tstops, allow_nonconcrete, global_events, kwargs...) -end -function SciMLBase.SDEProblem(g::GraphSystem, u0map, tspan, param_map=[]; - scheduler=SerialScheduler(), tstops=Float64[], - allow_nonconcrete=false, global_events=(), kwargs...) - g_part = PartitionedGraphSystem(g) - SDEProblem(g_part, u0map, tspan, param_map; scheduler, tstops, allow_nonconcrete, global_events, kwargs...) -end - - -function SciMLBase.ODEProblem(g::PartitionedGraphSystem, u0map, tspan, param_map=[]; scheduler=SerialScheduler(), tstops=Float64[], allow_nonconcrete=false, global_events=(), kwargs...) - nt = _problem(g, tspan; scheduler, allow_nonconcrete, u0map, param_map, global_events) - (; f, u, tspan, p, callback) = nt - if g.is_stochastic + p = GraphSystemParameters(g; scheduler, u0map, param_map) + (; symbolic_indexing_namemap, states_partitioned) = p + u0 = make_u0(p; allow_nonconcrete) + callback = make_callback(p; global_events) + f = ODEFunction{true, SciMLBase.FullSpecialize}(graph_ode!, sys=symbolic_indexing_namemap) + if g.flat_graph.is_stochastic error("Passed a stochastic GraphSystem to ODEProblem. You probably meant to use SDEProblem") end - tstops = vcat(tstops, nt.tstops) - prob = ODEProblem{true, SciMLBase.FullSpecialize}(f, u, tspan, p; callback, tstops, kwargs...) + tstops = vcat(tstops, g.flat_graph.tstops) + dtmax = make_dtmax(p) + prob = ODEProblem(f, u0, tspan, p; callback, tstops, dtmax, kwargs...) for (k, v) ∈ u0map setu(prob, k)(prob, v) end @@ -34,16 +24,21 @@ function SciMLBase.ODEProblem(g::PartitionedGraphSystem, u0map, tspan, param_ma prob end -function SciMLBase.SDEProblem(g::PartitionedGraphSystem, u0map, tspan, param_map=[]; +function SciMLBase.SDEProblem(g::GraphSystem, u0map, tspan, param_map=[]; scheduler=SerialScheduler(), tstops=Float64[], allow_nonconcrete=false, global_events=(), kwargs...) - nt = _problem(g, tspan; scheduler, allow_nonconcrete, u0map, param_map, global_events) - (; f, u, tspan, p, callback) = nt - if !g.is_stochastic + p = GraphSystemParameters(g; scheduler, u0map, param_map) + (; symbolic_indexing_namemap, states_partitioned) = p + u0 = make_u0(p; allow_nonconcrete) + callback = make_callback(p; global_events) + f = ODEFunction{true, SciMLBase.FullSpecialize}(graph_ode!, sys=symbolic_indexing_namemap) + if !g.flat_graph.is_stochastic error("Passed a non-stochastic GraphSystem to SDEProblem. You probably meant to use ODEProblem") end noise_rate_prototype = nothing # this'll need to change once we support correlated noise - prob = SDEProblem(f, graph_noise!, u, tspan, p; callback, noise_rate_prototype, tstops = vcat(tstops, nt.tstops), kwargs...) + dtmax = make_dtmax(p) + tstops = vcat(tstops, g.flat_graph.tstops) + prob = SDEProblem(f, graph_noise!, u0, tspan, p; callback, noise_rate_prototype, dtmax, kwargs...) for (k, v) ∈ u0map setu(prob, k)(prob, v) end @@ -51,32 +46,30 @@ function SciMLBase.SDEProblem(g::PartitionedGraphSystem, u0map, tspan, param_map prob end -Base.@kwdef struct GraphSystemParameters{PP, CM, S, PAP, DEC, NP, CONM, SNM, PNM, CNM, EP<:NamedTuple} +Base.@kwdef struct GraphSystemParameters{PP, SP, CM, S, PAP, DEC, NP, EP<:NamedTuple} + graph::Union{Nothing, GraphSystem} + states_partitioned::SP params_partitioned::PP connection_matrices::CM scheduler::S partition_plan::PAP discrete_event_cache::DEC names_partitioned::NP - connection_namemap::CONM - state_namemap::SNM - param_namemap::PNM - compu_namemap::CNM + symbolic_indexing_namemap::GraphNamemap extra_params::EP=(;) end function Base.copy(p::GraphSystemParameters) GraphSystemParameters( + copy(p.graph), + copy.(p.states_partitioned), copy.(p.params_partitioned), copy(p.connection_matrices), p.scheduler, (p.partition_plan), copy.(p.discrete_event_cache), copy.(p.names_partitioned), - copy(p.connection_namemap), - copy(p.state_namemap), - copy(p.param_namemap), - copy(p.compu_namemap), + copy(p.symbolic_indexing_namemap), map(copy, p.extra_params) ) end @@ -91,25 +84,40 @@ function DiffEqBase.anyeltypedual(p::ConnectionMatrix, ::Type{Val{counter}}) whe anyeltypedual(p.data) end -function _problem(g::PartitionedGraphSystem, tspan; scheduler, allow_nonconcrete, u0map, param_map, global_events) - (; states_partitioned, - params_partitioned, - connection_matrices, - tstops, - names_partitioned, - connection_namemap, - state_namemap, - param_namemap, - compu_namemap) = g +function GraphSystemParameters(g::GraphSystem; scheduler=SerialScheduler(), u0map=[], param_map=[]) + (; tstops) = g.flat_graph - params_partitioned = map(params_partitioned) do v - if !isconcretetype(eltype(v)) - unique_types = unique(typeof.(v)) + tupmap = Tuple ∘ map + subsystems_partitioned = Tuple(g.flat_graph.subsystems_partitioned) + states_partitioned = map(v -> get_states.(v), subsystems_partitioned) + names_partitioned = tupmap(g.flat_graph.nodes_partitioned) do v + get_name.(v) + end + connection_matrices = tupmap(enumerate(g.flat_graph.connections_partitioned)) do (nc, mat) + tupmap(axes(mat, 1)) do k + tupmap(axes(mat, 2)) do i + builder = mat[k,i] + if isempty(builder.data) + NotConnected{eltype(builder)}() + else + sparse(builder, + length(subsystems_partitioned[k]), + length(subsystems_partitioned[i])) + end + end + end |> ConnectionMatrix + end |> ConnectionMatrices + extra_params = NamedTuple(g.flat_graph.extra_params) + + params_partitioned = map(subsystems_partitioned) do v + pv = get_params.(v) + if !isconcretetype(eltype(pv)) + unique_types = unique(typeof.(pv)) @debug "Non-concrete param types. Promoting" unique_types - T = mapreduce(typeof, promote_type, v) - convert.(T, v) + T = mapreduce(typeof, promote_type, pv) + convert.(T, pv) else - v + pv end end @@ -136,7 +144,7 @@ function _problem(g::PartitionedGraphSystem, tspan; scheduler, allow_nonconcrete end re_eltype(s::SubsystemStates{T}) where {T} = convert(SubsystemStates{T, total_eltype}, s) - states_partitioned = map(states_partitioned) do v + states_partitioned = tupmap(states_partitioned) do v if eltype(eltype(v)) <: total_eltype v else @@ -150,32 +158,18 @@ function _problem(g::PartitionedGraphSystem, tspan; scheduler, allow_nonconcrete length(states_partitioned[i]) == length(params_partitioned[i]) || error("Incompatible state and parameter lengths") end + for nc ∈ 1:length(connection_matrices) for i ∈ eachindex(states_partitioned) for k ∈ eachindex(states_partitioned) M = connection_matrices[nc][i, k] if !(M isa NotConnected) size(M) == (length(states_partitioned[i]), length(states_partitioned[k])) || - error("Connection sub-matrix ($nc, $i, $k) has an incorrect size, expected $((length(states_partitioned[i]), length(states_partitioned[k]))), got $(size(connection_matrices[i, k])).") + error("Connection sub-matrix ($nc, $i, $k) has an incorrect size, expected $((length(states_partitioned[i]), length(states_partitioned[k]))), got $(size(connection_matrices[nc][i, k])).") end end end end - nce = sum(states_partitioned) do v - if has_continuous_events(eltype(v)) - length(v) - else - 0 - end - end - nde = sum(states_partitioned) do v - if has_discrete_events(eltype(v)) - length(v) - else - 0 - end - end - partition_plan = let offset=Ref(0) map(states_partitioned) do v sz = (length(eltype(v)), length(v)) @@ -186,30 +180,54 @@ function _problem(g::PartitionedGraphSystem, tspan; scheduler, allow_nonconcrete plan end end - u = reduce(vcat, map(v -> reduce(vcat, v), states_partitioned)) - if !allow_nonconcrete && !isconcretetype(eltype(u)) && !all(isconcretetype ∘ eltype, states_partitioned) - error(ArgumentError("The provided subsystem states do not have a concrete eltype. All partitions must contain the same eltype. Got `eltype(u) = $(eltype(u))`.")) - end - + symbolic_indexing_namemap = GraphNamemap( + names_partitioned, states_partitioned, params_partitioned, connection_matrices + ) discrete_event_cache = ntuple(length(states_partitioned)) do i len = has_discrete_events(eltype(states_partitioned[i])) ? length(states_partitioned[i]) : 0 falses(len) end + GraphSystemParameters(; + graph=g, + states_partitioned, + params_partitioned, + connection_matrices, + scheduler, + partition_plan, + discrete_event_cache, + names_partitioned, + symbolic_indexing_namemap, + extra_params) +end +function make_u0(g::GraphSystemParameters; allow_nonconcrete=false) + (; states_partitioned) = g + u0 = reduce(vcat, map(v -> reduce(vcat, v), states_partitioned)) + if !allow_nonconcrete && !isconcretetype(eltype(u0)) && !all(isconcretetype ∘ eltype, states_partitioned) + error(ArgumentError("The provided subsystem states do not have a concrete eltype. All partitions must contain the same eltype. Got `eltype(u) = $(eltype(u0))`.")) + end + u0 +end + +function make_callback(g::GraphSystemParameters; global_events=()) + (; states_partitioned) = g + nce = sum(states_partitioned) do v + has_continuous_events(eltype(v)) ? length(v) : 0 + end + nde = sum(states_partitioned) do v + has_discrete_events(eltype(v)) ? length(v) : 0 + end ce = nce > 0 ? VectorContinuousCallback(continuous_condition, continuous_affect!, nce) : nothing de = nde > 0 ? DiscreteCallback(discrete_condition, discrete_affect!) : nothing callback = CallbackSet(ce, de, global_events...) - f = GraphSystemFunction(graph_ode!, g) - p = GraphSystemParameters(; params_partitioned, - connection_matrices, - scheduler, - partition_plan, - discrete_event_cache, - names_partitioned, - connection_namemap, - state_namemap, - param_namemap, - compu_namemap) +end - (; f, u, tspan, p, callback, tstops) +function make_dtmax(p::GraphSystemParameters) + (; params_partitioned) = p + init = typemax(Float64) + minimum(params_partitioned; init) do v + minimum(v; init) do p + get(NamedTuple(p), :dtmax, init) + end + end end diff --git a/src/subsystems.jl b/src/subsystems.jl index 92e1b23..94f744c 100644 --- a/src/subsystems.jl +++ b/src/subsystems.jl @@ -23,7 +23,7 @@ function set_param_prop(s::SubsystemParams{T}, patch; allow_typechange=false) wh props = NamedTuple(s) props′ = merge(props, patch) if typeof(props) != typeof(props′) && !allow_typechange - param_setproperror(props, props′) + props′ = convert(typeof(props), props′) end SubsystemParams{T}(props′) end @@ -48,6 +48,12 @@ end :(promote_type($(param for param in Tup.parameters if param <: Number)...)) end +function Base.length( + ::Type{SubsystemParams{Name, NamedTuple{names, Tup}}} + ) where {Name, names, Tup} + length(names) +end + #------------------------------------------------------------ # Subsystem states function SubsystemStates{Name, Eltype, States}(v::AbstractVector) where {Name, Eltype, States <: NamedTuple} @@ -219,111 +225,124 @@ Base.eltype(::Type{<:Subsystem{<:Any, T}}) where {T} = T _deval(::Val{T}) where {T} = T function partitioned(v, partition_plan::NTuple{N, Any}) where {N} map(partition_plan) do (;inds, sz, TVal) + # to_structarray(_deval(TVal), v, inds, sz[2]) M = reshape(view(v, inds), sz...) - VectorOfSubsystemStates{_deval(TVal)}(M) + ArrayOfSubsystemStates{_deval(TVal)}(M) end end -struct VectorOfSubsystemStates{States, Mat <: AbstractMatrix} <: AbstractVector{States} - data::Mat +struct ArrayOfSubsystemStates{States, N, Store <: StridedArray} <: DenseArray{States, N} + parent::Store + function ArrayOfSubsystemStates{SubsystemStates{Name, T, NamedTuple{snames, Tup}}}(v::StridedArray{U, M}) where {Name, T, U, M, snames, Tup} + @assert size(v,1) == length(snames) + V = promote_type(T,U) + States = SubsystemStates{Name, V, NamedTuple{snames, NTuple{length(snames), V}}} + new{States, M-1, typeof(v)}(v) + end end -function VectorOfSubsystemStates{SubsystemStates{Name, T, NamedTuple{snames, Tup}}}(v::AbstractMatrix{U}) where {Name, T, U, snames, Tup} - V = promote_type(T,U) - States = SubsystemStates{Name, V, NamedTuple{snames, NTuple{length(snames), V}}} - VectorOfSubsystemStates{States, typeof(v)}(v) +const VectorOfSubsystemStates{States, Store} = ArrayOfSubsystemStates{States, 1, Store} +Base.size(v::ArrayOfSubsystemStates{States}) where {States} = size(parent(v))[2:end] +Base.parent(v::ArrayOfSubsystemStates) = getfield(v, :parent) +Base.IndexStyle(::Type{<:ArrayOfSubsystemStates}) = IndexCartesian() +Base.pointer(v::ArrayOfSubsystemStates) = pointer(parent(v)) +function Base.elsize(::Type{ArrayOfSubsystemStates{States, N, Store}}) where {States, N, Store} + sizeof(States) end -Base.size(v::VectorOfSubsystemStates{States}) where {States} = (size(v.data, 2),) - -@propagate_inbounds function Base.getindex(v::VectorOfSubsystemStates{States}, idx::Integer) where {States <: SubsystemStates} +@propagate_inbounds function Base.getindex(v::ArrayOfSubsystemStates{States}, idx::Integer...) where {States <: SubsystemStates} l = length(States) - @boundscheck checkbounds(v.data, 1:l, idx) - @inbounds States(view(v.data, 1:l, idx)) -end - -@noinline function sym_not_found_error(::Type{S}, s::Symbol) where {S<:SubsystemStates} - error("$S does not have a field $s") + data = parent(v) + @boundscheck checkbounds(data, 1:l, idx...) + @inbounds States(view(data, 1:l, idx...)) end - -@propagate_inbounds function Base.getindex(v::VectorOfSubsystemStates{States}, s::Symbol, idx::Integer) where {States <: SubsystemStates} - i = state_ind(States, s) - if isnothing(i) - sym_not_found_error(States, s) - end - v.data[i, idx] -end - -@propagate_inbounds function Base.setindex!(v::VectorOfSubsystemStates{States}, state::States, idx::Integer) where {States <: SubsystemStates} +@propagate_inbounds function Base.setindex!(v::ArrayOfSubsystemStates{States}, state::States′, idx::Integer...) where {States <: SubsystemStates, States′ <: SubsystemStates} l = length(States) - @boundscheck checkbounds(v.data, 1:l, idx) - @inbounds v.data[1:l, idx] .= Tuple(state) + data = parent(v) + @boundscheck checkbounds(data, 1:l, idx...) + @inbounds data[1:l, idx...] .= Tuple(convert(States, state)) v end -@propagate_inbounds function Base.setindex!(v::VectorOfSubsystemStates{States}, - val, - s::Symbol, - idx::Integer) where {States <: SubsystemStates} - i = state_ind(States, s) - if isnothing(i) - sym_not_found_error(States, s) - end - v.data[i, idx] = val -end +Base.IndexStyle(::Type{<:VectorOfSubsystemStates}) = IndexLinear() - -#------------------------------------------------------------------------- -struct SubsystemStatesView{States, Mat <: AbstractMatrix} <: AbstractArray{States, 0} - data::Mat - idx::Int -end -@propagate_inbounds function Base.view(v::VectorOfSubsystemStates{States, Mat}, idx::Int) where {States, Mat} +@propagate_inbounds function Base.getindex(v::VectorOfSubsystemStates{States}, idx::Integer) where {States <: SubsystemStates} l = length(States) - @boundscheck checkbounds(v.data, 1:l, idx) - SubsystemStatesView{States, Mat}(v.data, idx) + data = parent(v) + # @boundscheck checkbounds(data, :, idx) + # @inbounds + States(view(data, 1:l, idx)) end -Base.size(::SubsystemStatesView) = () -function Base.getindex(v::SubsystemStatesView{States}) where {States <: SubsystemStates} +@propagate_inbounds function Base.setindex!(v::VectorOfSubsystemStates{States}, state::States, idx::Integer) where {States <: SubsystemStates} l = length(States) - @inbounds States(view(v.data, 1:l, v.idx)) -end -@propagate_inbounds function Base.getindex(v::SubsystemStatesView{States}, s::Symbol) where {States <: SubsystemStates} - i = state_ind(States, s) - idx = v.idx - if isnothing(i) - sym_not_found_error(States, s) - end - @boundscheck checkbounds(v.data, i, idx) - @inbounds v.data[i, idx] + data = parent(v) + # @boundscheck checkbounds(data, :, idx) + # @inbounds + data[1:l, idx] .= Tuple(state) + v end -@propagate_inbounds function Base.setindex!(v::SubsystemStatesView{States}, state::States) where {States <: SubsystemStates} +function Base.getproperty(v::ArrayOfSubsystemStates, prop::Symbol) + FieldView{prop}(v) +end +@propagate_inbounds function Base.view(v::ArrayOfSubsystemStates{States}, inds...) where {States} l = length(States) - idx = v.idx - @boundscheck checkbounds(v.data, 1:l, idx) - tup = Tuple(state) - @inbounds begin - @simd for i ∈ 1:l - v.data[i, idx] = tup[i] - end + ArrayOfSubsystemStates{States}(view(parent(v), :, inds...)) +end + +#------------------------------------------------------------------------- + +struct ArrayOfSubsystems{T, N, Subsys<:Subsystem{T}, StateStore <:AbstractArray{<:SubsystemStates, N}, ParamStore <: AbstractArray{<:SubsystemParams, N}} <: AbstractArray{Subsys, N} + states::StateStore + params::ParamStore + function ArrayOfSubsystems(vstates::AbstractArray{SubsystemStates{T, Elt, SNT}, N}, + vparams::AbstractArray{SubsystemParams{T, PNT}, N} + ) where {T, Elt, N, SNT, PNT} + @assert size(vstates) == size(vparams) + new{T, N, Subsystem{T, Elt, SNT, PNT}, typeof(vstates), typeof(vparams)}(vstates, vparams) end - v end +const VectorOfSubsystems{States, Store} = ArrayOfSubsystems{States, 1, Store} +Base.size(v::ArrayOfSubsystems) = size(getfield(v, :states)) +Base.IndexStyle(::Type{<:ArrayOfSubsystems}) = IndexLinear() +get_states(x::ArrayOfSubsystems) = getfield(x, :states) +get_params(x::ArrayOfSubsystems) = getfield(x, :params) + -function Base.setindex!(v::SubsystemStatesView{States1}, state::States2) where {States1 <: SubsystemStates, States2 <: SubsystemStates} - state′ = convert(States1, state) - setindex!(v, state′) +@propagate_inbounds function Base.getindex(v::ArrayOfSubsystems, idx::Integer) + vstates = getfield(v, :states) + vparams = getfield(v, :params) + @boundscheck checkbounds(vstates, idx) + states = @inbounds vstates[idx] + params = @inbounds vparams[idx] + Subsystem(states, params) end +@propagate_inbounds function Base.setindex!(v::ArrayOfSubsystems{T, N, Subsys}, sys::Subsystem, idx::Integer) where {T, N, Subsys} + vstates = getfield(v, :states) + vparams = getfield(v, :params) + @boundscheck checkbounds(vstates, idx) + states = @inbounds vstates[idx] = get_states(sys) + params = @inbounds vparams[idx] = get_params(sys) + v +end -@propagate_inbounds function Base.setindex!(v::SubsystemStatesView{States}, val, s::Symbol) where {States <: SubsystemStates} - i = state_ind(States, s) - idx = v.idx - if isnothing(i) - sym_not_found_error(States, s) +function Base.getproperty(v::ArrayOfSubsystems{T, N, Subsystem{T, Elt, SNT, PNT}}, prop::Symbol) where {T, N, Elt, SNT, PNT} + if hasfield(SNT, prop) + FieldView{prop}(getfield(v, :states)) + elseif hasfield(PNT, prop) + FieldView{prop}(getfield(v, :params)) + else + @noinline errf(T, prop) = error("Type $T has no property $prop") + errf(eltype(v), prop) end - @boundscheck checkbounds(v.data, i, idx) - @inbounds v.data[i, v.idx] = val - v end + +function Base.view(v::ArrayOfSubsystems, args...) + vstates = view(getfield(v, :states), args...) + vparams = view(getfield(v, :params), args...) + ArrayOfSubsystems(vstates, vparams) +end + +get_parent_index(x::SubArray{T, 0}) where {T} = only(x.indices) +get_parent_index(x::ArrayOfSubsystems{T, 0}) where {T} = get_parent_index(get_params(x)) diff --git a/src/symbolic_indexing.jl b/src/symbolic_indexing.jl index daab92e..2f55dfd 100644 --- a/src/symbolic_indexing.jl +++ b/src/symbolic_indexing.jl @@ -1,94 +1,80 @@ -struct GraphSystemFunction{F, PGS <: PartitionedGraphSystem} <: Function - f::F - sys::PGS -end -(f::GraphSystemFunction{F})(args...; kwargs...) where {F} = f.f(args...; kwargs...) - -struct StateIndex - idx::Int -end -struct ParamIndex - tup_index::Int - v_index::Int - prop::Symbol -end -struct CompuIndex - tup_index::Int - v_index::Int - prop::Symbol - requires_inputs::Bool -end -struct ConnectionIndex - nc::Int - i_src::Int - i_dst::Int - j_src::Int - j_dst::Int - connection_key::Symbol - prop::Symbol -end - - -function make_state_namemap(names_partitioned::NTuple{N, Vector{Symbol}}, - states_partitioned::NTuple{N, AbstractVector{<:SubsystemStates}}) where {N} - namemap = OrderedDict{Symbol, StateIndex}() - idx = 1 - for i ∈ eachindex(names_partitioned, states_partitioned) - for j ∈ eachindex(names_partitioned[i], states_partitioned[i]) - states = states_partitioned[i][j] - for (k, name) ∈ enumerate(propertynames(states)) - propname = Symbol(names_partitioned[i][j], "₊", name) - namemap[propname] = StateIndex(idx) - idx += 1 - end - end - end - namemap -end - -function make_param_namemap(names_partitioned::NTuple{N, Vector{Symbol}}, - params_partitioned::NTuple{N, AbstractVector{<:SubsystemParams}}) where {N} - namemap = OrderedDict{Symbol, ParamIndex}() - for i ∈ eachindex(names_partitioned, params_partitioned) - for j ∈ eachindex(names_partitioned[i], params_partitioned[i]) - params = params_partitioned[i][j] - for name ∈ propertynames(params) - propname = Symbol(names_partitioned[i][j], "₊", name) - #TODO: this'll require some generalization to support weight params - namemap[propname] = ParamIndex(i, j, name) +function GraphNamemap(names_partitioned::Tuple, + states_partitioned::Tuple, + params_partitioned::Tuple, + connection_matrices::ConnectionMatrices) + nm = GraphNamemap(OrderedDict{Symbol,StateIndex}(), + OrderedDict{Symbol,ParamIndex}(), + OrderedDict{Symbol,CompuIndex}(), + OrderedDict{Symbol, ConnectionIndex}()) + + populate_property_namemaps!(nm, names_partitioned, states_partitioned, params_partitioned) + populate_connection_namemap!(nm.connection_namemap, names_partitioned, connection_matrices) + nm +end + +@generated function populate_property_namemaps!( + namemaps, + names_partitioned::NTuple{N, Vector{Symbol}}, + states_partitioned::NTuple{N, AbstractVector{<:SubsystemStates}}, + params_partitioned::NTuple{N, AbstractVector{<:SubsystemParams}}) where {N} + quote + sidx = 1 + @nexprs $N i -> begin + for j ∈ eachindex(names_partitioned[i], states_partitioned[i], params_partitioned[i]) + states = states_partitioned[i][j] + params = params_partitioned[i][j] + for name ∈ propertynames(states) + propname = Symbol(names_partitioned[i][j], "₊", name) + namemaps.state_namemap[propname] = StateIndex(sidx) + sidx += 1 + end + for name ∈ propertynames(params) + propname = Symbol(names_partitioned[i][j], "₊", name) + namemaps.param_namemap[propname] = ParamIndex(i, j, name) + end + sys = Subsystem(states, params) + tag = get_tag(sys) + for name ∈ keys(computed_properties(tag)) + requires_inputs = false + propname = Symbol(names_partitioned[i][j], "₊", name) + namemaps.compu_namemap[propname] = CompuIndex(i, j, name, requires_inputs) + end + for name ∈ keys(computed_properties_with_inputs(tag)) + requires_inputs = true + propname = Symbol(names_partitioned[i][j], "₊", name) + namemaps.compu_namemap[propname] = CompuIndex(i, j, name, requires_inputs) + end end end end - namemap -end - - -function make_compu_namemap(names_partitioned::NTuple{N, Vector{Symbol}}, - states_partitioned::NTuple{N, AbstractVector{<:SubsystemStates}}, - params_partitioned::NTuple{N, AbstractVector{<:SubsystemParams}}) where {N} - namemap = OrderedDict{Symbol, CompuIndex}() - for i ∈ eachindex(names_partitioned, states_partitioned, params_partitioned) - for j ∈ eachindex(names_partitioned[i], states_partitioned[i], params_partitioned[i]) - states = states_partitioned[i][j] - params = params_partitioned[i][j] - sys = Subsystem(states, params) - tag = get_tag(sys) - for name ∈ keys(computed_properties(tag)) - requires_inputs = false - propname = Symbol(names_partitioned[i][j], "₊", name) - namemap[propname] = CompuIndex(i, j, name, requires_inputs) - end - for name ∈ keys(computed_properties_with_inputs(tag)) - requires_inputs = true - propname = Symbol(names_partitioned[i][j], "₊", name) - namemap[propname] = CompuIndex(i, j, name, requires_inputs) +end + +@generated function populate_connection_namemap!( + namemap, + names_partitioned::NTuple{Len, Any}, + connection_matrices::ConnectionMatrices{NConn}) where {Len, NConn} + quote + @nexprs $Len k -> begin + @nexprs $Len i -> begin + @nexprs $NConn nc -> begin + M = connection_matrices[nc].data[k][i] + if !(M isa NotConnected) + for j ∈ eachindex(names_partitioned[i]) + for (l, conn) ∈ maybe_sparse_enumerate_col(M, j) + name_kl = names_partitioned[k][l] + name_ij = names_partitioned[i][j] + for (prop, name) ∈ pairs(connection_property_namemap(conn, name_kl, name_ij)) + namemap[name] = ConnectionIndex(nc, k, i, l, j, name, prop) + end + end + end + end + end end end end - namemap end - function Base.getindex(u::AbstractArray, idx::StateIndex) u[idx.idx] end @@ -109,10 +95,13 @@ end function Base.getindex(cm::ConnectionMatrices, (; nc, i_src, i_dst, j_src, j_dst, prop)::ConnectionIndex) conn = cm[nc].data[i_src][i_dst][j_src, j_dst] - getproperty(conn, prop) + if isnothing(prop) + conn + else + getproperty(conn, prop) + end end - function Base.setindex!(u::GraphSystemParameters, val, p::ParamIndex) setindex!(u.params_partitioned, val, p) end @@ -126,28 +115,31 @@ function Base.setindex!(u::GraphSystemParameters, val, p::ConnectionIndex) setindex!(u.connection_matrices, val, p) end function Base.setindex!(u::ConnectionMatrices, val, (; nc, i_src, i_dst, j_src, j_dst, prop)::ConnectionIndex) - params = u[tup_index][v_index] - @reset params[prop] = val - setindex!(u[tup_index], params, v_index) + M = u[nc].data[i_src][i_dst] + if isnothing(prop) + M[j_src, j_dst] = val + else + M[j_src, j_dst] = setproperties(M[j_src, j_dst], NamedTuple{(prop,)}(val)) + end end -function SymbolicIndexingInterface.is_variable(g::PartitionedGraphSystem, sym) +function SymbolicIndexingInterface.is_variable(g::GraphNamemap, sym) haskey(g.state_namemap, sym) end -function SymbolicIndexingInterface.variable_index(f::PartitionedGraphSystem, sym) +function SymbolicIndexingInterface.variable_index(f::GraphNamemap, sym) get(f.state_namemap, sym, nothing) end -function SymbolicIndexingInterface.variable_symbols(g::PartitionedGraphSystem) +function SymbolicIndexingInterface.variable_symbols(g::GraphNamemap) collect(keys(g.state_namemap)) end -function SymbolicIndexingInterface.is_parameter(g::PartitionedGraphSystem, sym) +function SymbolicIndexingInterface.is_parameter(g::GraphNamemap, sym) haskey(g.param_namemap, sym) || haskey(g.connection_namemap, sym) end -function SymbolicIndexingInterface.parameter_index(g::PartitionedGraphSystem, sym) +function SymbolicIndexingInterface.parameter_index(g::GraphNamemap, sym) if haskey(g.param_namemap, sym) g.param_namemap[sym] else @@ -171,48 +163,55 @@ function SymbolicIndexingInterface.set_parameter!(p::GraphSystemParameters, val, params = params_partitioned[tup_index][v_index] params_new = set_param_prop(params, prop, val; allow_typechange=false) params_partitioned[tup_index][v_index] = params_new - p + val end function SymbolicIndexingInterface.set_parameter!(buffer::GraphSystemParameters, value, conn_index::ConnectionIndex) - (;connection_matrices, connection_namemap) = buffer + (;connection_matrices) = buffer (; nc, i_src, i_dst, j_src, j_dst, connection_key, prop) = conn_index conn_old = connection_matrices[nc][i_src, i_dst][j_src, j_dst] conn_new = setproperties(conn_old, NamedTuple{(prop,)}(value)) connection_matrices[nc][i_src, i_dst][j_src, j_dst] = conn_new - buffer + value +end + +function SymbolicIndexingInterface.parameter_index(p::GraphSystemParameters, sym) + parameter_index(p.symbolic_indexing_namemap, sym) end +function SymbolicIndexingInterface.is_parameter(p::GraphSystemParameters, sym) + is_parameter(p.symbolic_indexing_namemap, sym) +end -function SymbolicIndexingInterface.parameter_symbols(g::PartitionedGraphSystem) +function SymbolicIndexingInterface.parameter_symbols(g::GraphNamemap) collect(Iterators.flatten((keys(g.param_namemap), keys(g.connection_namemap)))) end -function SymbolicIndexingInterface.is_independent_variable(sys::PartitionedGraphSystem, sym) +function SymbolicIndexingInterface.is_independent_variable(sys::GraphNamemap, sym) sym === :t end -function SymbolicIndexingInterface.independent_variable_symbols(sys::PartitionedGraphSystem) +function SymbolicIndexingInterface.independent_variable_symbols(sys::GraphNamemap) (:t,) end -function SymbolicIndexingInterface.is_time_dependent(sys::PartitionedGraphSystem) +function SymbolicIndexingInterface.is_time_dependent(sys::GraphNamemap) true end -function SymbolicIndexingInterface.observed(f::ODEFunction{a, b, F}, sym::Symbol) where {a, b, F<:GraphSystemFunction} - observed(f.f.sys, sym) +function SymbolicIndexingInterface.observed(f::ODEFunction{a, b, typeof(graph_ode!)}, sym::Symbol) where {a, b} + observed(f.sys, sym) end -function SymbolicIndexingInterface.observed(f::SDEFunction{a, b, F}, sym::Symbol) where {a, b, F<:GraphSystemFunction} - observed(f.f.sys, sym) +function SymbolicIndexingInterface.observed(f::SDEFunction{a, b, typeof(graph_ode!)}, sym::Symbol) where {a, b} + observed(f.sys, sym) end -function SymbolicIndexingInterface.is_observed(sys::PartitionedGraphSystem, sym) +function SymbolicIndexingInterface.is_observed(sys::GraphNamemap, sym) haskey(sys.compu_namemap, sym) end -function SymbolicIndexingInterface.observed(sys::PartitionedGraphSystem, syms::Union{Vector{Symbol}, Tuple{Vararg{Symbol}}}) +function SymbolicIndexingInterface.observed(sys::GraphNamemap, syms::Union{Vector{Symbol}, Tuple{Vararg{Symbol}}}) function (u, p, t) map(syms) do sym observed(sys, sym)(u, p, t) @@ -220,7 +219,7 @@ function SymbolicIndexingInterface.observed(sys::PartitionedGraphSystem, syms::U end end -function SymbolicIndexingInterface.observed(sys::PartitionedGraphSystem, sym) +function SymbolicIndexingInterface.observed(sys::GraphNamemap, sym) (; tup_index, v_index, prop, requires_inputs) = sys.compu_namemap[sym] # lift these to the type domain so that we specialize on them in the returned closures @@ -249,28 +248,13 @@ function SymbolicIndexingInterface.observed(sys::PartitionedGraphSystem, sym) end end -# function SymbolicIndexingInterface.all_solvable_symbols(sys::PartitionedGraphSystem) -# vcat( -# collect(keys(sys.state_namemap)), -# collect(keys(sys.observed_namemap)), -# ) -# end - -# function SymbolicIndexingInterface.all_symbols(sys::PartitionedGraphSystem) -# vcat( -# all_solvable_symbols(sys), -# collect(keys(sys.param_namemap)), -# :t -# ) -# end - function SymbolicIndexingInterface.remake_buffer(sys, oldbuffer::GraphSystemParameters, idxs, vals) newbuffer = copy(oldbuffer) set_params!!(newbuffer, zip(idxs, vals)) end function set_params!!(buffer::GraphSystemParameters, param_map) - (; param_namemap, connection_namemap) = buffer + (; param_namemap, connection_namemap) = buffer.symbolic_indexing_namemap for (key, val) ∈ param_map if haskey(param_namemap, key) buffer = set_param!!(buffer, param_namemap[key], val) @@ -312,10 +296,10 @@ function re_eltype_params(params_partitioned) end function set_param!!(buffer::GraphSystemParameters, conn_index::ConnectionIndex, value) - (;connection_matrices, connection_namemap) = buffer + (;connection_matrices, symbolic_indexing_namemap) = buffer (; nc, i_src, i_dst, j_src, j_dst, connection_key, prop) = conn_index conn_old = connection_matrices[nc][i_src, i_dst][j_src, j_dst] - conn_new = setproperties(conn_old, NamedTuple{(prop,)}(value)) + conn_new = setproperties(conn_old, NamedTuple{(prop,)}((value,))) CR_new = typeof(conn_new) CR_old = typeof(conn_old) if !(CR_new <: CR_old) @@ -340,7 +324,7 @@ function set_param!!(buffer::GraphSystemParameters, conn_index::ConnectionIndex, # Update the position in the namemap let conn_index_new = @set conn_index.nc = nc_new - connection_namemap[connection_key] = conn_index_new # This is important so we don't lose track of where the parameter moved to! + symbolic_indexing_namemap.connection_namemap[connection_key] = conn_index_new # This is important so we don't lose track of where the parameter moved to! end # Delete the old element! diff --git a/test/particle_osc_example.jl b/test/particle_osc_example.jl index dae73b0..5c5ce8a 100644 --- a/test/particle_osc_example.jl +++ b/test/particle_osc_example.jl @@ -128,10 +128,10 @@ function particle_osc_prob(;x1, x2, m=3.0, mp1=1.0, kc_p1_p2=1, tspan = (0.0, 10 add_connection!(g, particle1, particle2; conn=Coulomb(1)) add_connection!(g, particle2, particle1; conn=Coulomb(1)) - gsys = PartitionedGraphSystem(g) + # gsys._namemap[:w_particle1_particle2_spring] = prob = ODEProblem( - gsys, + g, [:particle1₊x => x1, :particle2₊x => x2, :particle2₊v => 0.0, :osc₊x => 0.0 ], tspan, @@ -176,58 +176,60 @@ function test_derivative(f, v; rtol=1e-3) end function sensitivity_test() - @testset "Sensitivities" begin - @testset "via constructor" begin + # @testset "Sensitivities" + begin + # @testset "via constructor" + begin test_jacobian([1.0, -1.0]) do (x1, x2) sol = solve_particle_osc(;x1, x2, reltol=1e-8) - [sol[:particle1₊x, end], sol[:particle2₊x, end], sol[:osc₊x, end]] + [sol[:particle1₊x][end], sol[:particle2₊x][end], sol[:osc₊x][end]] end test_derivative(3.0) do m sol = solve_particle_osc(;x1=one(typeof(m)), x2=-1.0, m=m, reltol=1e-8) - sol[:particle1₊x, end] + sol[:particle1₊x][end] end test_derivative(1.0) do fac sol = solve_particle_osc(;x1=1.0, x2=-1.0, kc_p1_p2=fac, reltol=1e-8) - [sol[:particle1₊x, end], sol[:particle2₊x, end], sol[:osc₊x, end]] + [sol[:particle1₊x][end], sol[:particle2₊x][end], sol[:osc₊x][end]] end test_jacobian([1.0, -1.0, 3.0, 1.0, 2.0]) do (x1, x2, m, kc_p1_p2, mp1) sol = solve_particle_osc(;x1, x2, m, kc_p1_p2, mp1, reltol=1e-8) - [sol[:particle1₊x, end], sol[:particle2₊x, end], sol[:osc₊x, end]] + [sol[:particle1₊x][end], sol[:particle2₊x][end], sol[:osc₊x][end]] end end - @testset "via remake" begin + # @testset "via remake" + begin prob = particle_osc_prob(; x1=1.0, x2=-1.0) test_jacobian([1.0, -1.0]) do (x1, x2) prob2 = remake(prob; u0=[:particle1₊x => x1, :particle2₊x => x2]) sol = solve(prob2, Tsit5(); reltol=1e-8) - [sol[:particle1₊x, end], sol[:particle2₊x, end], sol[:osc₊x, end]] + [sol[:particle1₊x][end], sol[:particle2₊x][end], sol[:osc₊x][end]] end test_derivative(3.0) do m prob2 = remake(prob; p=[:osc₊m => m]) sol = solve(prob2, Tsit5(); reltol=1e-8) - sol[:particle1₊x, end] + sol[:particle1₊x][end] end test_derivative(1.0) do k prob2 = remake(prob; p=[:k_Spring_particle1_osc => k]) sol = solve(prob2, Tsit5(); reltol=1e-8) - sol[:particle1₊x, end] + sol[:particle1₊x][end] end test_jacobian([1.0, -1.0, 3.0, 1.0, 2.0]) do (x1, x2, m, kc_p1_p2, mp1) prob2 = remake(prob; u0=[:particle1₊x => x1, :particle2₊x => x2], p=[:osc₊m => m, :particle1₊m => mp1, :fac_coulomb_particle1_particle2 => kc_p1_p2]) sol = solve(prob2, Tsit5(); reltol=1e-8) - [sol[:particle1₊x, end], sol[:particle2₊x, end], sol[:osc₊x, end]] + [sol[:particle1₊x][end], sol[:particle2₊x][end], sol[:osc₊x][end]] end - end end end diff --git a/test/runtests.jl b/test/runtests.jl index c568799..3c1f9b3 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1,6 +1,7 @@ using SafeTestsets -@safetestset "Particle/Oscillator example" begin +# @safetestset "Particle/Oscillator example" +begin include("particle_osc_example.jl") solution_solve_test() sensitivity_test() diff --git a/test/symbolic_indexing.jl b/test/symbolic_indexing.jl index 05cda1e..a3d2539 100644 --- a/test/symbolic_indexing.jl +++ b/test/symbolic_indexing.jl @@ -17,10 +17,9 @@ end @test getp(prob, :particle1₊m)(prob) == 2 # Test type promotion and conversion - @test_broken begin - (prob, :particle1₊m)(prob, 20) - @test getp(prob, :particle1₊m)(prob) === 20.0 - end + setp(prob, :particle1₊m)(prob, 20) + @test getp(prob, :particle1₊m)(prob) === 20.0 + # Test on connections as well setp(prob, :fac_coulomb_particle1_particle2)(prob, 100) @test getp(prob, :fac_coulomb_particle1_particle2)(prob) == 100 @@ -32,7 +31,7 @@ end end # Error on type-unstable change - @test_throws ErrorException setp(prob, :particle1₊m)(prob, ones(3)) + @test_throws Exception setp(prob, :particle1₊m)(prob, ones(3)) # Remake prob = remake(prob, p = [:particle1₊m => 2 + 3im, :particle2₊m => 3 + 2im])