Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "GraphDynamics"
uuid = "bcd5d0fe-e6b7-4ef1-9848-780c183c7f4c"
version = "0.4.9"
version = "0.5.0"

[workspace]
projects = ["test", "scrap"]
Expand Down
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ struct Spring <: ConnectionRule
k::Float64
end

function ((;k)::Spring)(src::Subsystem, dst::Subsystem)
function ((;k)::Spring)(src::Subsystem, dst::Subsystem, t)
# Calculate the force on subsystem `dst` due to being connected with
# subsystem `src` by a spring with spring constant `k`.
F = k * (src.x - dst.x)
Expand All @@ -119,7 +119,7 @@ end
``` julia
struct Coulomb <: ConnectionRule end

function (::Coulomb)(src::Subsystem, dst::Subsystem)
function (::Coulomb)(src::Subsystem, dst::Subsystem, t)
# Calculate the Coulomb force on subsystem `dst` due to the charge of subsystem `src`
F = -src.q * dst.q * sign(src.x - dst.x)/(abs(src.x - dst.x))^2
# Return the `input` being sent to the `dst` subsystem
Expand Down
25 changes: 15 additions & 10 deletions src/GraphDynamics.jl
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ export


#----------------------------------------------------------
using Base: @kwdef, @propagate_inbounds
using Base: @kwdef, @propagate_inbounds, isassigned
using Base.Iterators: map as imap

using Base.Cartesian: @nexprs
Expand Down Expand Up @@ -232,11 +232,6 @@ end
"""
computed_properties(s::Subsystem) = (;)

# TODO: delete this in the next breaking release.
# Accidentally released this with computed_properies
# as the name to use
const computed_properies = computed_properties


"""
computed_properties_with_inputs(s::Subsystem)
Expand Down Expand Up @@ -382,6 +377,7 @@ for s ∈ [:continuous, :discrete]
$has_events(::Type{<:Subsystem{T}}) where {T} = $has_events(T)
$has_events(::Type{<:SubsystemStates{T}}) where {T} = $has_events(T)
$has_events(::Type{T}) where {T} = false
$has_events(::Type, ::Type, ::Type) = false

$events_require_inputs(::Subsystem{T}) where {T} = $events_require_inputs(T)
$events_require_inputs(::Type{<:Subsystem{T}}) where {T} = $events_require_inputs(T)
Expand All @@ -391,19 +387,28 @@ for s ∈ [:continuous, :discrete]
end
end


"""
event_times(::Subsystem{SysType}) = ()

add methods to this function if a subsystem type `SysType` has a discrete event that triggers at pre-defined times. This will be used to add `tstops` to the `ODEProblem` or `SDEProblem` automatically during `GraphSystem` construction. This is vital for discrete events which only trigger at a specific time.
"""
event_times(::T) = ()
event_times(::Subsystem) = ()

add methods to this function if a subsystem or connection type has a discrete event that triggers at pre-defined times. This will be used to add `tstops` to the `ODEProblem` or `SDEProblem` automatically during `GraphSystem` construction. This is vital for discrete events which only trigger at a specific time.
"""
event_times(::Any) = ()
event_times(::ConnType, ::Subsystem{SysSrc}, ::Subsystem{SysDst}) = ()

add methods to this function if a connection type `ConnType` has a discrete event that triggers at pre-defined times. This will be used to add `tstops` to the `ODEProblem` or `SDEProblem` automatically during `GraphSystem` construction. This is vital for discrete events which only trigger at a specific time.
"""
event_times(::Any, ::Any, ::Any) = ()

abstract type ConnectionRule end
(c::ConnectionRule)(src, dst, t) = c(src, dst)
Base.zero(::T) where {T <: ConnectionRule} = zero(T)

struct NotConnected{CR} end
Base.getindex(::NotConnected{CR}, inds...) where {CR} = zero(CR)
Base.eltype(::NotConnected{CR}) where {CR} = CR

Base.copy(c::NotConnected) = c
struct ConnectionMatrix{N, CR, Tup <: NTuple{N, NTuple{N, Union{NotConnected{CR}, AbstractMatrix{CR}}}}}
data::Tup
Expand Down
26 changes: 18 additions & 8 deletions src/graph_solve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -210,6 +210,9 @@ end
function maybe_sparse_enumerate_col(M::AbstractMatrix, j)
enumerate(@view(M[:, j]))
end
function maybe_sparse_enumerate_col(::NotConnected, j)
()
end


#----------------------------------------------------------
Expand Down Expand Up @@ -322,7 +325,7 @@ function _continuous_affect!(integrator,
end
end
end
end
end

#----------------------------------------------------------
# Infra. for discrete events.
Expand Down Expand Up @@ -357,10 +360,15 @@ end
@nexprs $Len i -> begin
@nexprs $Len k -> begin
M = connection_matrices[nc].data[k][i] # Same as cm[nc][k,i] but performs better when there's many types
if has_discrete_events(eltype(M))

if !(M isa NotConnected) && has_discrete_events(eltype(M),
get_tag(eltype(states_partitioned[i])),
get_tag(eltype(states_partitioned[k])))
for j ∈ eachindex(states_partitioned[i])
sys_dst = Subsystem(states_partitioned[i][j], params_partitioned[i][j])
for (l, Mlj) ∈ maybe_sparse_enumerate_col(M, j)
discrete_event_condition(Mlj, t) && return true
sys_src = Subsystem(states_partitioned[k][l], params_partitioned[k][l])
discrete_event_condition(Mlj, t, sys_src, sys_dst) && return true
end
end
end
Expand Down Expand Up @@ -432,10 +440,12 @@ function _discrete_connection_affect!(::Val{i}, ::Val{k}, ::Val{nc}, t,
sview_dst = @view states_partitioned[i][j]
pview_dst = @view params_partitioned[i][j]
M = connection_matrices.matrices[nc].data[k][i]
if has_discrete_events(eltype(M))
if !(M isa NotConnected) && has_discrete_events(eltype(M),
get_tag(eltype(states_partitioned[i])),
get_tag(eltype(states_partitioned[k])))
for (l, Mlj) ∈ maybe_sparse_enumerate_col(M, j)
if discrete_event_condition(Mlj, t)
sys_src = Subsystem(states_partitioned[k][l], params_partitioned[k][l])
sys_src = Subsystem(states_partitioned[k][l], params_partitioned[k][l])
if discrete_event_condition(Mlj, t, sys_src, sys_dst)
sview_src = @view states_partitioned[k][l]
pview_src = @view params_partitioned[k][l]
if discrete_events_require_inputs(typeof(Mlj))
Expand Down Expand Up @@ -530,8 +540,8 @@ end
nothing
else
for j ∈ eachindex(states_partitioned[i])
@inbounds conn = M[l, j]
if !iszero(conn)
if isassigned(M, l, j)
conn = M[l, j]
@inbounds states_view_dst = @view states_partitioned[i][j]
@inbounds params_view_dst = @view params_partitioned[i][j]
sys_dst = Subsystem(states_view_dst[], params_view_dst[])
Expand Down
21 changes: 13 additions & 8 deletions src/graph_system.jl
Original file line number Diff line number Diff line change
Expand Up @@ -82,12 +82,15 @@ function system_wiring_rule!(g, src, dst; kwargs...)
add_connection!(g, src, dst; conn=kwargs[:conn], kwargs...)
end

@kwdef struct PartitionedGraphSystem{CM <: ConnectionMatrices, S, P, EVT, Ns, CONM, SNM, PNM, CNM, EP}
@kwdef struct PartitionedGraphSystem{CM <: ConnectionMatrices, S, P, SP, EVT, Ns, CONM, SNM, PNM, CNM, EP}
graph::Union{Nothing, GraphSystem} = nothing
flat_graph::Union{Nothing, GraphSystem} = nothing
connection_matrices::CM
states_partitioned::S
params_partitioned::P
subsystems_partitioned::SP = map(i -> map(j -> Subsystem(states_partitioned[i][j], params_partitioned[i][j]),
eachindex(states_partitioned[i], params_partitioned[i])),
eachindex(states_partitioned, params_partitioned))
tstops::EVT = Float64[]
names_partitioned::Ns
connection_namemap::CONM = make_connection_namemape(names_partitioned, connection_matrices)
Expand Down Expand Up @@ -174,7 +177,8 @@ function PartitionedGraphSystem(g::GraphSystem)

This allows for type-stable calculations involving the subsystems and their connections
===================================================================================================#
(;connection_matrices, connection_tstops, connection_namemap) = make_connection_matrices(g_flat, nodes_partitioned)
(;connection_matrices, connection_tstops, connection_namemap) = make_connection_matrices(g_flat, nodes_partitioned;
subsystems_partitioned, names_partitioned)

append!(tstops, connection_tstops)

Expand All @@ -184,6 +188,7 @@ function PartitionedGraphSystem(g::GraphSystem)
flat_graph=g_flat,
is_stochastic = any(isstochastic, node_types),
connection_matrices,
subsystems_partitioned,
states_partitioned,
params_partitioned,
tstops=unique!(tstops),
Expand All @@ -195,18 +200,18 @@ end
function make_partitioned_nodes(g_flat)
node_types = (unique ∘ imap)(typeof, nodes(g_flat))
nodes_partitioned = map(node_types) do T
if isstochastic(T)
system_is_stochastic = true
end
filter(collect(nodes(g_flat))) do sys
sys isa T
end
end
end


function make_connection_matrices(g_flat, nodes_partitioned=make_partitioned_nodes(g_flat);
pred=(_) -> true,
conn_key=:conn)
conn_key=:conn,
subsystems_partitioned=map(v -> map(to_subsystem, v), nodes_partitioned),
names_partitioned=map(v -> map(x -> convert(Symbol, get_name(x)), v), nodes_partitioned))
check_no_double_connections(g_flat, conn_key)
connection_types = (imap)(connections(g_flat)) do (; src, dst, data)
if haskey(data, conn_key) && pred(data[conn_key])
Expand Down Expand Up @@ -234,11 +239,11 @@ function make_connection_matrices(g_flat, nodes_partitioned=make_partitioned_nod
push!(ls, l)
push!(conns, conn)

for (prop, name) ∈ pairs(connection_property_namemap(conn, get_name(nodekl), get_name(nodeij)))
for (prop, name) ∈ pairs(connection_property_namemap(conn, names_partitioned[k][l], names_partitioned[i][j]))
connection_namemap[name] = ConnectionIndex(nc, k, i, l, j, name, prop)
end

for t ∈ event_times(conn)
for t ∈ event_times(conn, subsystems_partitioned[k][l], subsystems_partitioned[i][j])
push!(connection_tstops, t)
end
end
Expand Down
4 changes: 2 additions & 2 deletions test/particle_osc_example.jl
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ struct Spring{T} <: ConnectionRule
end
Base.zero(::Type{Spring{T}}) where {T} = Spring(zero(T))

function ((;k)::Spring)(src::Subsystem, dst::Subsystem)
function ((;k)::Spring)(src::Subsystem, dst::Subsystem, t)
# Calculate the force on subsystem `dst` due to being connected with
# subsystem `src` by a spring with spring constant `k`.
F = k * (src.x - dst.x)
Expand All @@ -101,7 +101,7 @@ struct Coulomb{T} <: ConnectionRule
end
Base.zero(::Type{Coulomb{T}}) where {T} = Coulomb(zero(T))

function ((;fac)::Coulomb)(src::Subsystem, dst::Subsystem)
function ((;fac)::Coulomb)(src::Subsystem, dst::Subsystem, t)
# Calculate the Coulomb force on subsystem `dst` due to the charge of subsystem `src`
F = -fac * src.q * dst.q * sign(src.x - dst.x)/(abs(src.x - dst.x))^2
# Return the `input` being sent to the `dst` subsystem
Expand Down