diff --git a/Project.toml b/Project.toml index 1a385cf..39a770b 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "GraphDynamics" uuid = "bcd5d0fe-e6b7-4ef1-9848-780c183c7f4c" -version = "0.4.9" +version = "0.5.0" [workspace] projects = ["test", "scrap"] diff --git a/README.md b/README.md index f270a2a..96454f5 100644 --- a/README.md +++ b/README.md @@ -107,7 +107,7 @@ struct Spring <: ConnectionRule k::Float64 end -function ((;k)::Spring)(src::Subsystem, dst::Subsystem) +function ((;k)::Spring)(src::Subsystem, dst::Subsystem, t) # Calculate the force on subsystem `dst` due to being connected with # subsystem `src` by a spring with spring constant `k`. F = k * (src.x - dst.x) @@ -119,7 +119,7 @@ end ``` julia struct Coulomb <: ConnectionRule end -function (::Coulomb)(src::Subsystem, dst::Subsystem) +function (::Coulomb)(src::Subsystem, dst::Subsystem, t) # Calculate the Coulomb force on subsystem `dst` due to the charge of subsystem `src` F = -src.q * dst.q * sign(src.x - dst.x)/(abs(src.x - dst.x))^2 # Return the `input` being sent to the `dst` subsystem diff --git a/src/GraphDynamics.jl b/src/GraphDynamics.jl index ff9739f..33422ed 100644 --- a/src/GraphDynamics.jl +++ b/src/GraphDynamics.jl @@ -69,7 +69,7 @@ export #---------------------------------------------------------- -using Base: @kwdef, @propagate_inbounds +using Base: @kwdef, @propagate_inbounds, isassigned using Base.Iterators: map as imap using Base.Cartesian: @nexprs @@ -232,11 +232,6 @@ end """ computed_properties(s::Subsystem) = (;) -# TODO: delete this in the next breaking release. -# Accidentally released this with computed_properies -# as the name to use -const computed_properies = computed_properties - """ computed_properties_with_inputs(s::Subsystem) @@ -382,6 +377,7 @@ for s ∈ [:continuous, :discrete] $has_events(::Type{<:Subsystem{T}}) where {T} = $has_events(T) $has_events(::Type{<:SubsystemStates{T}}) where {T} = $has_events(T) $has_events(::Type{T}) where {T} = false + $has_events(::Type, ::Type, ::Type) = false $events_require_inputs(::Subsystem{T}) where {T} = $events_require_inputs(T) $events_require_inputs(::Type{<:Subsystem{T}}) where {T} = $events_require_inputs(T) @@ -391,19 +387,28 @@ for s ∈ [:continuous, :discrete] end end + +""" + event_times(::Subsystem{SysType}) = () + +add methods to this function if a subsystem type `SysType` has a discrete event that triggers at pre-defined times. This will be used to add `tstops` to the `ODEProblem` or `SDEProblem` automatically during `GraphSystem` construction. This is vital for discrete events which only trigger at a specific time. """ - event_times(::T) = () +event_times(::Subsystem) = () -add methods to this function if a subsystem or connection type has a discrete event that triggers at pre-defined times. This will be used to add `tstops` to the `ODEProblem` or `SDEProblem` automatically during `GraphSystem` construction. This is vital for discrete events which only trigger at a specific time. """ -event_times(::Any) = () + event_times(::ConnType, ::Subsystem{SysSrc}, ::Subsystem{SysDst}) = () + +add methods to this function if a connection type `ConnType` has a discrete event that triggers at pre-defined times. This will be used to add `tstops` to the `ODEProblem` or `SDEProblem` automatically during `GraphSystem` construction. This is vital for discrete events which only trigger at a specific time. +""" +event_times(::Any, ::Any, ::Any) = () abstract type ConnectionRule end -(c::ConnectionRule)(src, dst, t) = c(src, dst) Base.zero(::T) where {T <: ConnectionRule} = zero(T) struct NotConnected{CR} end Base.getindex(::NotConnected{CR}, inds...) where {CR} = zero(CR) +Base.eltype(::NotConnected{CR}) where {CR} = CR + Base.copy(c::NotConnected) = c struct ConnectionMatrix{N, CR, Tup <: NTuple{N, NTuple{N, Union{NotConnected{CR}, AbstractMatrix{CR}}}}} data::Tup diff --git a/src/graph_solve.jl b/src/graph_solve.jl index 4f8729a..81afd2a 100644 --- a/src/graph_solve.jl +++ b/src/graph_solve.jl @@ -210,6 +210,9 @@ end function maybe_sparse_enumerate_col(M::AbstractMatrix, j) enumerate(@view(M[:, j])) end +function maybe_sparse_enumerate_col(::NotConnected, j) + () +end #---------------------------------------------------------- @@ -322,7 +325,7 @@ function _continuous_affect!(integrator, end end end -end +end #---------------------------------------------------------- # Infra. for discrete events. @@ -357,10 +360,15 @@ end @nexprs $Len i -> begin @nexprs $Len k -> begin M = connection_matrices[nc].data[k][i] # Same as cm[nc][k,i] but performs better when there's many types - if has_discrete_events(eltype(M)) + + if !(M isa NotConnected) && has_discrete_events(eltype(M), + get_tag(eltype(states_partitioned[i])), + get_tag(eltype(states_partitioned[k]))) for j ∈ eachindex(states_partitioned[i]) + sys_dst = Subsystem(states_partitioned[i][j], params_partitioned[i][j]) for (l, Mlj) ∈ maybe_sparse_enumerate_col(M, j) - discrete_event_condition(Mlj, t) && return true + sys_src = Subsystem(states_partitioned[k][l], params_partitioned[k][l]) + discrete_event_condition(Mlj, t, sys_src, sys_dst) && return true end end end @@ -432,10 +440,12 @@ function _discrete_connection_affect!(::Val{i}, ::Val{k}, ::Val{nc}, t, sview_dst = @view states_partitioned[i][j] pview_dst = @view params_partitioned[i][j] M = connection_matrices.matrices[nc].data[k][i] - if has_discrete_events(eltype(M)) + if !(M isa NotConnected) && has_discrete_events(eltype(M), + get_tag(eltype(states_partitioned[i])), + get_tag(eltype(states_partitioned[k]))) for (l, Mlj) ∈ maybe_sparse_enumerate_col(M, j) - if discrete_event_condition(Mlj, t) - sys_src = Subsystem(states_partitioned[k][l], params_partitioned[k][l]) + sys_src = Subsystem(states_partitioned[k][l], params_partitioned[k][l]) + if discrete_event_condition(Mlj, t, sys_src, sys_dst) sview_src = @view states_partitioned[k][l] pview_src = @view params_partitioned[k][l] if discrete_events_require_inputs(typeof(Mlj)) @@ -530,8 +540,8 @@ end nothing else for j ∈ eachindex(states_partitioned[i]) - @inbounds conn = M[l, j] - if !iszero(conn) + if isassigned(M, l, j) + conn = M[l, j] @inbounds states_view_dst = @view states_partitioned[i][j] @inbounds params_view_dst = @view params_partitioned[i][j] sys_dst = Subsystem(states_view_dst[], params_view_dst[]) diff --git a/src/graph_system.jl b/src/graph_system.jl index 2e70190..ecb775a 100644 --- a/src/graph_system.jl +++ b/src/graph_system.jl @@ -82,12 +82,15 @@ function system_wiring_rule!(g, src, dst; kwargs...) add_connection!(g, src, dst; conn=kwargs[:conn], kwargs...) end -@kwdef struct PartitionedGraphSystem{CM <: ConnectionMatrices, S, P, EVT, Ns, CONM, SNM, PNM, CNM, EP} +@kwdef struct PartitionedGraphSystem{CM <: ConnectionMatrices, S, P, SP, EVT, Ns, CONM, SNM, PNM, CNM, EP} graph::Union{Nothing, GraphSystem} = nothing flat_graph::Union{Nothing, GraphSystem} = nothing connection_matrices::CM states_partitioned::S params_partitioned::P + subsystems_partitioned::SP = map(i -> map(j -> Subsystem(states_partitioned[i][j], params_partitioned[i][j]), + eachindex(states_partitioned[i], params_partitioned[i])), + eachindex(states_partitioned, params_partitioned)) tstops::EVT = Float64[] names_partitioned::Ns connection_namemap::CONM = make_connection_namemape(names_partitioned, connection_matrices) @@ -174,7 +177,8 @@ function PartitionedGraphSystem(g::GraphSystem) This allows for type-stable calculations involving the subsystems and their connections ===================================================================================================# - (;connection_matrices, connection_tstops, connection_namemap) = make_connection_matrices(g_flat, nodes_partitioned) + (;connection_matrices, connection_tstops, connection_namemap) = make_connection_matrices(g_flat, nodes_partitioned; + subsystems_partitioned, names_partitioned) append!(tstops, connection_tstops) @@ -184,6 +188,7 @@ function PartitionedGraphSystem(g::GraphSystem) flat_graph=g_flat, is_stochastic = any(isstochastic, node_types), connection_matrices, + subsystems_partitioned, states_partitioned, params_partitioned, tstops=unique!(tstops), @@ -195,18 +200,18 @@ end function make_partitioned_nodes(g_flat) node_types = (unique ∘ imap)(typeof, nodes(g_flat)) nodes_partitioned = map(node_types) do T - if isstochastic(T) - system_is_stochastic = true - end filter(collect(nodes(g_flat))) do sys sys isa T end end end + function make_connection_matrices(g_flat, nodes_partitioned=make_partitioned_nodes(g_flat); pred=(_) -> true, - conn_key=:conn) + conn_key=:conn, + subsystems_partitioned=map(v -> map(to_subsystem, v), nodes_partitioned), + names_partitioned=map(v -> map(x -> convert(Symbol, get_name(x)), v), nodes_partitioned)) check_no_double_connections(g_flat, conn_key) connection_types = (imap)(connections(g_flat)) do (; src, dst, data) if haskey(data, conn_key) && pred(data[conn_key]) @@ -234,11 +239,11 @@ function make_connection_matrices(g_flat, nodes_partitioned=make_partitioned_nod push!(ls, l) push!(conns, conn) - for (prop, name) ∈ pairs(connection_property_namemap(conn, get_name(nodekl), get_name(nodeij))) + for (prop, name) ∈ pairs(connection_property_namemap(conn, names_partitioned[k][l], names_partitioned[i][j])) connection_namemap[name] = ConnectionIndex(nc, k, i, l, j, name, prop) end - for t ∈ event_times(conn) + for t ∈ event_times(conn, subsystems_partitioned[k][l], subsystems_partitioned[i][j]) push!(connection_tstops, t) end end diff --git a/test/particle_osc_example.jl b/test/particle_osc_example.jl index 6c09bde..d0b11c9 100644 --- a/test/particle_osc_example.jl +++ b/test/particle_osc_example.jl @@ -88,7 +88,7 @@ struct Spring{T} <: ConnectionRule end Base.zero(::Type{Spring{T}}) where {T} = Spring(zero(T)) -function ((;k)::Spring)(src::Subsystem, dst::Subsystem) +function ((;k)::Spring)(src::Subsystem, dst::Subsystem, t) # Calculate the force on subsystem `dst` due to being connected with # subsystem `src` by a spring with spring constant `k`. F = k * (src.x - dst.x) @@ -101,7 +101,7 @@ struct Coulomb{T} <: ConnectionRule end Base.zero(::Type{Coulomb{T}}) where {T} = Coulomb(zero(T)) -function ((;fac)::Coulomb)(src::Subsystem, dst::Subsystem) +function ((;fac)::Coulomb)(src::Subsystem, dst::Subsystem, t) # Calculate the Coulomb force on subsystem `dst` due to the charge of subsystem `src` F = -fac * src.q * dst.q * sign(src.x - dst.x)/(abs(src.x - dst.x))^2 # Return the `input` being sent to the `dst` subsystem