diff --git a/bloat-check/src/bin/bloat-check.rs b/bloat-check/src/bin/bloat-check.rs index eb8a7a5f9..fe427302e 100644 --- a/bloat-check/src/bin/bloat-check.rs +++ b/bloat-check/src/bin/bloat-check.rs @@ -59,7 +59,8 @@ use rs_matter::crypto::backend::rustcrypto::RustCrypto; use rs_matter::crypto::{Crypto, RngCore, WeakTestOnlyRand}; use rs_matter::dm::clusters::desc::{self, ClusterHandler as _, DescHandler}; use rs_matter::dm::clusters::net_comm::{ - NetCtl, NetCtlError, NetCtlStatus, NetworkScanInfo, NetworkType, Networks, WirelessCreds, + NetCtl, NetCtlError, NetCtlStatus, NetworkScanInfo, NetworkType, NetworksAccess, + SharedNetworks, WirelessCreds, }; use rs_matter::dm::clusters::on_off::NoLevelControl; use rs_matter::dm::clusters::on_off::{self, test::TestOnOffDeviceLogic, OnOffHooks}; @@ -80,6 +81,7 @@ use rs_matter::dm::{Async, DataModel, Dataver, EmptyHandler, Endpoint, EpClMatch use rs_matter::error::Error; use rs_matter::pairing::qr::QrTextType; use rs_matter::pairing::DiscoveryCapabilities; +use rs_matter::persist::{DummyKvBlobStore, SharedKvBlobStore}; use rs_matter::respond::DefaultResponder; use rs_matter::sc::pase::MAX_COMM_WINDOW_TIMEOUT_SECS; use rs_matter::tlv::Nullable; @@ -157,7 +159,7 @@ struct MatterStack<'a> { buffers: PooledBuffers<10, IMBuffer>, subscriptions: Subscriptions<{ DEFAULT_MAX_SUBSCRIPTIONS }>, events: Events, - networks: WifiNetworks<3>, + networks: SharedNetworks>, net_ctl_state: NetCtlStateMutex, btp: Btp, wireless_mgr_buffer: MaybeUninit<[u8; MAX_CREDS_SIZE]>, @@ -179,7 +181,7 @@ impl<'a> MatterStack<'a> { buffers <- PooledBuffers::init(0), subscriptions <- Subscriptions::init(), events <- Events::init(dummy_epoch), - networks <- WifiNetworks::init(), + networks <- SharedNetworks::init(WifiNetworks::init()), net_ctl_state <- NetCtlState::init_with_mutex(), btp <- Btp::init(), wireless_mgr_buffer: MaybeUninit::zeroed(), @@ -191,8 +193,9 @@ impl<'a> MatterStack<'a> { // Fully spelled-out types for everything which is passed down as arguments to `embassy-executor` tasks // Necessary, because `embassy-executor` doesn't grok generics +type AppNetworks = SharedNetworks>; type AppNetCtl<'a> = NetCtlWithStatusImpl<'a, FakeWifi>; -type AppWirelessMgr<'a> = WirelessMgr<'a, &'a WifiNetworks<3>, &'a AppNetCtl<'a>>; +type AppWirelessMgr<'a> = WirelessMgr<'a, &'a AppNetworks, &'a AppNetCtl<'a>>; type AppTransport<'a> = ChainedNetwork bool>; type AppHandler<'a> = handler_chain_type!( EpClMatcher => on_off::HandlerAsyncAdaptor>, @@ -200,7 +203,8 @@ type AppHandler<'a> = handler_chain_type!( | EmptyHandler ); type AppCrypto = RustCrypto<'static, WeakTestOnlyRand>; -type AppDmHandler<'a> = WifiHandler<'a, &'a AppNetCtl<'a>, SysHandler<'a, AppHandler<'a>>>; +type AppDmHandler<'a> = + WifiHandler<'a, &'a AppNetworks, &'a AppNetCtl<'a>, SysHandler<'a, AppHandler<'a>>>; type AppDataModel<'a> = DataModel< 'a, DEFAULT_MAX_SUBSCRIPTIONS, @@ -208,6 +212,7 @@ type AppDataModel<'a> = DataModel< &'a AppCrypto, PooledBuffers<10, IMBuffer>, (Node<'a>, &'a AppDmHandler<'a>), + SharedKvBlobStore, >; type AppResponder<'d, 'a> = DefaultResponder< 'd, @@ -217,6 +222,7 @@ type AppResponder<'d, 'a> = DefaultResponder< &'a AppCrypto, PooledBuffers<10, IMBuffer>, (Node<'a>, &'a AppDmHandler<'a>), + SharedKvBlobStore, >; #[cfg_attr(target_os = "none", main)] @@ -324,6 +330,9 @@ fn main() -> ! { let mut rand = unwrap!(crypto.weak_rand()); + let kv_buf = unsafe { stack.psm_buffer.assume_init_mut() }.as_mut_slice(); + let kv = SharedKvBlobStore::new(DummyKvBlobStore, kv_buf); + // A Wireless handler with a sample app cluster (on-off) let handler = mk_static!( AppDmHandler, @@ -334,8 +343,8 @@ fn main() -> ! { 1, TestOnOffDeviceLogic::new(true), ), - net_ctl, &stack.networks, + net_ctl, ) ); @@ -351,6 +360,7 @@ fn main() -> ! { &stack.subscriptions, Some(&stack.events), (NODE, handler), + kv, ) ); @@ -636,20 +646,21 @@ const NODE: Node<'static> = Node { /// The Data Model handler for our Matter device. /// The handler is the root endpoint 0 handler plus the on-off handler and its descriptor. -fn dm_handler<'a, N>( +fn dm_handler<'a, N, T>( mut rand: impl RngCore + Copy, on_off: on_off::OnOffHandler<'a, TestOnOffDeviceLogic, NoLevelControl>, - net_ctl: &'a N, - networks: &'a dyn Networks, -) -> WifiHandler<'a, &'a N, SysHandler<'a, AppHandler<'a>>> + networks: N, + net_ctl: &'a T, +) -> WifiHandler<'a, N, &'a T, SysHandler<'a, AppHandler<'a>>> where - N: NetCtl + NetCtlStatus + WifiDiag, + N: NetworksAccess, + T: NetCtl + NetCtlStatus + WifiDiag, { endpoints::with_wifi( &(), &(), - net_ctl, networks, + net_ctl, rand, endpoints::with_sys( &true, diff --git a/examples/src/bin/bridge.rs b/examples/src/bin/bridge.rs index 1c22d2156..95b133082 100644 --- a/examples/src/bin/bridge.rs +++ b/examples/src/bin/bridge.rs @@ -23,9 +23,10 @@ use core::pin::pin; use std::net::UdpSocket; -use embassy_futures::select::{select, select4}; +use embassy_futures::select::select4; use rand::RngCore; + use rs_matter::crypto::{default_crypto, Crypto}; use rs_matter::dm::clusters::desc::{self, ClusterHandler as _}; use rs_matter::dm::clusters::groups::{self, ClusterHandler as _}; @@ -45,7 +46,7 @@ use rs_matter::dm::{ use rs_matter::error::Error; use rs_matter::pairing::qr::QrTextType; use rs_matter::pairing::DiscoveryCapabilities; -use rs_matter::persist::{Psm, NO_NETWORKS}; +use rs_matter::persist::{DirKvBlobStore, SharedKvBlobStore}; use rs_matter::respond::DefaultResponder; use rs_matter::sc::pase::MAX_COMM_WINDOW_TIMEOUT_SECS; use rs_matter::tlv::{TLVBuilderParent, Utf8StrBuilder}; @@ -67,11 +68,16 @@ fn main() -> Result<(), Error> { ); // Create the Matter object - let matter = Matter::new_default(&TEST_DEV_DET, TEST_DEV_COMM, &TEST_DEV_ATT, MATTER_PORT); + let mut matter = Matter::new_default(&TEST_DEV_DET, TEST_DEV_COMM, &TEST_DEV_ATT, MATTER_PORT); // Need to call this once matter.initialize_transport_buffers()?; + // Persistence + let mut kv_buf = [0; 4096]; + let mut kv = DirKvBlobStore::new_default(); + futures_lite::future::block_on(matter.load_persist(&mut kv, &mut kv_buf))?; + // Create the transport buffers let buffers = PooledBuffers::<10, _>::new(0); @@ -106,6 +112,7 @@ fn main() -> Result<(), Error> { &subscriptions, Some(&events), dm_handler(rand, &on_off_handler_ep2, &on_off_handler_ep3), + SharedKvBlobStore::new(kv, kv_buf.as_mut_slice()), ); // Create a default responder capable of handling up to 3 subscriptions @@ -126,12 +133,6 @@ fn main() -> Result<(), Error> { let mut mdns = pin!(mdns::run_mdns(&matter, &crypto, dm.change_notify())); let mut transport = pin!(matter.run(&crypto, &socket, &socket, &socket)); - // Create, load and run the persister - let mut psm: Psm<4096> = Psm::new(); - let path = std::env::temp_dir().join("rs-matter"); - - psm.load(&path, &matter, NO_NETWORKS, Some(&events))?; - if !matter.is_commissioned() { // If the device is not commissioned yet, print the QR text and code to the console // and enable basic commissioning @@ -142,18 +143,11 @@ fn main() -> Result<(), Error> { matter.open_basic_comm_window(MAX_COMM_WINDOW_TIMEOUT_SECS, &crypto, dm.change_notify())?; } - let mut persist = pin!(psm.run(&path, &matter, NO_NETWORKS, Some(&events))); - // Combine all async tasks in a single one - let all = select4( - &mut transport, - &mut mdns, - &mut persist, - select(&mut respond, &mut dm_job).coalesce(), - ); + let all = select4(&mut transport, &mut mdns, &mut respond, &mut dm_job).coalesce(); // Run with a simple `block_on`. Any local executor would do. - futures_lite::future::block_on(all.coalesce()) + futures_lite::future::block_on(all) } /// The Node meta-data describing our Matter device. diff --git a/examples/src/bin/chip_tool_tests.rs b/examples/src/bin/chip_tool_tests.rs index 03093e707..ee83890ae 100644 --- a/examples/src/bin/chip_tool_tests.rs +++ b/examples/src/bin/chip_tool_tests.rs @@ -21,11 +21,10 @@ use core::pin::pin; use std::net::UdpSocket; -use std::path::PathBuf; use async_signal::{Signal, Signals}; -use embassy_futures::select::{select3, select4}; +use embassy_futures::select::select3; use futures_lite::StreamExt; @@ -58,7 +57,7 @@ use rs_matter::dm::{ use rs_matter::error::Error; use rs_matter::pairing::qr::QrTextType; use rs_matter::pairing::DiscoveryCapabilities; -use rs_matter::persist::{Psm, NO_NETWORKS}; +use rs_matter::persist::{FileKvBlobStore, SharedKvBlobStore}; use rs_matter::respond::DefaultResponder; use rs_matter::sc::pase::MAX_COMM_WINDOW_TIMEOUT_SECS; use rs_matter::transport::MATTER_SOCKET_BIND_ADDR; @@ -73,16 +72,6 @@ use static_cell::StaticCell; #[path = "../common/mdns.rs"] mod mdns; -/// The `chip-tool` tests expect the persistent storage location -/// to be `/tmp/chip_kvs`. -/// -/// Moreover, this _must_ be a file rather than a directory. -/// -/// While there seem to be some facilities to change that in some of the Python scripts, -/// these facilities are simply not exposed at the top level test suite Python runner. -/// TODO: Open a bug for that (and for the single-file expectation) in the `connectedhomeip` repo. -const PERSIST_FILE_NAME: &str = "/tmp/chip_kvs"; - // Statically allocate in BSS the bigger objects // `rs-matter` supports efficient initialization of BSS objects (with `init`) // as well as just allocating the objects on-stack or on the heap. @@ -90,7 +79,7 @@ static MATTER: StaticCell = StaticCell::new(); static BUFFERS: StaticCell> = StaticCell::new(); static SUBSCRIPTIONS: StaticCell = StaticCell::new(); static EVENTS: StaticCell = StaticCell::new(); -static PSM: StaticCell> = StaticCell::new(); +static KV_BUF: StaticCell<[u8; 4096]> = StaticCell::new(); static UNIT_TESTING_DATA: StaticCell> = StaticCell::new(); fn main() -> Result<(), Error> { @@ -126,6 +115,11 @@ fn main() -> Result<(), Error> { // Need to call this once matter.initialize_transport_buffers()?; + // Persistence + let kv_buf = KV_BUF.uninit().init_zeroed().as_mut_slice(); + let mut kv = FileKvBlobStore::new_default(); + futures_lite::future::block_on(matter.load_persist(&mut kv, kv_buf))?; + // Create the transport buffers let buffers = BUFFERS.uninit().init_with(PooledBuffers::init(0)); @@ -176,6 +170,7 @@ fn main() -> Result<(), Error> { &on_off_handler_1, &on_off_handler_2, ), + SharedKvBlobStore::new(kv, kv_buf), ); // Create a default responder capable of handling up to 3 subscriptions @@ -207,18 +202,6 @@ fn main() -> Result<(), Error> { let mut mdns = pin!(mdns::run_mdns(matter, &crypto, dm.change_notify())); let mut transport = pin!(matter.run(&crypto, &socket, &socket, &socket)); - // Create, load and run the persister - let psm = PSM.uninit().init_with(Psm::init()); - let path = PathBuf::from(PERSIST_FILE_NAME); - - info!( - "Persist memory: Persist (BSS)={}B, Persist fut (stack)={}B", - core::mem::size_of::>(), - core::mem::size_of_val(&psm.run(&path, matter, NO_NETWORKS, Some(events))) - ); - - psm.load(&path, matter, NO_NETWORKS, Some(events))?; - // We need to always print the QR text, because the test runner expects it to be printed // even if the device is already commissioned matter.print_standard_qr_text(DiscoveryCapabilities::IP)?; @@ -232,8 +215,6 @@ fn main() -> Result<(), Error> { matter.open_basic_comm_window(MAX_COMM_WINDOW_TIMEOUT_SECS, &crypto, dm.change_notify())?; } - let mut persist = pin!(psm.run(&path, matter, NO_NETWORKS, Some(events))); - // Listen to SIGTERM because at the end of the test we'll receive it let mut term_signal = Signals::new([Signal::Term])?; let mut term = pin!(async { @@ -242,10 +223,9 @@ fn main() -> Result<(), Error> { }); // Combine all async tasks in a single one - let all = select4( + let all = select3( &mut transport, &mut mdns, - &mut persist, select3(&mut respond, &mut dm_job, &mut term).coalesce(), ); diff --git a/examples/src/bin/dimmable_light.rs b/examples/src/bin/dimmable_light.rs index dab212b8d..7deef6187 100644 --- a/examples/src/bin/dimmable_light.rs +++ b/examples/src/bin/dimmable_light.rs @@ -26,7 +26,7 @@ use std::io::{Read, Write}; use std::net::UdpSocket; use std::path::PathBuf; -use embassy_futures::select::{select3, select4}; +use embassy_futures::select::select3; use async_signal::{Signal, Signals}; use log::{error, info, trace}; @@ -58,7 +58,7 @@ use rs_matter::dm::{ use rs_matter::error::{Error, ErrorCode}; use rs_matter::pairing::qr::QrTextType; use rs_matter::pairing::DiscoveryCapabilities; -use rs_matter::persist::{Psm, NO_NETWORKS}; +use rs_matter::persist::SharedKvBlobStore; use rs_matter::respond::DefaultResponder; use rs_matter::sc::pase::MAX_COMM_WINDOW_TIMEOUT_SECS; use rs_matter::tlv::Nullable; @@ -80,10 +80,7 @@ static MATTER: StaticCell = StaticCell::new(); static BUFFERS: StaticCell> = StaticCell::new(); static SUBSCRIPTIONS: StaticCell = StaticCell::new(); static EVENTS: StaticCell = StaticCell::new(); -static PSM: StaticCell> = StaticCell::new(); - -#[cfg(feature = "chip-test")] -const PERSIST_FILE_NAME: &str = "/tmp/chip_kvs"; +static KV_BUF: StaticCell<[u8; 4096]> = StaticCell::new(); fn main() -> Result<(), Error> { let thread = std::thread::Builder::new() @@ -135,6 +132,14 @@ fn run() -> Result<(), Error> { // Need to call this once matter.initialize_transport_buffers()?; + // Persistence + let kv_buf = KV_BUF.uninit().init_zeroed().as_mut_slice(); + #[cfg(feature = "chip-test")] + let mut kv = rs_matter::persist::FileKvBlobStore::new_default(); + #[cfg(not(feature = "chip-test"))] + let mut kv = rs_matter::persist::DirKvBlobStore::new_default(); + futures_lite::future::block_on(matter.load_persist(&mut kv, kv_buf))?; + // Create the transport buffers let buffers = BUFFERS.uninit().init_with(PooledBuffers::init(0)); @@ -181,6 +186,7 @@ fn run() -> Result<(), Error> { subscriptions, Some(events), dm_handler(rand, &on_off_handler, &level_control_handler), + SharedKvBlobStore::new(kv, kv_buf), ); // Create a default responder capable of handling up to 3 subscriptions @@ -211,22 +217,6 @@ fn run() -> Result<(), Error> { let mut mdns = pin!(mdns::run_mdns(matter, &crypto, dm.change_notify())); let mut transport = pin!(matter.run(&crypto, &socket, &socket, &socket)); - // Create, load and run the persister - let psm = PSM.uninit().init_with(Psm::init()); - #[cfg(not(feature = "chip-test"))] - let path = std::env::temp_dir().join("rs-matter"); - #[cfg(feature = "chip-test")] - let path = PathBuf::from(PERSIST_FILE_NAME); - - info!( - "Persist memory: Persist (BSS)={}B, Persist fut (stack)={}B, Persist path={}", - core::mem::size_of::>(), - core::mem::size_of_val(&psm.run(&path, matter, NO_NETWORKS, Some(events))), - path.as_path().to_str().unwrap_or("none") - ); - - psm.load(&path, matter, NO_NETWORKS, Some(events))?; - // We need to always print the QR text, because the test runner expects it to be printed // even if the device is already commissioned matter.print_standard_qr_text(DiscoveryCapabilities::IP)?; @@ -240,8 +230,6 @@ fn run() -> Result<(), Error> { matter.open_basic_comm_window(MAX_COMM_WINDOW_TIMEOUT_SECS, &crypto, dm.change_notify())?; } - let mut persist = pin!(psm.run(&path, matter, NO_NETWORKS, Some(events))); - // Listen to SIGTERM because at the end of the test we'll receive it let mut term_signal = Signals::new([Signal::Term])?; let mut term = pin!(async { @@ -250,10 +238,9 @@ fn run() -> Result<(), Error> { }); // Combine all async tasks in a single one - let all = select4( + let all = select3( &mut transport, &mut mdns, - &mut persist, select3(&mut respond, &mut dm_job, &mut term).coalesce(), ); diff --git a/examples/src/bin/media_player.rs b/examples/src/bin/media_player.rs index 074803f52..350aa1827 100644 --- a/examples/src/bin/media_player.rs +++ b/examples/src/bin/media_player.rs @@ -27,7 +27,7 @@ use core::pin::pin; use std::net::UdpSocket; -use embassy_futures::select::{select, select4}; +use embassy_futures::select::select4; use log::info; @@ -69,7 +69,7 @@ use rs_matter::dm::{ use rs_matter::error::{Error, ErrorCode}; use rs_matter::pairing::qr::QrTextType; use rs_matter::pairing::DiscoveryCapabilities; -use rs_matter::persist::{Psm, NO_NETWORKS}; +use rs_matter::persist::{DirKvBlobStore, SharedKvBlobStore}; use rs_matter::respond::DefaultResponder; use rs_matter::sc::pase::MAX_COMM_WINDOW_TIMEOUT_SECS; use rs_matter::tlv::{TLVBuilderParent, Utf8StrArrayBuilder, Utf8StrBuilder}; @@ -87,11 +87,16 @@ fn main() -> Result<(), Error> { ); // Create the Matter object - let matter = Matter::new_default(&TEST_DEV_DET, TEST_DEV_COMM, &TEST_DEV_ATT, MATTER_PORT); + let mut matter = Matter::new_default(&TEST_DEV_DET, TEST_DEV_COMM, &TEST_DEV_ATT, MATTER_PORT); // Need to call this once matter.initialize_transport_buffers()?; + // Persistence + let mut kv_buf = [0; 4096]; + let mut kv = DirKvBlobStore::new_default(); + futures_lite::future::block_on(matter.load_persist(&mut kv, &mut kv_buf))?; + // Create the transport buffers let buffers = PooledBuffers::<10, _>::new(0); @@ -121,6 +126,7 @@ fn main() -> Result<(), Error> { &subscriptions, Some(&events), dm_handler(rand, &on_off_handler), + SharedKvBlobStore::new(kv, kv_buf.as_mut_slice()), ); // Create a default responder capable of handling up to 3 subscriptions @@ -141,12 +147,6 @@ fn main() -> Result<(), Error> { let mut mdns = pin!(mdns::run_mdns(&matter, &crypto, dm.change_notify())); let mut transport = pin!(matter.run(&crypto, &socket, &socket, &socket)); - // Create, load and run the persister - let mut psm: Psm<4096> = Psm::new(); - let path = std::env::temp_dir().join("rs-matter"); - - psm.load(&path, &matter, NO_NETWORKS, Some(&events))?; - if !matter.is_commissioned() { // If the device is not commissioned yet, print the QR text and code to the console // and enable basic commissioning @@ -157,18 +157,11 @@ fn main() -> Result<(), Error> { matter.open_basic_comm_window(MAX_COMM_WINDOW_TIMEOUT_SECS, &crypto, dm.change_notify())?; } - let mut persist = pin!(psm.run(&path, &matter, NO_NETWORKS, Some(&events))); - // Combine all async tasks in a single one - let all = select4( - &mut transport, - &mut mdns, - &mut persist, - select(&mut respond, &mut dm_job).coalesce(), - ); + let all = select4(&mut transport, &mut mdns, &mut respond, &mut dm_job).coalesce(); // Run with a simple `block_on`. Any local executor would do. - futures_lite::future::block_on(all.coalesce()) + futures_lite::future::block_on(all) } /// The Node meta-data describing our Matter device. diff --git a/examples/src/bin/onoff_light.rs b/examples/src/bin/onoff_light.rs index bda294fbd..3f189e45c 100644 --- a/examples/src/bin/onoff_light.rs +++ b/examples/src/bin/onoff_light.rs @@ -21,11 +21,12 @@ use core::pin::pin; use std::net::UdpSocket; -use embassy_futures::select::{select, select4}; +use embassy_futures::select::select4; use log::info; use rand::RngCore; + use rs_matter::crypto::{default_crypto, Crypto}; use rs_matter::dm::clusters::desc::{self, ClusterHandler as _}; use rs_matter::dm::clusters::groups::{self, ClusterHandler as _}; @@ -46,7 +47,7 @@ use rs_matter::dm::{ use rs_matter::error::Error; use rs_matter::pairing::qr::QrTextType; use rs_matter::pairing::DiscoveryCapabilities; -use rs_matter::persist::{Psm, NO_NETWORKS}; +use rs_matter::persist::{DirKvBlobStore, SharedKvBlobStore}; use rs_matter::respond::DefaultResponder; use rs_matter::sc::pase::MAX_COMM_WINDOW_TIMEOUT_SECS; use rs_matter::transport::MATTER_SOCKET_BIND_ADDR; @@ -66,7 +67,7 @@ mod mdns; static MATTER: StaticCell = StaticCell::new(); static BUFFERS: StaticCell> = StaticCell::new(); static SUBSCRIPTIONS: StaticCell = StaticCell::new(); -static PSM: StaticCell> = StaticCell::new(); +static KV_BUF: StaticCell<[u8; 4096]> = StaticCell::new(); fn main() -> Result<(), Error> { let thread = std::thread::Builder::new() @@ -107,6 +108,11 @@ fn run() -> Result<(), Error> { // Need to call this once matter.initialize_transport_buffers()?; + // Persistence + let kv_buf = KV_BUF.uninit().init_zeroed().as_mut_slice(); + let mut kv = DirKvBlobStore::new_default(); + futures_lite::future::block_on(matter.load_persist(&mut kv, kv_buf))?; + // Create the transport buffers let buffers = BUFFERS.uninit().init_with(PooledBuffers::init(0)); @@ -135,6 +141,7 @@ fn run() -> Result<(), Error> { subscriptions, NO_EVENTS, dm_handler(rand, &on_off_handler), + SharedKvBlobStore::new(kv, kv_buf), ); // Create a default responder capable of handling up to 3 subscriptions @@ -166,18 +173,6 @@ fn run() -> Result<(), Error> { let mut mdns = pin!(mdns::run_mdns(matter, &crypto, dm.change_notify())); let mut transport = pin!(matter.run(&crypto, &socket, &socket, &socket)); - // Create, load and run the persister - let psm = PSM.uninit().init_with(Psm::init()); - let path = std::env::temp_dir().join("rs-matter"); - - info!( - "Persist memory: Persist (BSS)={}B, Persist fut (stack)={}B", - core::mem::size_of::>(), - core::mem::size_of_val(&psm.run(&path, matter, NO_NETWORKS, NO_EVENTS)) - ); - - psm.load(&path, matter, NO_NETWORKS, NO_EVENTS)?; - if !matter.is_commissioned() { // If the device is not commissioned yet, print the QR text and code to the console // and enable basic commissioning @@ -188,18 +183,11 @@ fn run() -> Result<(), Error> { matter.open_basic_comm_window(MAX_COMM_WINDOW_TIMEOUT_SECS, &crypto, dm.change_notify())?; } - let mut persist = pin!(psm.run(&path, matter, NO_NETWORKS, NO_EVENTS)); - // Combine all async tasks in a single one - let all = select4( - &mut transport, - &mut mdns, - &mut persist, - select(&mut respond, &mut dm_job).coalesce(), - ); + let all = select4(&mut transport, &mut mdns, &mut respond, &mut dm_job).coalesce(); // Run with a simple `block_on`. Any local executor would do. - futures_lite::future::block_on(all.coalesce()) + futures_lite::future::block_on(all) } /// The Node meta-data describing our Matter device. diff --git a/examples/src/bin/onoff_light_bt.rs b/examples/src/bin/onoff_light_bt.rs index 680400c56..a013bb3b1 100644 --- a/examples/src/bin/onoff_light_bt.rs +++ b/examples/src/bin/onoff_light_bt.rs @@ -44,7 +44,9 @@ use rs_matter::crypto::{default_crypto, Crypto}; use rs_matter::dm::clusters::desc::{self, ClusterHandler as _}; use rs_matter::dm::clusters::groups::{self, ClusterHandler as _}; use rs_matter::dm::clusters::level_control::LevelControlHooks; -use rs_matter::dm::clusters::net_comm::{NetCtl, NetCtlStatus, NetworkType, Networks}; +use rs_matter::dm::clusters::net_comm::{ + NetCtl, NetCtlStatus, NetworkType, NetworksAccess, SharedNetworks, +}; use rs_matter::dm::clusters::on_off::{self, test::TestOnOffDeviceLogic, OnOffHooks}; use rs_matter::dm::clusters::wifi_diag::WifiDiag; use rs_matter::dm::devices::test::{DAC_PRIVKEY, TEST_DEV_ATT, TEST_DEV_COMM, TEST_DEV_DET}; @@ -61,7 +63,7 @@ use rs_matter::dm::{ use rs_matter::error::Error; use rs_matter::pairing::qr::QrTextType; use rs_matter::pairing::DiscoveryCapabilities; -use rs_matter::persist::Psm; +use rs_matter::persist::{DirKvBlobStore, SharedKvBlobStore}; use rs_matter::respond::DefaultResponder; use rs_matter::sc::pase::MAX_COMM_WINDOW_TIMEOUT_SECS; #[cfg(target_os = "linux")] @@ -122,11 +124,20 @@ fn main() -> Result<(), Error> { fn run(connection: &Connection, net_ctl: N) -> Result<(), Error> { // Create the Matter object - let matter = Matter::new_default(&TEST_DEV_DET, TEST_DEV_COMM, &TEST_DEV_ATT, MATTER_PORT); + let mut matter = Matter::new_default(&TEST_DEV_DET, TEST_DEV_COMM, &TEST_DEV_ATT, MATTER_PORT); // Need to call this once matter.initialize_transport_buffers()?; + // A storage for the Wifi networks + let mut networks = WifiNetworks::<3>::new(); + + // Persistence + let mut kv_buf = [0; 4096]; + let mut kv = DirKvBlobStore::new_default(); + futures_lite::future::block_on(matter.load_persist(&mut kv, &mut kv_buf))?; + futures_lite::future::block_on(networks.load_persist(&mut kv, &mut kv_buf))?; + // Create the transport buffers let buffers = PooledBuffers::<10, _>::new(0); @@ -148,9 +159,6 @@ fn run(connection: &Connection, net_ctl: N) -> Result<(), TestOnOffDeviceLogic::new(true), ); - // A storage for the Wifi networks - let networks = WifiNetworks::<3>::new(); - // The network controller let net_ctl_state = NetCtlState::new_with_mutex(); @@ -163,7 +171,13 @@ fn run(connection: &Connection, net_ctl: N) -> Result<(), &buffers, &subscriptions, Some(&events), - dm_handler(rand, &on_off_handler, &net_ctl, &networks), + dm_handler( + rand, + &on_off_handler, + SharedNetworks::new(networks), + &net_ctl, + ), + SharedKvBlobStore::new(kv, kv_buf.as_mut_slice()), ); // Create a default responder capable of handling up to 3 subscriptions @@ -177,14 +191,6 @@ fn run(connection: &Connection, net_ctl: N) -> Result<(), // Run the background job of the data model let mut dm_job = pin!(dm.run()); - // Create, load and run the persister - let mut psm: Psm<4096> = Psm::new(); - let path = std::env::temp_dir().join("rs-matter"); - - psm.load(&path, &matter, Some(&networks), Some(&events))?; - - let mut persist = pin!(psm.run(&path, &matter, Some(&networks), Some(&events))); - // Create and run the mDNS responder let mut mdns = pin!(mdns::run_mdns(&matter, &crypto, dm.change_notify())); @@ -230,7 +236,7 @@ fn run(connection: &Connection, net_ctl: N) -> Result<(), let all = select4( &mut transport, &mut bluetooth, - select(&mut wifi_prov_task, &mut persist).coalesce(), + &mut wifi_prov_task, select(&mut respond, &mut dm_job).coalesce(), ); @@ -248,15 +254,10 @@ fn run(connection: &Connection, net_ctl: N) -> Result<(), let mut transport = pin!(matter.run(&crypto, &udp, &udp, &udp)); // Combine all async tasks in a single one - let all = select4( - &mut transport, - &mut mdns, - &mut persist, - select(&mut respond, &mut dm_job).coalesce(), - ); + let all = select4(&mut transport, &mut mdns, &mut respond, &mut dm_job).coalesce(); // Run with a simple `block_on`. Any local executor would do. - futures_lite::future::block_on(all.coalesce()) + futures_lite::future::block_on(all) } /// The Node meta-data describing our Matter device. @@ -278,22 +279,23 @@ const NODE: Node<'static> = Node { /// The Data Model handler + meta-data for our Matter device. /// The handler is the root endpoint 0 handler plus the on-off handler and its descriptor. -fn dm_handler<'a, OH: OnOffHooks, LH: LevelControlHooks, N>( +fn dm_handler<'a, OH: OnOffHooks, LH: LevelControlHooks, N, T>( mut rand: impl RngCore + Copy, on_off: &'a on_off::OnOffHandler<'a, OH, LH>, - net_ctl: &'a N, - networks: &'a dyn Networks, + networks: N, + net_ctl: &'a T, ) -> impl AsyncMetadata + AsyncHandler + 'a where - N: NetCtl + NetCtlStatus + WifiDiag, + N: NetworksAccess + 'a, + T: NetCtl + NetCtlStatus + WifiDiag, { ( NODE, endpoints::with_wifi( &(), &UnixNetifs, - net_ctl, networks, + net_ctl, rand, endpoints::with_sys( &true, diff --git a/examples/src/bin/onoff_light_work_stealing.rs b/examples/src/bin/onoff_light_work_stealing.rs index c1697a1d2..4b5615dbb 100644 --- a/examples/src/bin/onoff_light_work_stealing.rs +++ b/examples/src/bin/onoff_light_work_stealing.rs @@ -44,7 +44,7 @@ use rs_matter::dm::{Async, DataModel, Dataver, EmptyHandler, Endpoint, EpClMatch use rs_matter::error::Error; use rs_matter::pairing::qr::QrTextType; use rs_matter::pairing::DiscoveryCapabilities; -use rs_matter::persist::{Psm, NO_NETWORKS}; +use rs_matter::persist::{DirKvBlobStore, SharedKvBlobStore}; use rs_matter::respond::DefaultResponder; use rs_matter::sc::pase::MAX_COMM_WINDOW_TIMEOUT_SECS; use rs_matter::transport::MATTER_SOCKET_BIND_ADDR; @@ -70,8 +70,8 @@ type AppDmHandler<'a> = EthHandler<'a, SysHandler<'a, AppHandler<'a>>>; static MATTER: StaticCell = StaticCell::new(); static BUFFERS: StaticCell> = StaticCell::new(); static SUBSCRIPTIONS: StaticCell = StaticCell::new(); -static PSM: StaticCell> = StaticCell::new(); static CRYPTO: StaticCell> = StaticCell::new(); +static KV_BUF: StaticCell<[u8; 4096]> = StaticCell::new(); fn main() -> Result<(), Error> { let thread = std::thread::Builder::new() @@ -101,7 +101,7 @@ fn run() -> Result<(), Error> { core::mem::size_of::() ); - let matter = &*MATTER.uninit().init_with(Matter::init( + let matter = MATTER.uninit().init_with(Matter::init( &TEST_DEV_DET, TEST_DEV_COMM, &TEST_DEV_ATT, @@ -112,6 +112,11 @@ fn run() -> Result<(), Error> { // Need to call this once matter.initialize_transport_buffers()?; + // Persistence + let kv_buf = KV_BUF.uninit().init_zeroed().as_mut_slice(); + let mut kv = DirKvBlobStore::new_default(); + futures_lite::future::block_on(matter.load_persist(&mut kv, kv_buf))?; + // Create the transport buffers let buffers = &*BUFFERS.uninit().init_with(PooledBuffers::init(0)); @@ -140,6 +145,7 @@ fn run() -> Result<(), Error> { subscriptions, NO_EVENTS, (NODE, dm_handler(rand, on_off_handler)), + SharedKvBlobStore::new(kv, kv_buf), ); // Create a default responder capable of handling up to 3 subscriptions @@ -173,18 +179,6 @@ fn run() -> Result<(), Error> { let mdns = mdns::run_mdns(matter, crypto, dm.change_notify()); let transport = matter.run(crypto, &socket, &socket, &socket); - // Create, load and run the persister - let psm = PSM.uninit().init_with(Psm::init()); - let path = std::env::temp_dir().join("rs-matter"); - - info!( - "Persist memory: Persist (BSS)={}B, Persist fut (stack)={}B", - core::mem::size_of::>(), - core::mem::size_of_val(&psm.run(&path, matter, NO_NETWORKS, NO_EVENTS)) - ); - - psm.load(&path, matter, NO_NETWORKS, NO_EVENTS)?; - if !matter.is_commissioned() { // If the device is not commissioned yet, print the QR text and code to the console // and enable basic commissioning @@ -195,13 +189,10 @@ fn run() -> Result<(), Error> { matter.open_basic_comm_window(MAX_COMM_WINDOW_TIMEOUT_SECS, crypto, dm.change_notify())?; } - let persist = psm.run(&path, matter, NO_NETWORKS, NO_EVENTS); - let executor = async_executor::Executor::new(); executor.spawn(transport).detach(); executor.spawn(mdns).detach(); - executor.spawn(persist).detach(); // NOTE: Commented out because compiling this line blocks forever //executor.spawn(dm_job).detach(); diff --git a/examples/src/bin/speaker.rs b/examples/src/bin/speaker.rs index b8920aa9a..b5ee0a4f8 100644 --- a/examples/src/bin/speaker.rs +++ b/examples/src/bin/speaker.rs @@ -22,7 +22,7 @@ use core::pin::pin; use std::net::UdpSocket; -use embassy_futures::select::{select, select4}; +use embassy_futures::select::select4; use rand::RngCore; use rs_matter::crypto::{default_crypto, Crypto}; @@ -46,7 +46,7 @@ use rs_matter::dm::{ use rs_matter::error::Error; use rs_matter::pairing::qr::QrTextType; use rs_matter::pairing::DiscoveryCapabilities; -use rs_matter::persist::{Psm, NO_NETWORKS}; +use rs_matter::persist::{DirKvBlobStore, SharedKvBlobStore}; use rs_matter::respond::DefaultResponder; use rs_matter::sc::pase::MAX_COMM_WINDOW_TIMEOUT_SECS; use rs_matter::tlv::Nullable; @@ -64,11 +64,16 @@ fn main() -> Result<(), Error> { ); // Create the Matter object - let matter = Matter::new_default(&TEST_DEV_DET, TEST_DEV_COMM, &TEST_DEV_ATT, MATTER_PORT); + let mut matter = Matter::new_default(&TEST_DEV_DET, TEST_DEV_COMM, &TEST_DEV_ATT, MATTER_PORT); // Need to call this once matter.initialize_transport_buffers()?; + // Persistence + let mut kv_buf = [0; 4096]; + let mut kv = DirKvBlobStore::new_default(); + futures_lite::future::block_on(matter.load_persist(&mut kv, &mut kv_buf))?; + // Create the transport buffers let buffers = PooledBuffers::<10, _>::new(0); @@ -117,6 +122,7 @@ fn main() -> Result<(), Error> { &subscriptions, Some(&events), dm_handler(rand, &on_off_handler, &level_control_handler), + SharedKvBlobStore::new(kv, kv_buf.as_mut_slice()), ); // Create a default responder capable of handling up to 3 subscriptions @@ -137,12 +143,6 @@ fn main() -> Result<(), Error> { let mut mdns = pin!(mdns::run_mdns(&matter, &crypto, dm.change_notify())); let mut transport = pin!(matter.run(&crypto, &socket, &socket, &socket)); - // Create, load and run the persister - let mut psm: Psm<4096> = Psm::new(); - let path = std::env::temp_dir().join("rs-matter"); - - psm.load(&path, &matter, NO_NETWORKS, Some(&events))?; - if !matter.is_commissioned() { // If the device is not commissioned yet, print the QR text and code to the console // and enable basic commissioning @@ -153,18 +153,11 @@ fn main() -> Result<(), Error> { matter.open_basic_comm_window(MAX_COMM_WINDOW_TIMEOUT_SECS, &crypto, dm.change_notify())?; } - let mut persist = pin!(psm.run(&path, &matter, NO_NETWORKS, Some(&events))); - // Combine all async tasks in a single one - let all = select4( - &mut transport, - &mut mdns, - &mut persist, - select(&mut respond, &mut dm_job).coalesce(), - ); + let all = select4(&mut transport, &mut mdns, &mut respond, &mut dm_job).coalesce(); // Run with a simple `block_on`. Any local executor would do. - futures_lite::future::block_on(all.coalesce()) + futures_lite::future::block_on(all) } /// The Node meta-data describing our Matter device. diff --git a/examples/src/common/mdns.rs b/examples/src/common/mdns.rs index fd534b998..4ab5aaae4 100644 --- a/examples/src/common/mdns.rs +++ b/examples/src/common/mdns.rs @@ -31,12 +31,12 @@ pub async fn run_mdns( ) -> Result<(), Error> { #[cfg(feature = "astro-dnssd")] rs_matter::transport::network::mdns::astro::AstroMdnsResponder::new(matter) - .run(crypto, notify) + .run(|endpt_id, clust_id, attr_id| notify.notify(endpt_id, clust_id, attr_id)) .await?; #[cfg(all(feature = "zeroconf", not(feature = "astro-dnssd")))] rs_matter::transport::network::mdns::zeroconf::ZeroconfMdnsResponder::new(matter) - .run(crypto, notify) + .run(|endpt_id, clust_id, attr_id| notify.notify(endpt_id, clust_id, attr_id)) .await?; #[cfg(all( @@ -46,8 +46,7 @@ pub async fn run_mdns( rs_matter::transport::network::mdns::resolve::ResolveMdnsResponder::new(matter) .run( &rs_matter::utils::zbus::Connection::system().await.unwrap(), - crypto, - notify, + |endpt_id, clust_id, attr_id| notify.notify(endpt_id, clust_id, attr_id), ) .await?; @@ -58,8 +57,7 @@ pub async fn run_mdns( rs_matter::transport::network::mdns::avahi::AvahiMdnsResponder::new(matter) .run( &rs_matter::utils::zbus::Connection::system().await.unwrap(), - crypto, - notify, + |endpt_id, clust_id, attr_id| notify.notify(endpt_id, clust_id, attr_id), ) .await?; diff --git a/rs-matter/Cargo.toml b/rs-matter/Cargo.toml index 8b8f4ee4c..0e1062548 100644 --- a/rs-matter/Cargo.toml +++ b/rs-matter/Cargo.toml @@ -28,7 +28,7 @@ max-fabrics-3 = [] # default max-fabrics-2 = [] max-fabrics-1 = [] -# Number of fabrics +# Number of groups per fabrics max-groups-per-fabric-32 = [] max-groups-per-fabric-16 = [] max-groups-per-fabric-12 = [] @@ -38,11 +38,13 @@ max-groups-per-fabric-6 = [] max-groups-per-fabric-5 = [] max-groups-per-fabric-4 = [] +# Number of group keys per fabric max-group-keys-per-fabric-5 = [] max-group-keys-per-fabric-4 = [] max-group-keys-per-fabric-3 = [] max-group-keys-per-fabric-2 = [] +# Number of endpoints that can be members of a group (per fabric) max-group-endpoints-per-fabric-5 = [] max-group-endpoints-per-fabric-4 = [] max-group-endpoints-per-fabric-3 = [] @@ -114,8 +116,8 @@ max-btp-sessions-1 = [] # default # General astro-dnssd = ["os", "dep:astro-dnssd"] zeroconf = ["os", "dep:zeroconf"] -zbus = ["dep:zbus", "os", "futures-lite", "libc", "uuid", "async-io", "async-channel"] -os = ["std", "backtrace", "critical-section/std", "embassy-sync/std", "embassy-time/std"] +zbus = ["dep:zbus", "os", "futures-lite", "libc", "uuid", "async-channel"] +os = ["std", "backtrace", "async-io", "critical-section/std", "embassy-sync/std", "embassy-time/std"] std = ["alloc", "rand"] backtrace = [] alloc = ["defmt?/alloc"] diff --git a/rs-matter/src/acl.rs b/rs-matter/src/acl.rs index 5e993695e..6d49597d3 100644 --- a/rs-matter/src/acl.rs +++ b/rs-matter/src/acl.rs @@ -415,7 +415,8 @@ impl<'a> Accessor<'a> { }; fabric - .group_get(group_id) + .groups() + .get(group_id) .is_some_and(|e| e.endpoints.contains(&endpoint_id)) }) } @@ -846,13 +847,11 @@ pub(crate) mod tests { } fn add_acl(matter: &Matter<'_>, fab_idx: NonZeroU8, entry: AclEntry) -> Result { - matter.with_state(|state| state.fabrics.acl_add(fab_idx, entry)) + matter.with_state(|state| state.fabrics.fabric_mut(fab_idx)?.acl_add(entry)) } fn remove_all_acl(matter: &Matter<'_>, fab_idx: NonZeroU8) { - matter.with_state(|state| { - state.fabrics.acl_remove_all(fab_idx).unwrap(); - }) + matter.with_state(|state| state.fabrics.fabric_mut(fab_idx).unwrap().acl_remove_all()) } #[test] diff --git a/rs-matter/src/dm.rs b/rs-matter/src/dm.rs index f5e0c7ea6..39186b26d 100644 --- a/rs-matter/src/dm.rs +++ b/rs-matter/src/dm.rs @@ -31,6 +31,7 @@ use crate::im::{ StatusResp, SubscribeReq, SubscribeResp, TimedReq, WriteReq, WriteRespTag, PROTO_ID_INTERACTION_MODEL, }; +use crate::persist::KvBlobStoreAccess; use crate::respond::ExchangeHandler; use crate::tlv::{get_root_node_struct, FromTLV, Nullable, TLVElement, TLVTag, TLVWrite}; use crate::transport::exchange::{Exchange, MAX_EXCHANGE_RX_BUF_SIZE, MAX_EXCHANGE_TX_BUF_SIZE}; @@ -73,24 +74,26 @@ struct SubscriptionBuffer { /// An `ExchangeHandler` implementation capable of handling responder exchanges for the Interaction Model protocol. /// The implementation needs a `DataModelHandler` instance to interact with the underlying clusters of the data model. -pub struct DataModel<'a, const NS: usize, const NE: usize, C, B, T> +pub struct DataModel<'a, const NS: usize, const NE: usize, C, B, T, S> where B: BufferAccess, { matter: &'a Matter<'a>, crypto: C, buffers: &'a B, + kv: S, subscriptions: &'a Subscriptions, subscriptions_buffers: Mutex>, NS>>>, events: Option<&'a Events>, handler: T, } -impl<'a, const NS: usize, const NE: usize, C, B, T> DataModel<'a, NS, NE, C, B, T> +impl<'a, const NS: usize, const NE: usize, C, B, T, S> DataModel<'a, NS, NE, C, B, T, S> where C: Crypto, B: BufferAccess, T: DataModelHandler, + S: KvBlobStoreAccess, { /// Create the data model. /// @@ -99,9 +102,11 @@ where /// - `buffers` - a reference to an implementation of `BufferAccess` which is used for allocating RX and TX buffers on the fly, when necessary /// - `subscriptions` - a reference to a `Subscriptions` struct which is used for managing subscriptions. `N` designates the maximum /// number of subscriptions that can be managed by this handler. + /// - `events` - an optional reference to an `Events` struct which is used for managing events in the data model. `N` designates the maximum number of events that can be buffered in the event management system. /// - `handler` - an instance of type `T` which implements the `DataModelHandler` trait. This instance is used for interacting with the underlying /// clusters of the data model. Note that the expectations is for the user to provide a handler that handles the Matter system clusters /// as well (Endpoint 0), possibly by decorating her own clusters with the `rs_matter::dm::root_endpoint::with_` methods + /// - `kv` - an instance of type `S` which implements the `KvBlobStoreAccess` trait. This instance is used for interacting with the key-value blob store. #[inline(always)] pub const fn new( matter: &'a Matter<'a>, @@ -110,6 +115,7 @@ where subscriptions: &'a Subscriptions, events: Option<&'a Events>, handler: T, + kv: S, ) -> Self { Self { matter, @@ -119,6 +125,7 @@ where subscriptions_buffers: Mutex::new(RefCell::new(Vec::new())), events, handler, + kv, } } @@ -131,6 +138,11 @@ where &self.crypto } + /// Return a reference to the `KvBlobStoreAccess` instance used by this data model for interacting with the key-value blob store. + pub const fn kv(&self) -> &S { + &self.kv + } + /// Return a reference to the `ChangeNotify` instance used by this data model for tracking changes in the data model /// and notifying the subscription processing task about them. pub const fn change_notify(&self) -> &dyn ChangeNotify { @@ -144,6 +156,7 @@ where &self.crypto, &self.handler, self.buffers, + &self.kv, self.change_notify(), ); @@ -209,7 +222,6 @@ where if !is_groupcast { exchange.acknowledge().await?; } - exchange.matter().notify_persist(); Ok(()) } @@ -234,7 +246,13 @@ where &req, &node, None, - HandlerInvoker::new(exchange, &self.crypto, &self.handler, &self.buffers), + HandlerInvoker::new( + exchange, + &self.crypto, + &self.handler, + &self.buffers, + &self.kv, + ), EventReader::new(0), self.events, ); @@ -279,7 +297,13 @@ where let mut resp = WriteResponder::new( &req, &node, - HandlerInvoker::new(exchange, &self.crypto, &self.handler, &self.buffers), + HandlerInvoker::new( + exchange, + &self.crypto, + &self.handler, + &self.buffers, + &self.kv, + ), ); resp.respond(self.change_notify(), &mut wb, is_groupcast) @@ -327,7 +351,13 @@ where let mut resp = InvokeResponder::new( &req, &node, - HandlerInvoker::new(exchange, &self.crypto, &self.handler, &self.buffers), + HandlerInvoker::new( + exchange, + &self.crypto, + &self.handler, + &self.buffers, + &self.kv, + ), ); resp.respond(self.change_notify(), &mut wb, is_groupcast) @@ -711,7 +741,13 @@ where &req, &node, Some(id), - HandlerInvoker::new(exchange, &self.crypto, &self.handler, &self.buffers), + HandlerInvoker::new( + exchange, + &self.crypto, + &self.handler, + &self.buffers, + &self.kv, + ), EventReader::new(min_event_number), self.events, ); @@ -834,11 +870,13 @@ where } } -impl ExchangeHandler for DataModel<'_, NS, NE, C, B, T> +impl ExchangeHandler + for DataModel<'_, NS, NE, C, B, T, S> where C: Crypto, - T: DataModelHandler, B: BufferAccess, + T: DataModelHandler, + S: KvBlobStoreAccess, { fn handle(&self, exchange: &mut Exchange<'_>) -> impl Future> { DataModel::handle(self, exchange) @@ -853,20 +891,21 @@ where /// The responder handles chunking as needed. I.e. if reported data is too large to fit into a single /// Matter message, it will send the data in multiple chunks (i.e. with multiple Matter messages), waiting for /// a `Success` response from the peer after each chunk, and then continuing to send the next chunk until all data is sent. -struct ReportDataResponder<'a, 'b, 'c, C, D, B, const NE: usize> { +struct ReportDataResponder<'a, 'b, 'c, const NE: usize, C, D, B, S> { req: &'a ReportDataReq<'a>, node: &'a Node<'a>, subscription_id: Option, - invoker: HandlerInvoker<'b, 'c, C, D, B>, + invoker: HandlerInvoker<'b, 'c, C, D, B, S>, event_reader: EventReader, events: Option<&'a Events>, } -impl<'a, 'b, 'c, C, D, B, const NE: usize> ReportDataResponder<'a, 'b, 'c, C, D, B, NE> +impl<'a, 'b, 'c, const NE: usize, C, D, B, S> ReportDataResponder<'a, 'b, 'c, NE, C, D, B, S> where C: Crypto, D: AsyncHandler, B: BufferAccess, + S: KvBlobStoreAccess, { // This is the amount of space we reserve for the structure/array closing TLVs // to be attached towards the end of long reads @@ -877,7 +916,7 @@ where req: &'a ReportDataReq<'a>, node: &'a Node<'a>, subscription_id: Option, - invoker: HandlerInvoker<'b, 'c, C, D, B>, + invoker: HandlerInvoker<'b, 'c, C, D, B, S>, event_reader: EventReader, events: Option<&'a Events>, ) -> Self { @@ -1220,23 +1259,24 @@ enum ReportDataChunkState { /// the other peers is sending, but processing all of those chunks is not done here, /// but is rather - a responsibility of the caller who should call in a loop `WriteResponder::respond` /// for all the chunks of the write request, until the `WriteReq::more_chunks()` returns `false`. -struct WriteResponder<'a, 'b, 'c, C, D, B> { +struct WriteResponder<'a, 'b, 'c, C, D, B, S> { req: &'a WriteReq<'a>, node: &'a Node<'a>, - invoker: HandlerInvoker<'b, 'c, C, D, B>, + invoker: HandlerInvoker<'b, 'c, C, D, B, S>, } -impl<'a, 'b, 'c, C, D, B> WriteResponder<'a, 'b, 'c, C, D, B> +impl<'a, 'b, 'c, C, D, B, S> WriteResponder<'a, 'b, 'c, C, D, B, S> where C: Crypto, D: AsyncHandler, B: BufferAccess, + S: KvBlobStoreAccess, { /// Create a new `WriteResponder`. const fn new( req: &'a WriteReq<'a>, node: &'a Node<'a>, - invoker: HandlerInvoker<'b, 'c, C, D, B>, + invoker: HandlerInvoker<'b, 'c, C, D, B, S>, ) -> Self { Self { req, node, invoker } } @@ -1297,23 +1337,24 @@ where /// The simplest strategy for chunking would be to simply - and unconditionally - send each individual /// command response in a separate Matter message, i.e. if the invoke request contains 3 commands, /// the responder will send 3 Matter messages, each containing a single command response. -struct InvokeResponder<'a, 'b, 'c, C, D, B> { +struct InvokeResponder<'a, 'b, 'c, C, D, B, S> { req: &'a InvReq<'a>, node: &'a Node<'a>, - invoker: HandlerInvoker<'b, 'c, C, D, B>, + invoker: HandlerInvoker<'b, 'c, C, D, B, S>, } -impl<'a, 'b, 'c, C, D, B> InvokeResponder<'a, 'b, 'c, C, D, B> +impl<'a, 'b, 'c, C, D, B, S> InvokeResponder<'a, 'b, 'c, C, D, B, S> where C: Crypto, D: AsyncHandler, B: BufferAccess, + S: KvBlobStoreAccess, { /// Create a new `InvokeResponder`. const fn new( req: &'a InvReq<'a>, node: &'a Node<'a>, - invoker: HandlerInvoker<'b, 'c, C, D, B>, + invoker: HandlerInvoker<'b, 'c, C, D, B, S>, ) -> Self { Self { req, node, invoker } } diff --git a/rs-matter/src/dm/clusters/acl.rs b/rs-matter/src/dm/clusters/acl.rs index 4cbea0d84..52b4992db 100644 --- a/rs-matter/src/dm/clusters/acl.rs +++ b/rs-matter/src/dm/clusters/acl.rs @@ -17,7 +17,6 @@ //! This module contains the implementation of the Access Control cluster and its handler. -use crate::utils::init::stack_try_pin_init; use core::num::NonZeroU8; use crate::acl::{self, AclEntry, MAX_ACL_ENTRIES_PER_FABRIC}; @@ -26,8 +25,9 @@ use crate::dm::{ ReadContext, WriteContext, }; use crate::error::{Error, ErrorCode}; -use crate::fabric::Fabrics; +use crate::fabric::{Fabric, FabricPersist, Fabrics}; use crate::tlv::{TLVArray, TLVBuilderParent}; +use crate::utils::init::stack_try_pin_init; use crate::with; pub use crate::dm::clusters::decl::access_control::*; @@ -90,8 +90,7 @@ impl AclHandler { /// Set the ACL entries in the fabrics fn set_acl( &self, - fabrics: &mut Fabrics, - fab_idx: NonZeroU8, + fabric: &mut Fabric, value: ArrayAttributeWrite< TLVArray<'_, AccessControlEntryStruct<'_>>, AccessControlEntryStruct<'_>, @@ -108,30 +107,27 @@ impl AclHandler { } let entry = entry?; // Init a dummy to propagate failures for bad inputs - stack_try_pin_init!(let _processed =? AclEntry::init_with(fab_idx, &entry)); + stack_try_pin_init!(let _processed =? AclEntry::init_with(fabric.fab_idx(), &entry)); } // Now add everything once we know all are valid - fabrics.acl_remove_all(fab_idx)?; + fabric.acl_remove_all(); for entry in &list { // unwrap! calls below can't fail because we already checked that the entry is well-formed // and the length of the list is within the limit let entry = unwrap!(entry); - unwrap!(fabrics.acl_add_init(fab_idx, AclEntry::init_with(fab_idx, &entry))); + unwrap!(fabric.acl_add_init(AclEntry::init_with(fabric.fab_idx(), &entry))); } } ArrayAttributeWrite::Add(entry) => { - fabrics.acl_add_init(fab_idx, AclEntry::init_with(fab_idx, &entry))?; + fabric.acl_add_init(AclEntry::init_with(fabric.fab_idx(), &entry))?; } ArrayAttributeWrite::Update(index, entry) => { - fabrics.acl_update_init( - fab_idx, - index as _, - AclEntry::init_with(fab_idx, &entry), - )?; + fabric + .acl_update_init(index as _, AclEntry::init_with(fabric.fab_idx(), &entry))?; } ArrayAttributeWrite::Remove(index) => { - fabrics.acl_remove(fab_idx, index as _)?; + fabric.acl_remove(index as _)?; } } @@ -182,10 +178,17 @@ impl ClusterHandler for AclHandler { AccessControlEntryStruct<'_>, >, ) -> Result<(), Error> { + let mut persist = FabricPersist::new(ctx.kv()); + ctx.exchange().with_state(|state| { let fab_idx = NonZeroU8::new(ctx.attr().fab_idx).ok_or(ErrorCode::Invalid)?; - self.set_acl(&mut state.fabrics, fab_idx, value) - }) + let fabric = state.fabrics.fabric_mut(fab_idx)?; + self.set_acl(fabric, value)?; + + persist.store(fabric) + })?; + + persist.run() } fn handle_review_fabric_restrictions( @@ -270,7 +273,11 @@ mod tests { AclEntry::new(Some(FAB_2), Privilege::ADMIN, AuthMode::Case), ]; for i in &verifier { - fabrics.acl_add(i.fab_idx.unwrap(), i.clone()).unwrap(); + fabrics + .fabric_mut(i.fab_idx.unwrap()) + .unwrap() + .acl_add(i.clone()) + .unwrap(); } let acl = AclHandler::new(Dataver::new(0)); @@ -318,7 +325,11 @@ mod tests { AclEntry::new(Some(FAB_2), Privilege::ADMIN, AuthMode::Case), ]; for i in &input { - fabrics.acl_add(i.fab_idx.unwrap(), i.clone()).unwrap(); + fabrics + .fabric_mut(i.fab_idx.unwrap()) + .unwrap() + .acl_add(i.clone()) + .unwrap(); } let acl = AclHandler::new(Dataver::new(0)); @@ -357,7 +368,11 @@ mod tests { AclEntry::new(Some(FAB_2), Privilege::ADMIN, AuthMode::Case), ]; for i in input { - fabrics.acl_add(i.fab_idx.unwrap(), i).unwrap(); + fabrics + .fabric_mut(i.fab_idx.unwrap()) + .unwrap() + .acl_add(i) + .unwrap(); } let acl = AclHandler::new(Dataver::new(0)); @@ -473,8 +488,7 @@ mod tests { fn acl_add(acl: &AclHandler, fabrics: &mut Fabrics, data: &TLVElement<'_>, fab_idx: NonZeroU8) { unwrap!(acl.set_acl( - fabrics, - fab_idx, + fabrics.fabric_mut(fab_idx).unwrap(), ArrayAttributeWrite::Add(AccessControlEntryStruct::new(data.clone())), )); } @@ -487,13 +501,15 @@ mod tests { fab_idx: NonZeroU8, ) { unwrap!(acl.set_acl( - fabrics, - fab_idx, + fabrics.fabric_mut(fab_idx).unwrap(), ArrayAttributeWrite::Update(index, AccessControlEntryStruct::new(data.clone())), )); } fn acl_remove(acl: &AclHandler, fabrics: &mut Fabrics, index: u16, fab_idx: NonZeroU8) { - unwrap!(acl.set_acl(fabrics, fab_idx, ArrayAttributeWrite::Remove(index))); + unwrap!(acl.set_acl( + fabrics.fabric_mut(fab_idx).unwrap(), + ArrayAttributeWrite::Remove(index) + )); } } diff --git a/rs-matter/src/dm/clusters/adm_comm.rs b/rs-matter/src/dm/clusters/adm_comm.rs index bc2dd3b0e..862ba517e 100644 --- a/rs-matter/src/dm/clusters/adm_comm.rs +++ b/rs-matter/src/dm/clusters/adm_comm.rs @@ -17,8 +17,12 @@ //! This module contains the implementation of the Administrative Commissioning cluster and its handler. +use rand_core::RngCore; + +use crate::crypto::Crypto; use crate::dm::{Cluster, Dataver, InvokeContext, ReadContext}; use crate::error::Error; +use crate::sc::pase::spake2p::SPAKE2P_VERIFIER_SALT_ZEROED; use crate::sc::pase::{CommWindowOpener, CommWindowType}; use crate::tlv::Nullable; use crate::MatterState; @@ -74,8 +78,12 @@ impl ClusterHandler for AdminCommHandler { } fn window_status(&self, ctx: impl ReadContext) -> Result { + let notify_mdns = || ctx.exchange().matter().notify_mdns(); + let notify_change = + |endpt_id, clust_id, attr_id| ctx.notify_attribute_changed(endpt_id, clust_id, attr_id); + ctx.exchange().with_state(|state| { - let comm_window = state.pase.comm_window(&ctx)?; + let comm_window = state.pase.comm_window(notify_mdns, notify_change)?; let window_type = comm_window.map(|comm_window| comm_window.comm_window_type()); @@ -88,8 +96,12 @@ impl ClusterHandler for AdminCommHandler { } fn admin_fabric_index(&self, ctx: impl ReadContext) -> Result, Error> { + let notify_mdns = || ctx.exchange().matter().notify_mdns(); + let notify_change = + |endpt_id, clust_id, attr_id| ctx.notify_attribute_changed(endpt_id, clust_id, attr_id); + ctx.exchange().with_state(|state| { - let comm_window = state.pase.comm_window(&ctx)?; + let comm_window = state.pase.comm_window(notify_mdns, notify_change)?; if let Some(opener) = comm_window.and_then(|comm_window| comm_window.opener()) { if state.fabrics.get(opener.fab_idx).is_some() { @@ -104,8 +116,12 @@ impl ClusterHandler for AdminCommHandler { } fn admin_vendor_id(&self, ctx: impl ReadContext) -> Result, Error> { + let notify_mdns = || ctx.exchange().matter().notify_mdns(); + let notify_change = + |endpt_id, clust_id, attr_id| ctx.notify_attribute_changed(endpt_id, clust_id, attr_id); + ctx.exchange().with_state(|state| { - let comm_window = state.pase.comm_window(&ctx)?; + let comm_window = state.pase.comm_window(notify_mdns, notify_change)?; Ok(Nullable::new( comm_window @@ -120,17 +136,25 @@ impl ClusterHandler for AdminCommHandler { ctx: impl InvokeContext, request: OpenCommissioningWindowRequest<'_>, ) -> Result<(), Error> { + let notify_mdns = || ctx.exchange().matter().notify_mdns(); + let notify_change = + |endpt_id, clust_id, attr_id| ctx.notify_attribute_changed(endpt_id, clust_id, attr_id); + ctx.exchange().with_state(|state| { let opener = Self::current_window_opener(state, &ctx.exchange().id()); + let mdns_id = ctx.crypto().rand()?.next_u64(); + state.pase.open_comm_window( - &ctx, + mdns_id, request.pake_passcode_verifier()?.0.try_into()?, request.salt()?.0.try_into()?, request.iterations()?, request.discriminator()?, request.commissioning_timeout()?, opener, + notify_mdns, + notify_change, ) }) } @@ -140,23 +164,42 @@ impl ClusterHandler for AdminCommHandler { ctx: impl InvokeContext, request: OpenBasicCommissioningWindowRequest<'_>, ) -> Result<(), Error> { + let notify_mdns = || ctx.exchange().matter().notify_mdns(); + let notify_change = + |endpt_id, clust_id, attr_id| ctx.notify_attribute_changed(endpt_id, clust_id, attr_id); + ctx.exchange().with_state(|state| { let opener = Self::current_window_opener(state, &ctx.exchange().id()); let dev_comm = ctx.exchange().matter().dev_comm(); + let crypto = ctx.crypto(); + let mut rand = crypto.rand()?; + + let mdns_id = rand.next_u64(); + + let mut salt = SPAKE2P_VERIFIER_SALT_ZEROED; + rand.fill_bytes(salt.access_mut()); + state.pase.open_basic_comm_window( - &ctx, + mdns_id, + salt.reference(), dev_comm.password.reference(), dev_comm.discriminator, request.commissioning_timeout()?, opener, + notify_mdns, + notify_change, ) }) } fn handle_revoke_commissioning(&self, ctx: impl InvokeContext) -> Result<(), Error> { + let notify_mdns = || ctx.exchange().matter().notify_mdns(); + let notify_change = + |endpt_id, clust_id, attr_id| ctx.notify_attribute_changed(endpt_id, clust_id, attr_id); + ctx.exchange() - .with_state(|state| state.pase.close_comm_window(&ctx))?; + .with_state(|state| state.pase.close_comm_window(notify_mdns, notify_change))?; // TODO: Send status code if no commissioning window is open? diff --git a/rs-matter/src/dm/clusters/basic_info.rs b/rs-matter/src/dm/clusters/basic_info.rs index c1d9c52d1..64a2dceef 100644 --- a/rs-matter/src/dm/clusters/basic_info.rs +++ b/rs-matter/src/dm/clusters/basic_info.rs @@ -23,12 +23,12 @@ use crate::dm::subscriptions::DEFAULT_MAX_SUBSCRIPTIONS; use crate::dm::{Cluster, Dataver, InvokeContext, ReadContext, WriteContext}; use crate::error::{Error, ErrorCode}; use crate::fabric::MAX_FABRICS; -use crate::tlv::{FromTLV, Nullable, TLVBuilderParent, TLVElement, TLVTag, ToTLV, Utf8StrBuilder}; +use crate::persist::{KvBlobStore, Persist, BASIC_INFO_KEY}; +use crate::tlv::{FromTLV, Nullable, TLVBuilderParent, TLVElement, ToTLV, Utf8StrBuilder}; use crate::transport::exchange::Exchange; use crate::transport::session::MAX_SESSIONS; use crate::utils::bitflags::bitflags; use crate::utils::init::{init, Init}; -use crate::utils::storage::WriteBuf; use crate::{except, with}; pub use crate::dm::clusters::decl::basic_information::*; @@ -305,7 +305,6 @@ pub struct BasicInfoSettings { pub location: Option>, // Max location as per the spec pub location_type: RegulatoryLocationTypeEnum, pub local_config_disabled: bool, - pub changed: bool, } impl BasicInfoSettings { @@ -316,7 +315,6 @@ impl BasicInfoSettings { location: None, location_type: RegulatoryLocationTypeEnum::IndoorOutdoor, local_config_disabled: false, - changed: false, } } @@ -327,7 +325,6 @@ impl BasicInfoSettings { location: None, location_type: RegulatoryLocationTypeEnum::IndoorOutdoor, local_config_disabled: false, - changed: false, }) } @@ -335,43 +332,63 @@ impl BasicInfoSettings { /// /// # Arguments /// - `flag_changed`: whether to mark the basic info settings as changed - pub fn reset(&mut self, flag_changed: bool) { + pub fn reset(&mut self) { self.node_label.clear(); self.location = None; self.local_config_disabled = false; - self.changed = flag_changed; } - /// Load the basic info settings from the provided TLV data - pub fn load(&mut self, data: &[u8]) -> Result<(), Error> { - *self = FromTLV::from_tlv(&TLVElement::new(data))?; - - self.changed = false; - - Ok(()) + pub fn set_location(&mut self, location: &str) { + if location == "XX" { + self.location = None; + } else { + self.location = Some(unwrap!(heapless::String::<2>::from_str(location))); + } } - /// Store the basic info settings into the provided buffer as TLV data - pub fn store(&mut self, buf: &mut [u8]) -> Result { - let mut wb = WriteBuf::new(buf); - - self.to_tlv(&TLVTag::Anonymous, &mut wb) - .map_err(|_| ErrorCode::NoSpace)?; + /// Remove all basic info settings from the provided BLOB store as well as from memory + /// + /// # Arguments + /// - `store`: the BLOB store to remove the settings from + /// - `buf`: a temporary buffer to use for removing the settings + pub async fn reset_persist( + &mut self, + mut store: S, + buf: &mut [u8], + ) -> Result<(), Error> { + self.reset(); - self.changed = false; + store.remove(BASIC_INFO_KEY, buf)?; - let len = wb.get_tail(); + info!("Removed basic info settings from storage"); - Ok(len) + Ok(()) } - pub fn set_location(&mut self, location: &str) { - if location == "XX" { - self.location = None; - } else { - self.location = Some(unwrap!(heapless::String::<2>::from_str(location))); + /// Load all basic info settings from the provided BLOB store + /// + /// # Arguments + /// - `store`: the BLOB store to load the fabrics from + /// - `buf`: a temporary buffer to use for loading the fabrics + pub async fn load_persist( + &mut self, + mut store: S, + buf: &mut [u8], + ) -> Result<(), Error> { + self.reset(); + + if let Some(data) = store.load(BASIC_INFO_KEY, buf)? { + let info = Self::from_tlv(&TLVElement::new(data))?; + + self.node_label = info.node_label; + self.location = info.location; + self.location_type = info.location_type; + self.local_config_disabled = info.local_config_disabled; + + info!("Loaded basic info settings from storage"); } - self.changed = true; + + Ok(()) } } @@ -489,18 +506,19 @@ impl ClusterHandler for BasicInfoHandler { return Err(ErrorCode::ConstraintError.into()); } + let mut persist = Persist::new(ctx.kv()); + Self::with_settings(ctx.exchange(), |settings| { settings.node_label.clear(); settings .node_label .push_str(label) .map_err(|_| ErrorCode::ConstraintError)?; - settings.changed = true; - ctx.exchange().matter().notify_persist(); + persist.store_tlv(BASIC_INFO_KEY, &*settings) + })?; - Ok(()) - }) + persist.run() } fn location( @@ -518,13 +536,15 @@ impl ClusterHandler for BasicInfoHandler { return Err(ErrorCode::ConstraintError.into()); } + let mut persist = Persist::new(ctx.kv()); + Self::with_settings(ctx.exchange(), |settings| { settings.set_location(location); - ctx.exchange().matter().notify_persist(); + persist.store_tlv(BASIC_INFO_KEY, &*settings) + })?; - Ok(()) - }) + persist.run() } fn capability_minima( @@ -604,14 +624,15 @@ impl ClusterHandler for BasicInfoHandler { } fn set_local_config_disabled(&self, ctx: impl WriteContext, value: bool) -> Result<(), Error> { + let mut persist = Persist::new(ctx.kv()); + Self::with_settings(ctx.exchange(), |settings| { settings.local_config_disabled = value; - settings.changed = true; - ctx.exchange().matter().notify_persist(); + persist.store_tlv(BASIC_INFO_KEY, &*settings) + })?; - Ok(()) - }) + persist.run() } fn unique_id( diff --git a/rs-matter/src/dm/clusters/gen_comm.rs b/rs-matter/src/dm/clusters/gen_comm.rs index 652434a67..1711489ad 100644 --- a/rs-matter/src/dm/clusters/gen_comm.rs +++ b/rs-matter/src/dm/clusters/gen_comm.rs @@ -21,6 +21,7 @@ use core::fmt::Debug; use crate::dm::{Cluster, Dataver, InvokeContext, ReadContext, WriteContext}; use crate::error::{Error, ErrorCode}; +use crate::persist::{Persist, BASIC_INFO_KEY}; use crate::tlv::TLVBuilderParent; use crate::utils::sync::DynBase; use crate::with; @@ -187,9 +188,13 @@ impl ClusterHandler for GenCommHandler<'_> { fn handle_arm_fail_safe( &self, ctx: impl InvokeContext, - request: ArmFailSafeRequest, + request: ArmFailSafeRequest<'_>, response: ArmFailSafeResponseBuilder

, ) -> Result { + let notify_mdns = || ctx.exchange().matter().notify_mdns(); + let notify_change = + |endpt_id, clust_id, attr_id| ctx.notify_attribute_changed(endpt_id, clust_id, attr_id); + ctx.exchange().with_state(|state| { let sess = ctx.exchange().id().session(&mut state.sessions); @@ -198,7 +203,8 @@ impl ClusterHandler for GenCommHandler<'_> { request.breadcrumb()?, sess.get_session_mode(), &mut state.pase, - &ctx, + notify_mdns, + notify_change, ))?; response.error_code(status)?.debug_text("")?.end() @@ -208,7 +214,7 @@ impl ClusterHandler for GenCommHandler<'_> { fn handle_set_regulatory_config( &self, ctx: impl InvokeContext, - request: SetRegulatoryConfigRequest, + request: SetRegulatoryConfigRequest<'_>, response: SetRegulatoryConfigResponseBuilder

, ) -> Result { let country_code = request.country_code()?; @@ -216,19 +222,26 @@ impl ClusterHandler for GenCommHandler<'_> { return Err(ErrorCode::ConstraintError.into()); } + let location_type = request.new_regulatory_config()?; + let breadcrumb = request.breadcrumb()?; + + let mut persist = Persist::new(ctx.kv()); + ctx.exchange().with_state(|state| { state.basic_info_settings.set_location(country_code); - state.basic_info_settings.location_type = request.new_regulatory_config()?; + state.basic_info_settings.location_type = location_type; - ctx.exchange().matter().notify_persist(); + state.failsafe.set_breadcrumb(breadcrumb); - state.failsafe.set_breadcrumb(request.breadcrumb()?); + persist.store_tlv(BASIC_INFO_KEY, &state.basic_info_settings) + })?; - response - .error_code(CommissioningErrorEnum::OK)? - .debug_text("")? - .end() - }) + persist.run()?; + + response + .error_code(CommissioningErrorEnum::OK)? + .debug_text("")? + .end() } fn handle_commissioning_complete( @@ -236,6 +249,10 @@ impl ClusterHandler for GenCommHandler<'_> { ctx: impl InvokeContext, response: CommissioningCompleteResponseBuilder

, ) -> Result { + let notify_mdns = || ctx.exchange().matter().notify_mdns(); + let notify_change = + |endpt_id, clust_id, attr_id| ctx.notify_attribute_changed(endpt_id, clust_id, attr_id); + ctx.exchange().with_state(|state| { let sess = ctx.exchange().id().session(&mut state.sessions); @@ -254,7 +271,7 @@ impl ClusterHandler for GenCommHandler<'_> { if matches!(status, CommissioningErrorEnum::OK) { // As per section 5.5 of the Matter Core Spec V1.3 we have to terminate the PASE session // upon completion of commissioning - state.pase.close_comm_window(&ctx)?; + state.pase.close_comm_window(notify_mdns, notify_change)?; } response.error_code(status)?.debug_text("")?.end() @@ -264,7 +281,7 @@ impl ClusterHandler for GenCommHandler<'_> { fn handle_set_tc_acknowledgements( &self, _ctx: impl InvokeContext, - _request: SetTCAcknowledgementsRequest, + _request: SetTCAcknowledgementsRequest<'_>, response: SetTCAcknowledgementsResponseBuilder

, ) -> Result { // TODO diff --git a/rs-matter/src/dm/clusters/groups.rs b/rs-matter/src/dm/clusters/groups.rs index 224dd20a4..6be479c08 100644 --- a/rs-matter/src/dm/clusters/groups.rs +++ b/rs-matter/src/dm/clusters/groups.rs @@ -21,6 +21,7 @@ use core::num::NonZeroU8; use crate::dm::{Cluster, Dataver, InvokeContext, ReadContext}; use crate::error::{Error, ErrorCode}; +use crate::fabric::FabricPersist; use crate::im::IMStatusCode; use crate::tlv::{Nullable, TLVBuilderParent}; use crate::{with, MatterState}; @@ -59,7 +60,8 @@ impl GroupsHandler { let fabric = state.fabrics.get(fab_idx).ok_or(ErrorCode::NotFound)?; let result = fabric - .group_key_map_iter() + .groups() + .key_map_iter() .any(|entry| entry.group_id == group_id); Ok(result) @@ -104,35 +106,35 @@ impl ClusterHandler for GroupsHandler { .end(); } - ctx.exchange().with_state(|state| { + let mut persist = FabricPersist::new(ctx.kv()); + + let status = ctx.exchange().with_state(|state| { // Check if group security material is available if !Self::has_group_material(state, fab_idx, group_id)? { - return response - .status(IMStatusCode::UnsupportedAccess as u8)? - .group_id(group_id)? - .end(); + return Ok(IMStatusCode::UnsupportedAccess); } // Add or update group membership let endpoint_id = ctx.cmd().endpoint_id; - match state - .fabrics - .group_add(endpoint_id, group_id, group_name, fab_idx) - { + let fabric = state.fabrics.fabric_mut(fab_idx)?; + + match fabric.groups_mut().add(endpoint_id, group_id, group_name) { Ok(_) => { + persist.store(fabric)?; ctx.exchange().matter().notify_groups_changed(); - response - .status(IMStatusCode::Success as u8)? - .group_id(group_id)? - .end() + + Ok(IMStatusCode::Success) } - Err(e) if e.code() == ErrorCode::ResourceExhausted => response - .status(IMStatusCode::ResourceExhausted as u8)? - .group_id(group_id)? - .end(), - Err(e) => Err(e), + Err(e) if e.code() == ErrorCode::ResourceExhausted => { + Ok(IMStatusCode::ResourceExhausted) + } + Err(e) => Err(e)?, } - }) + })?; + + persist.run()?; + + response.status(status as u8)?.group_id(group_id)?.end() } fn handle_view_group( @@ -160,7 +162,7 @@ impl ClusterHandler for GroupsHandler { let fabric = state.fabrics.get(fab_idx).ok_or(ErrorCode::NotFound)?; let endpoint_id = ctx.cmd().endpoint_id; - if let Some(entry) = fabric.group_get(group_id) { + if let Some(entry) = fabric.groups().get(group_id) { if entry.endpoints.contains(&endpoint_id) { return response .status(IMStatusCode::Success as u8)? @@ -190,7 +192,7 @@ impl ClusterHandler for GroupsHandler { let request_group_list = request.group_list()?; ctx.exchange().with_state(|state| { - let fabric = state.fabrics.get(fab_idx).ok_or(ErrorCode::NotFound)?; + let fabric = state.fabrics.fabric(fab_idx)?; // Capacity is nullable - return null to indicate unknown capacity let capacity = Nullable::::none(); @@ -200,7 +202,7 @@ impl ClusterHandler for GroupsHandler { if request_group_list.iter().count() == 0 { // Return all groups this endpoint is a member of - for entry in fabric.group_iter() { + for entry in fabric.groups().iter() { if entry.endpoints.contains(&endpoint_id) { group_list = group_list.push(&entry.group_id)?; } @@ -208,7 +210,7 @@ impl ClusterHandler for GroupsHandler { } else { // Return intersection: only requested groups that this endpoint is a member of for gid in request_group_list.into_iter().flatten() { - if let Some(entry) = fabric.group_get(gid) { + if let Some(entry) = fabric.groups().get(gid) { if entry.endpoints.contains(&endpoint_id) { group_list = group_list.push(&gid)?; } @@ -228,48 +230,54 @@ impl ClusterHandler for GroupsHandler { ) -> Result { let fab_idx = NonZeroU8::new(ctx.exchange().accessor()?.fab_idx).ok_or(ErrorCode::Invalid)?; - let group_id = request.group_id()?; + let endpoint_id = ctx.cmd().endpoint_id; - // Step 1: Validate constraints - if group_id == 0 { - return response - .status(IMStatusCode::ConstraintError as u8)? - .group_id(group_id)? - .end(); - } + let mut persist = FabricPersist::new(ctx.kv()); - // Steps 2-3: Remove membership - let endpoint_id = ctx.cmd().endpoint_id; - let removed = ctx - .exchange() - .with_state(|state| state.fabrics.group_remove(endpoint_id, group_id, fab_idx))?; + let status = ctx.exchange().with_state(|state| { + // Step 1: Validate constraints + if group_id == 0 { + return Ok(IMStatusCode::ConstraintError); + } - if removed { - ctx.exchange().matter().notify_groups_changed(); - response - .status(IMStatusCode::Success as u8)? - .group_id(group_id)? - .end() - } else { - response - .status(IMStatusCode::NotFound as u8)? - .group_id(group_id)? - .end() - } + let fabric = state.fabrics.fabric_mut(fab_idx)?; + + // Steps 2-3: Remove membership + if fabric.groups_mut().remove(endpoint_id, Some(group_id)) { + persist.store(fabric)?; + ctx.exchange().matter().notify_groups_changed(); + + Ok(IMStatusCode::Success) + } else { + Ok(IMStatusCode::NotFound) + } + })?; + + persist.run()?; + + response.status(status as u8)?.group_id(group_id)?.end() } fn handle_remove_all_groups(&self, ctx: impl InvokeContext) -> Result<(), Error> { let fab_idx = NonZeroU8::new(ctx.exchange().accessor()?.fab_idx).ok_or(ErrorCode::Invalid)?; - let endpoint_id = ctx.cmd().endpoint_id; + + let mut persist = FabricPersist::new(ctx.kv()); + ctx.exchange().with_state(|state| { - state - .fabrics - .group_remove_all_for_endpoint(endpoint_id, fab_idx) + let fabric = state.fabrics.fabric_mut(fab_idx)?; + + fabric.groups_mut().remove(endpoint_id, None); + + persist.store(fabric)?; + ctx.exchange().matter().notify_groups_changed(); + + Ok(()) })?; - ctx.exchange().matter().notify_groups_changed(); + + persist.run()?; Ok(()) } diff --git a/rs-matter/src/dm/clusters/grp_key_mgmt.rs b/rs-matter/src/dm/clusters/grp_key_mgmt.rs index c58629756..e72766544 100644 --- a/rs-matter/src/dm/clusters/grp_key_mgmt.rs +++ b/rs-matter/src/dm/clusters/grp_key_mgmt.rs @@ -25,7 +25,9 @@ use crate::dm::{ WriteContext, }; use crate::error::{Error, ErrorCode}; -use crate::fabric::GroupKeyMapping; +use crate::fabric::{ + FabricPersist, GroupKeyMapping, MAX_GROUPS_PER_FABRIC, MAX_GROUP_KEYS_PER_FABRIC, +}; use crate::group_keys::{GroupEpochKeyEntry, GroupKeySet}; use crate::tlv::{Nullable, Octets, TLVArray, TLVBuilderParent}; use crate::with; @@ -76,7 +78,8 @@ impl ClusterHandler for GrpKeyMgmtHandler { .filter(|fabric| !attr.fab_filter || fabric.fab_idx().get() == attr.fab_idx) .flat_map(|fabric| { fabric - .group_key_map_iter() + .groups() + .key_map_iter() .map(move |entry| (fabric.fab_idx(), entry)) }); @@ -124,7 +127,8 @@ impl ClusterHandler for GrpKeyMgmtHandler { .filter(|fabric| !attr.fab_filter || fabric.fab_idx().get() == attr.fab_idx) .flat_map(|fabric| { fabric - .group_iter() + .groups() + .iter() .map(move |entry| (fabric.fab_idx(), entry)) }); @@ -163,15 +167,13 @@ impl ClusterHandler for GrpKeyMgmtHandler { }) } - fn max_groups_per_fabric(&self, ctx: impl ReadContext) -> Result { - ctx.exchange() - .with_state(|state| Ok(state.fabrics.max_groups_per_fabric())) + fn max_groups_per_fabric(&self, _ctx: impl ReadContext) -> Result { + Ok(MAX_GROUPS_PER_FABRIC as _) } - fn max_group_keys_per_fabric(&self, ctx: impl ReadContext) -> Result { + fn max_group_keys_per_fabric(&self, _ctx: impl ReadContext) -> Result { // +1 for IPK (key set 0) - ctx.exchange() - .with_state(|state| Ok(state.fabrics.max_group_keys_per_fabric() + 1)) + Ok(MAX_GROUP_KEYS_PER_FABRIC as u16 + 1) } fn set_group_key_map( @@ -181,14 +183,18 @@ impl ClusterHandler for GrpKeyMgmtHandler { ) -> Result<(), Error> { let fab_idx = NonZeroU8::new(ctx.attr().fab_idx).ok_or(ErrorCode::Invalid)?; + let mut persist = FabricPersist::new(ctx.kv()); + ctx.exchange().with_state(|state| { + let fabric = state.fabrics.fabric_mut(fab_idx)?; + match value { ArrayAttributeWrite::Replace(list) => { // First validate all entries let mut count: usize = 0; for entry in &list { count += 1; - if count > state.fabrics.max_groups_per_fabric().into() { + if count > MAX_GROUP_KEYS_PER_FABRIC { return Err(ErrorCode::Failure.into()); } let entry = entry?; @@ -207,7 +213,7 @@ impl ClusterHandler for GrpKeyMgmtHandler { }) }); - state.fabrics.group_key_map_replace(fab_idx, entries)?; + fabric.groups_mut().key_map_replace(entries)?; } ArrayAttributeWrite::Add(entry) => { // GroupKeySetID must not be 0 @@ -215,25 +221,25 @@ impl ClusterHandler for GrpKeyMgmtHandler { return Err(ErrorCode::ConstraintError.into()); } - state.fabrics.group_key_map_add( - fab_idx, - GroupKeyMapping { - group_id: entry.group_id().map_err(|_| ErrorCode::InvalidCommand)?, - group_key_set_id: entry - .group_key_set_id() - .map_err(|_| ErrorCode::InvalidCommand)?, - }, - )?; + fabric.groups_mut().key_map_add(GroupKeyMapping { + group_id: entry.group_id().map_err(|_| ErrorCode::InvalidCommand)?, + group_key_set_id: entry + .group_key_set_id() + .map_err(|_| ErrorCode::InvalidCommand)?, + })?; } _ => { return Err(ErrorCode::InvalidAction.into()); } } + persist.store(fabric)?; ctx.exchange().matter().notify_groups_changed(); Ok(()) - }) + })?; + + persist.run() } fn handle_key_set_write( @@ -379,11 +385,19 @@ impl ClusterHandler for GrpKeyMgmtHandler { } } - ctx.exchange() - .with_state(|state| state.fabrics.group_key_set_add(fab_idx, entry))?; - ctx.exchange().matter().notify_groups_changed(); + let mut persist = FabricPersist::new(ctx.kv()); + + ctx.exchange().with_state(|state| { + let fabric = state.fabrics.fabric_mut(fab_idx)?; + + fabric.groups_mut().key_set_add(entry)?; + persist.store(fabric)?; + ctx.exchange().matter().notify_groups_changed(); + + Ok(()) + })?; - Ok(()) + persist.run() } fn handle_key_set_read( @@ -400,7 +414,8 @@ impl ClusterHandler for GrpKeyMgmtHandler { ctx.exchange().with_state(|state| { let fabric = state.fabrics.get(fab_idx).ok_or(ErrorCode::NotFound)?; let entry = fabric - .group_key_set_get(group_key_set_id) + .groups() + .key_set_get(group_key_set_id) .ok_or(ErrorCode::NotFound)?; // Build response: epoch keys are always null, start times are preserved @@ -450,14 +465,19 @@ impl ClusterHandler for GrpKeyMgmtHandler { return Err(ErrorCode::InvalidCommand.into()); } + let mut persist = FabricPersist::new(ctx.kv()); + ctx.exchange().with_state(|state| { - state - .fabrics - .group_key_set_remove(fab_idx, group_key_set_id) + let fabric = state.fabrics.fabric_mut(fab_idx)?; + + fabric.groups_mut().key_set_remove(group_key_set_id)?; + persist.store(fabric)?; + ctx.exchange().matter().notify_groups_changed(); + + Ok(()) })?; - ctx.exchange().matter().notify_groups_changed(); - Ok(()) + persist.run() } fn handle_key_set_read_all_indices( @@ -469,13 +489,13 @@ impl ClusterHandler for GrpKeyMgmtHandler { NonZeroU8::new(ctx.exchange().accessor()?.fab_idx).ok_or(ErrorCode::Invalid)?; ctx.exchange().with_state(|state| { - let fabric = state.fabrics.get(fab_idx).ok_or(ErrorCode::NotFound)?; + let fabric = state.fabrics.fabric(fab_idx)?; // Always include IPK (0) plus all stored key set IDs let mut ids = response.group_key_set_i_ds()?; ids = ids.push(&0u16)?; - for entry in fabric.group_key_set_iter() { + for entry in fabric.groups().key_set_iter() { ids = ids.push(&entry.group_key_set_id)?; } diff --git a/rs-matter/src/dm/clusters/level_control.rs b/rs-matter/src/dm/clusters/level_control.rs index 768ea9421..2d34501b3 100644 --- a/rs-matter/src/dm/clusters/level_control.rs +++ b/rs-matter/src/dm/clusters/level_control.rs @@ -29,7 +29,7 @@ //! - Designed for extensibility and integration with other clusters (e.g., OnOff). use core::cell::Cell; -use core::future::{pending, Future}; +use core::future::{pending, ready, Future}; use core::ops::Mul; use core::pin::pin; @@ -45,6 +45,7 @@ use crate::dm::{ use crate::error::{Error, ErrorCode}; use crate::tlv::Nullable; use crate::utils::cell::RefCell; +use crate::utils::future::delayed_ready; use crate::utils::sync::blocking::Mutex; use crate::utils::sync::Signal; @@ -1275,276 +1276,338 @@ impl ClusterAsyncHandler for LevelControlH self.dataver.changed(); } - async fn current_level(&self, _ctx: impl ReadContext) -> Result, Error> { - match self.hooks.current_level() { + fn current_level( + &self, + _ctx: impl ReadContext, + ) -> impl Future, Error>> { + delayed_ready(|| match self.hooks.current_level() { Some(level) => Ok(Nullable::some(level)), None => Ok(Nullable::none()), - } + }) } - async fn on_level(&self, _ctx: impl ReadContext) -> Result, Error> { - Ok(self.with_state(|state| state.on_level.clone())) + fn on_level( + &self, + _ctx: impl ReadContext, + ) -> impl Future, Error>> { + delayed_ready(|| Ok(self.with_state(|state| state.on_level.clone()))) } - async fn set_on_level(&self, ctx: impl WriteContext, value: Nullable) -> Result<(), Error> { - if let Some(level) = value.clone().into_option() { - if level > H::MAX_LEVEL || level < H::MIN_LEVEL { - return Err(ErrorCode::ConstraintError.into()); + fn set_on_level( + &self, + ctx: impl WriteContext, + value: Nullable, + ) -> impl Future> { + delayed_ready(|| { + if let Some(level) = value.clone().into_option() { + if level > H::MAX_LEVEL || level < H::MIN_LEVEL { + return Err(ErrorCode::ConstraintError.into()); + } } - } - self.with_state_notify(ctx, |state| { - state.on_level = value; - }); + self.with_state_notify(ctx, |state| { + state.on_level = value; + }); - Ok(()) + Ok(()) + }) } - async fn options(&self, _ctx: impl ReadContext) -> Result { - Ok(self.with_state(|state| state.options)) + fn options( + &self, + _ctx: impl ReadContext, + ) -> impl Future> { + delayed_ready(|| Ok(self.with_state(|state| state.options))) } - async fn set_options(&self, ctx: impl WriteContext, value: OptionsBitmap) -> Result<(), Error> { - self.with_state_notify(ctx, |state| { - state.options = value; - }); + fn set_options( + &self, + ctx: impl WriteContext, + value: OptionsBitmap, + ) -> impl Future> { + delayed_ready(move || { + self.with_state_notify(ctx, |state| { + state.options = value; + }); - Ok(()) + Ok(()) + }) } - async fn remaining_time(&self, _ctx: impl ReadContext) -> Result { - Ok(self.with_state(|state| state.remaining_time)) + fn remaining_time(&self, _ctx: impl ReadContext) -> impl Future> { + delayed_ready(|| Ok(self.with_state(|state| state.remaining_time))) } - async fn max_level(&self, _ctx: impl ReadContext) -> Result { - Ok(H::MAX_LEVEL) + fn max_level(&self, _ctx: impl ReadContext) -> impl Future> { + delayed_ready(|| Ok(H::MAX_LEVEL)) } - async fn min_level(&self, _ctx: impl ReadContext) -> Result { - Ok(H::MIN_LEVEL) + fn min_level(&self, _ctx: impl ReadContext) -> impl Future> { + delayed_ready(|| Ok(H::MIN_LEVEL)) } - async fn on_off_transition_time(&self, _ctx: impl ReadContext) -> Result { - Ok(self.with_state(|state| state.on_off_transition_time)) + fn on_off_transition_time( + &self, + _ctx: impl ReadContext, + ) -> impl Future> { + delayed_ready(|| Ok(self.with_state(|state| state.on_off_transition_time))) } - async fn set_on_off_transition_time( + fn set_on_off_transition_time( &self, ctx: impl WriteContext, value: u16, - ) -> Result<(), Error> { - self.with_state_notify(ctx, |state| { - state.on_off_transition_time = value; - }); + ) -> impl Future> { + delayed_ready(move || { + self.with_state_notify(ctx, |state| { + state.on_off_transition_time = value; + }); - Ok(()) + Ok(()) + }) } - async fn on_transition_time(&self, _ctx: impl ReadContext) -> Result, Error> { - Ok(self.with_state(|state| state.on_transition_time.clone())) + fn on_transition_time( + &self, + _ctx: impl ReadContext, + ) -> impl Future, Error>> { + delayed_ready(|| Ok(self.with_state(|state| state.on_transition_time.clone()))) } - async fn set_on_transition_time( + fn set_on_transition_time( &self, ctx: impl WriteContext, value: Nullable, - ) -> Result<(), Error> { - self.with_state_notify(ctx, |state| { - state.on_transition_time = value; - }); + ) -> impl Future> { + delayed_ready(|| { + self.with_state_notify(ctx, |state| { + state.on_transition_time = value; + }); - Ok(()) + Ok(()) + }) } - async fn off_transition_time(&self, _ctx: impl ReadContext) -> Result, Error> { - Ok(self.with_state(|state| state.off_transition_time.clone())) + fn off_transition_time( + &self, + _ctx: impl ReadContext, + ) -> impl Future, Error>> { + delayed_ready(|| Ok(self.with_state(|state| state.off_transition_time.clone()))) } - async fn set_off_transition_time( + fn set_off_transition_time( &self, ctx: impl WriteContext, value: Nullable, - ) -> Result<(), Error> { - self.with_state_notify(ctx, |state| { - state.off_transition_time = value; - }); + ) -> impl Future> { + delayed_ready(|| { + self.with_state_notify(ctx, |state| { + state.off_transition_time = value; + }); - Ok(()) + Ok(()) + }) } - async fn default_move_rate(&self, _ctx: impl ReadContext) -> Result, Error> { - Ok(self.with_state(|state| state.default_move_rate.clone())) + fn default_move_rate( + &self, + _ctx: impl ReadContext, + ) -> impl Future, Error>> { + delayed_ready(|| Ok(self.with_state(|state| state.default_move_rate.clone()))) } - async fn set_default_move_rate( + fn set_default_move_rate( &self, ctx: impl WriteContext, value: Nullable, - ) -> Result<(), Error> { - // The spec is not explicit about what should be done if this happens. - // For now we error out if DefaultMoveRate is equal to 0 as this is invalid - // until spec defines a behaviour. - if Some(0) == value.clone().into_option() { - return Err(ErrorCode::InvalidData.into()); - } + ) -> impl Future> { + delayed_ready(move || { + // The spec is not explicit about what should be done if this happens. + // For now we error out if DefaultMoveRate is equal to 0 as this is invalid + // until spec defines a behaviour. + if Some(0) == value.clone().into_option() { + return Err(ErrorCode::InvalidData.into()); + } - self.with_state_notify(ctx, |state| { - state.default_move_rate = value; - }); + self.with_state_notify(ctx, |state| { + state.default_move_rate = value; + }); - Ok(()) + Ok(()) + }) } - async fn start_up_current_level(&self, _ctx: impl ReadContext) -> Result, Error> { - match self.hooks.start_up_current_level()? { + fn start_up_current_level( + &self, + _ctx: impl ReadContext, + ) -> impl Future, Error>> { + delayed_ready(|| match self.hooks.start_up_current_level()? { Some(val) => Ok(Nullable::some(val)), None => Ok(Nullable::none()), - } + }) } - async fn set_start_up_current_level( + fn set_start_up_current_level( &self, ctx: impl WriteContext, value: Nullable, - ) -> Result<(), Error> { - // According to the current spec, this attribute does not have any constraints at this stage. - // However, it's usage is bounded by min/max hence it makes sense to restrict the settable values to this range. - if let Some(level) = value.clone().into_option() { - if level > H::MAX_LEVEL || level < H::MIN_LEVEL { - return Err(ErrorCode::ConstraintError.into()); + ) -> impl Future> { + delayed_ready(move || { + // According to the current spec, this attribute does not have any constraints at this stage. + // However, it's usage is bounded by min/max hence it makes sense to restrict the settable values to this range. + if let Some(level) = value.clone().into_option() { + if level > H::MAX_LEVEL || level < H::MIN_LEVEL { + return Err(ErrorCode::ConstraintError.into()); + } } - } - self.hooks.set_start_up_current_level(value.into_option())?; - self.dataver_changed(); - ctx.notify_changed(); - Ok(()) + self.hooks.set_start_up_current_level(value.into_option())?; + self.dataver_changed(); + ctx.notify_changed(); + Ok(()) + }) } - async fn handle_move_to_level( + fn handle_move_to_level( &self, _ctx: impl InvokeContext, request: MoveToLevelRequest<'_>, - ) -> Result<(), Error> { - self.move_to_level( - false, - request.level()?, - request.transition_time()?.into_option(), - request.options_mask()?, - request.options_override()?, - ) + ) -> impl Future> { + delayed_ready(move || { + self.move_to_level( + false, + request.level()?, + request.transition_time()?.into_option(), + request.options_mask()?, + request.options_override()?, + ) + }) } - async fn handle_move( + fn handle_move( &self, _ctx: impl InvokeContext, request: MoveRequest<'_>, - ) -> Result<(), Error> { - self.with_state(|state| { - self.move_command( - state, - false, - request.move_mode()?, - request.rate()?.into_option(), - request.options_mask()?, - request.options_override()?, - ) + ) -> impl Future> { + delayed_ready(move || { + self.with_state(|state| { + self.move_command( + state, + false, + request.move_mode()?, + request.rate()?.into_option(), + request.options_mask()?, + request.options_override()?, + ) + }) }) } - async fn handle_step( + fn handle_step( &self, _ctx: impl InvokeContext, request: StepRequest<'_>, - ) -> Result<(), Error> { - self.step( - false, - request.step_mode()?, - request.step_size()?, - request.transition_time()?.into_option(), - request.options_mask()?, - request.options_override()?, - ) + ) -> impl Future> { + delayed_ready(move || { + self.step( + false, + request.step_mode()?, + request.step_size()?, + request.transition_time()?.into_option(), + request.options_mask()?, + request.options_override()?, + ) + }) } - async fn handle_stop( + fn handle_stop( &self, ctx: impl InvokeContext, request: StopRequest<'_>, - ) -> Result<(), Error> { - self.stop( - &ctx, - false, - request.options_mask()?, - request.options_override()?, - ) + ) -> impl Future> { + delayed_ready(move || { + self.stop( + &ctx, + false, + request.options_mask()?, + request.options_override()?, + ) + }) } - async fn handle_move_to_level_with_on_off( + fn handle_move_to_level_with_on_off( &self, _ctx: impl InvokeContext, request: MoveToLevelWithOnOffRequest<'_>, - ) -> Result<(), Error> { - self.move_to_level( - true, - request.level()?, - request.transition_time()?.into_option(), - request.options_mask()?, - request.options_override()?, - ) + ) -> impl Future> { + delayed_ready(move || { + self.move_to_level( + true, + request.level()?, + request.transition_time()?.into_option(), + request.options_mask()?, + request.options_override()?, + ) + }) } - async fn handle_move_with_on_off( + fn handle_move_with_on_off( &self, _ctx: impl InvokeContext, request: MoveWithOnOffRequest<'_>, - ) -> Result<(), Error> { - self.with_state(|state| { - self.move_command( - state, - true, - request.move_mode()?, - request.rate()?.into_option(), - request.options_mask()?, - request.options_override()?, - ) + ) -> impl Future> { + delayed_ready(move || { + self.with_state(|state| { + self.move_command( + state, + true, + request.move_mode()?, + request.rate()?.into_option(), + request.options_mask()?, + request.options_override()?, + ) + }) }) } - async fn handle_step_with_on_off( + fn handle_step_with_on_off( &self, _ctx: impl InvokeContext, request: StepWithOnOffRequest<'_>, - ) -> Result<(), Error> { - self.step( - true, - request.step_mode()?, - request.step_size()?, - request.transition_time()?.into_option(), - request.options_mask()?, - request.options_override()?, - ) + ) -> impl Future> { + delayed_ready(move || { + self.step( + true, + request.step_mode()?, + request.step_size()?, + request.transition_time()?.into_option(), + request.options_mask()?, + request.options_override()?, + ) + }) } - async fn handle_stop_with_on_off( + fn handle_stop_with_on_off( &self, ctx: impl InvokeContext, request: StopWithOnOffRequest<'_>, - ) -> Result<(), Error> { - self.stop( - &ctx, - true, - request.options_mask()?, - request.options_override()?, - ) + ) -> impl Future> { + delayed_ready(move || { + self.stop( + &ctx, + true, + request.options_mask()?, + request.options_override()?, + ) + }) } - async fn handle_move_to_closest_frequency( + fn handle_move_to_closest_frequency( &self, _ctx: impl InvokeContext, _request: MoveToClosestFrequencyRequest<'_>, - ) -> Result<(), Error> { - Err(ErrorCode::InvalidCommand.into()) + ) -> impl Future> { + ready(Err(ErrorCode::InvalidCommand.into())) } } @@ -1592,8 +1655,8 @@ pub trait LevelControlHooks { /// /// # Panics /// The SDK will panic if this method returns. - async fn run(&self, _notify: F) { - pending::<()>().await + fn run(&self, _notify: F) -> impl Future { + pending::<()>() } } diff --git a/rs-matter/src/dm/clusters/net_comm.rs b/rs-matter/src/dm/clusters/net_comm.rs index 667c793dd..445898071 100644 --- a/rs-matter/src/dm/clusters/net_comm.rs +++ b/rs-matter/src/dm/clusters/net_comm.rs @@ -18,15 +18,22 @@ //! This module contains the implementation of the Network Commissioning cluster and its handler. use core::fmt::{self, Debug}; +use core::future::{ready, Future}; use crate::dm::networks::wireless::{Thread, ThreadTLV, MAX_WIRELESS_NETWORK_ID_LEN}; +use crate::dm::networks::NetChangeNotif; use crate::dm::{ArrayAttributeRead, Cluster, Dataver, InvokeContext, ReadContext, WriteContext}; use crate::error::{Error, ErrorCode}; +use crate::persist::{Persist, NETWORKS_KEY}; use crate::tlv::{ Nullable, NullableBuilder, Octets, OctetsBuilder, TLVBuilder, TLVBuilderParent, TLVWrite, ToTLVArrayBuilder, ToTLVBuilder, }; -use crate::utils::sync::DynBase; +use crate::utils::cell::RefCell; +use crate::utils::future::delayed_ready; +use crate::utils::init::{init, Init}; +use crate::utils::sync::blocking::Mutex; +use crate::utils::sync::{DynBase, Notification}; use crate::{clusters, with}; pub use crate::dm::clusters::decl::network_commissioning::*; @@ -408,7 +415,7 @@ impl NetworkCommissioningStatusEnum { } /// Trait for managing networks' credentials storage -pub trait Networks: DynBase { +pub trait Networks { /// Return the maximum number of networks supported by the implementation /// /// For `NetworkType::Ethernet` this method should always return 1 @@ -452,7 +459,7 @@ pub trait Networks: DynBase { fn enabled(&self) -> Result; /// Set the network interface enabled or disabled - fn set_enabled(&self, enabled: bool) -> Result<(), Error>; + fn set_enabled(&mut self, enabled: bool) -> Result<(), Error>; /// Add or update the credentials for the given network ID /// @@ -461,7 +468,7 @@ pub trait Networks: DynBase { /// The network ID is derived from the credentials. /// /// Return the index of the network ID if it was added or updated, or an error if the operation failed. - fn add_or_update(&self, creds: &WirelessCreds<'_>) -> Result; + fn add_or_update(&mut self, creds: &WirelessCreds<'_>) -> Result; /// Reorder the network with the given index /// @@ -470,26 +477,30 @@ pub trait Networks: DynBase { /// The index is the new index of the network ID. /// /// Return the index of the network ID if it was reordered, or an error if the operation failed. - fn reorder(&self, index: u8, network_id: &[u8]) -> Result; + fn reorder(&mut self, index: u8, network_id: &[u8]) -> Result; /// Remove the network with the given network ID /// /// For `NetworkType::Ethernet` this method should always fail with an error. /// /// Return the index of the network ID if it was removed, or an error if the operation failed. - fn remove(&self, network_id: &[u8]) -> Result; + fn remove(&mut self, network_id: &[u8]) -> Result; + + /// Persist the networks' credentials into the given buffer and return the number of bytes written + /// or `None` if the networks do not need persistence. + fn persist(&self, buf: &mut [u8]) -> Result, Error>; } -impl Networks for &T +impl Networks for &mut T where T: Networks, { fn max_networks(&self) -> Result { - (*self).max_networks() + (**self).max_networks() } fn networks(&self, f: &mut dyn FnMut(&NetworkInfo) -> Result<(), Error>) -> Result<(), Error> { - (*self).networks(f) + (**self).networks(f) } fn creds( @@ -497,7 +508,7 @@ where network_id: &[u8], f: &mut dyn FnMut(&WirelessCreds) -> Result<(), Error>, ) -> Result { - (*self).creds(network_id, f) + (**self).creds(network_id, f) } fn next_creds( @@ -505,28 +516,45 @@ where last_network_id: Option<&[u8]>, f: &mut dyn FnMut(&WirelessCreds) -> Result<(), Error>, ) -> Result { - (*self).next_creds(last_network_id, f) + (**self).next_creds(last_network_id, f) } fn enabled(&self) -> Result { - (*self).enabled() + (**self).enabled() } - fn set_enabled(&self, enabled: bool) -> Result<(), Error> { + fn set_enabled(&mut self, enabled: bool) -> Result<(), Error> { (*self).set_enabled(enabled) } - fn add_or_update(&self, creds: &WirelessCreds<'_>) -> Result { + fn add_or_update(&mut self, creds: &WirelessCreds<'_>) -> Result { (*self).add_or_update(creds) } - fn reorder(&self, index: u8, network_id: &[u8]) -> Result { + fn reorder(&mut self, index: u8, network_id: &[u8]) -> Result { (*self).reorder(index, network_id) } - fn remove(&self, network_id: &[u8]) -> Result { + fn remove(&mut self, network_id: &[u8]) -> Result { (*self).remove(network_id) } + + fn persist(&self, buf: &mut [u8]) -> Result, Error> { + (**self).persist(buf) + } +} + +pub trait NetworksAccess { + fn access R, R>(&self, f: F) -> R; +} + +impl NetworksAccess for &T +where + T: NetworksAccess, +{ + fn access R, R>(&self, f: F) -> R { + (*self).access(f) + } } /// Trait for managing network connectivity @@ -622,15 +650,15 @@ where (*self).thread_version() } - async fn scan(&self, network: Option<&[u8]>, f: F) -> Result<(), NetCtlError> + fn scan(&self, network: Option<&[u8]>, f: F) -> impl Future> where F: FnMut(&NetworkScanInfo) -> Result<(), Error>, { - (*self).scan(network, f).await + (*self).scan(network, f) } - async fn connect(&self, creds: &WirelessCreds<'_>) -> Result<(), NetCtlError> { - (*self).connect(creds).await + fn connect(&self, creds: &WirelessCreds<'_>) -> impl Future> { + (*self).connect(creds) } } @@ -674,17 +702,151 @@ where } } +/// A type providing shared access to a `Networks` implementation with change notification capabilities. +pub struct SharedNetworks { + state: Mutex>, + state_changed: Notification, +} + +impl SharedNetworks { + /// Create a new instance. + pub const fn new(networks: N) -> Self { + Self { + state: Mutex::new(RefCell::new(networks)), + state_changed: Notification::new(), + } + } + + /// Return an in-place initializer for the struct. + pub fn init(networks: impl Init) -> impl Init { + init!(Self { + state <- Mutex::init(RefCell::init(networks)), + state_changed: Notification::new(), + }) + } + + /// Get a mutable reference to the inner `Networks` implementation. + pub fn get_mut(&mut self) -> &mut RefCell { + self.state.get_mut() + } + + /// Wait for the state to change. + pub fn wait_state_changed(&self) -> impl Future + '_ { + self.state_changed.wait() + } +} + +impl DynBase for SharedNetworks where N: Send {} + +impl NetworksAccess for SharedNetworks +where + N: Networks, +{ + fn access R, R>(&self, f: F) -> R { + self.state.lock(|state| { + let mut networks = state.borrow_mut(); + + let mut instance = SharedNetworksInstance { + networks: &mut *networks, + changed: &self.state_changed, + }; + + f(&mut instance) + }) + } +} + +impl NetChangeNotif for SharedNetworks { + fn wait_changed(&self) -> impl Future { + self.state_changed.wait() + } +} + +/// A wrapper around a `Networks` implementation that notifies on changes to the networks state. +pub struct SharedNetworksInstance<'a> { + networks: &'a mut dyn Networks, + changed: &'a Notification, +} + +impl Networks for SharedNetworksInstance<'_> { + fn max_networks(&self) -> Result { + self.networks.max_networks() + } + + fn networks(&self, f: &mut dyn FnMut(&NetworkInfo) -> Result<(), Error>) -> Result<(), Error> { + self.networks.networks(f) + } + + fn creds( + &self, + network_id: &[u8], + f: &mut dyn FnMut(&WirelessCreds) -> Result<(), Error>, + ) -> Result { + self.networks.creds(network_id, f) + } + + fn next_creds( + &self, + last_network_id: Option<&[u8]>, + f: &mut dyn FnMut(&WirelessCreds) -> Result<(), Error>, + ) -> Result { + self.networks.next_creds(last_network_id, f) + } + + fn enabled(&self) -> Result { + self.networks.enabled() + } + + fn set_enabled(&mut self, enabled: bool) -> Result<(), Error> { + self.networks.set_enabled(enabled)?; + + self.changed.notify(); + + Ok(()) + } + + fn add_or_update(&mut self, creds: &WirelessCreds<'_>) -> Result { + let index = self.networks.add_or_update(creds)?; + + self.changed.notify(); + + Ok(index) + } + + fn reorder(&mut self, index: u8, network_id: &[u8]) -> Result { + let index = self.networks.reorder(index, network_id)?; + + self.changed.notify(); + + Ok(index) + } + + fn remove(&mut self, network_id: &[u8]) -> Result { + let index = self.networks.remove(network_id)?; + + self.changed.notify(); + + Ok(index) + } + + fn persist(&self, buf: &mut [u8]) -> Result, Error> { + let len = self.networks.persist(buf)?; + + Ok(len) + } +} + /// The system implementation of a handler for the Network Commissioning Matter cluster. #[derive(Clone)] -pub struct NetCommHandler<'a, T> { +pub struct NetCommHandler { dataver: Dataver, - networks: &'a dyn Networks, + networks: N, net_ctl: T, } -impl<'a, T> NetCommHandler<'a, T> { +impl NetCommHandler { /// Create a new instance of `NetCommHandler` with the given `Dataver`, `Networks` and `NetCtl`. - pub const fn new(dataver: Dataver, networks: &'a dyn Networks, net_ctl: T) -> Self { + pub const fn new(dataver: Dataver, networks: N, net_ctl: T) -> Self { Self { dataver, networks, @@ -698,8 +860,9 @@ impl<'a, T> NetCommHandler<'a, T> { } } -impl ClusterAsyncHandler for NetCommHandler<'_, T> +impl ClusterAsyncHandler for NetCommHandler where + N: NetworksAccess, T: NetCtl + NetCtlStatus, { const CLUSTER: Cluster<'static> = NetworkType::Ethernet.cluster(); // TODO @@ -712,27 +875,33 @@ where self.dataver.changed(); } - async fn max_networks(&self, _ctx: impl ReadContext) -> Result { - self.networks.max_networks() + fn max_networks(&self, _ctx: impl ReadContext) -> impl Future> { + delayed_ready(move || self.networks.access(|networks| networks.max_networks())) } - async fn connect_max_time_seconds(&self, _ctx: impl ReadContext) -> Result { - Ok(self.net_ctl.connect_max_time_seconds()) + fn connect_max_time_seconds( + &self, + _ctx: impl ReadContext, + ) -> impl Future> { + delayed_ready(move || Ok(self.net_ctl.connect_max_time_seconds())) } - async fn scan_max_time_seconds(&self, _ctx: impl ReadContext) -> Result { - Ok(self.net_ctl.scan_max_time_seconds()) + fn scan_max_time_seconds( + &self, + _ctx: impl ReadContext, + ) -> impl Future> { + delayed_ready(move || Ok(self.net_ctl.scan_max_time_seconds())) } - async fn supported_wi_fi_bands( + fn supported_wi_fi_bands( &self, _ctx: impl ReadContext, builder: ArrayAttributeRead< ToTLVArrayBuilder, ToTLVBuilder, >, - ) -> Result { - match builder { + ) -> impl Future> { + delayed_ready(move || match builder { ArrayAttributeRead::ReadAll(builder) => builder.with(|builder| { let mut builder = Some(builder); @@ -766,100 +935,115 @@ where } } ArrayAttributeRead::ReadNone(builder) => builder.end(), - } + }) } - async fn supported_thread_features( + fn supported_thread_features( &self, _ctx: impl ReadContext, - ) -> Result { - Ok(self.net_ctl.supported_thread_features()) + ) -> impl Future> { + delayed_ready(move || Ok(self.net_ctl.supported_thread_features())) } - async fn thread_version(&self, _ctx: impl ReadContext) -> Result { - Ok(self.net_ctl.thread_version()) + fn thread_version(&self, _ctx: impl ReadContext) -> impl Future> { + delayed_ready(move || Ok(self.net_ctl.thread_version())) } - async fn networks( + fn networks( &self, _ctx: impl ReadContext, builder: ArrayAttributeRead, NetworkInfoStructBuilder

>, - ) -> Result { - match builder { - ArrayAttributeRead::ReadAll(builder) => builder.with(|builder| { - let mut builder = Some(builder); + ) -> impl Future> { + delayed_ready(move || { + self.networks.access(|networks| match builder { + ArrayAttributeRead::ReadAll(builder) => builder.with(|builder| { + let mut builder = Some(builder); - self.networks.networks(&mut |ni| { - builder = Some(ni.read_into(unwrap!(builder.take()).push()?)?); + networks.networks(&mut |ni| { + builder = Some(ni.read_into(unwrap!(builder.take()).push()?)?); - Ok(()) - })?; - - unwrap!(builder.take()).end() - }), - ArrayAttributeRead::ReadOne(index, builder) => { - let mut current = 0; - let mut builder = Some(builder); - let mut parent = None; - - self.networks.networks(&mut |ni| { - if current == index { - parent = Some(ni.read_into(unwrap!(builder.take()))?); - } + Ok(()) + })?; + + unwrap!(builder.take()).end() + }), + ArrayAttributeRead::ReadOne(index, builder) => { + let mut current = 0; + let mut builder = Some(builder); + let mut parent = None; + + networks.networks(&mut |ni| { + if current == index { + parent = Some(ni.read_into(unwrap!(builder.take()))?); + } - current += 1; + current += 1; - Ok(()) - })?; + Ok(()) + })?; - if let Some(parent) = parent { - Ok(parent) - } else { - Err(ErrorCode::ConstraintError.into()) + if let Some(parent) = parent { + Ok(parent) + } else { + Err(ErrorCode::ConstraintError.into()) + } } - } - ArrayAttributeRead::ReadNone(builder) => builder.end(), - } + ArrayAttributeRead::ReadNone(builder) => builder.end(), + }) + }) } - async fn interface_enabled(&self, _ctx: impl ReadContext) -> Result { - self.networks.enabled() + fn interface_enabled( + &self, + _ctx: impl ReadContext, + ) -> impl Future> { + delayed_ready(move || self.networks.access(|networks| networks.enabled())) } - async fn last_networking_status( + fn last_networking_status( &self, _ctx: impl ReadContext, - ) -> Result, Error> { - Ok(Nullable::new(self.net_ctl.last_networking_status()?)) + ) -> impl Future, Error>> { + delayed_ready(move || Ok(Nullable::new(self.net_ctl.last_networking_status()?))) } - async fn last_network_id( + fn last_network_id( &self, _ctx: impl ReadContext, builder: NullableBuilder>, - ) -> Result { - self.net_ctl.last_network_id(|network_id| { - if let Some(network_id) = network_id { - builder.non_null()?.set(Octets::new(network_id)) - } else { - builder.null() - } + ) -> impl Future> { + delayed_ready(move || { + self.net_ctl.last_network_id(|network_id| { + if let Some(network_id) = network_id { + builder.non_null()?.set(Octets::new(network_id)) + } else { + builder.null() + } + }) }) } - async fn last_connect_error_value( + fn last_connect_error_value( &self, _ctx: impl ReadContext, - ) -> Result, Error> { - Ok(Nullable::new(self.net_ctl.last_connect_error_value()?)) + ) -> impl Future, Error>> { + delayed_ready(move || Ok(Nullable::new(self.net_ctl.last_connect_error_value()?))) } async fn set_interface_enabled( &self, - _ctx: impl WriteContext, + ctx: impl WriteContext, value: bool, ) -> Result<(), Error> { - self.networks.set_enabled(value) + let mut persist = Persist::new(ctx.kv()); + + self.networks.access(|networks| { + networks.set_enabled(value)?; + + persist.store(NETWORKS_KEY, |buf| networks.persist(buf)) + })?; + + persist.run() } async fn handle_scan_networks( @@ -971,43 +1155,71 @@ where async fn handle_add_or_update_wi_fi_network( &self, - _ctx: impl InvokeContext, + ctx: impl InvokeContext, request: AddOrUpdateWiFiNetworkRequest<'_>, response: NetworkConfigResponseBuilder

, ) -> Result { - let (status, _, index) = NetworkCommissioningStatusEnum::map(self.networks.add_or_update( - &WirelessCreds::Wifi { - ssid: request.ssid()?.0, - pass: request.credentials()?.0, - }, - ))?; + let mut persist = Persist::new(ctx.kv()); + + let (status, _, index) = + NetworkCommissioningStatusEnum::map(self.networks.access(|networks| { + let index = networks.add_or_update(&WirelessCreds::Wifi { + ssid: request.ssid()?.0, + pass: request.credentials()?.0, + })?; + + persist.store(NETWORKS_KEY, |buf| networks.persist(buf))?; + + Ok(index) + }))?; + + persist.run()?; status.read_into(index, response) } async fn handle_add_or_update_thread_network( &self, - _ctx: impl InvokeContext, + ctx: impl InvokeContext, request: AddOrUpdateThreadNetworkRequest<'_>, response: NetworkConfigResponseBuilder

, ) -> Result { - let (status, _, index) = NetworkCommissioningStatusEnum::map(self.networks.add_or_update( - &WirelessCreds::Thread { - dataset_tlv: request.operational_dataset()?.0, - }, - ))?; + let mut persist = Persist::new(ctx.kv()); + + let (status, _, index) = + NetworkCommissioningStatusEnum::map(self.networks.access(|networks| { + let index = networks.add_or_update(&WirelessCreds::Thread { + dataset_tlv: request.operational_dataset()?.0, + })?; + + persist.store(NETWORKS_KEY, |buf| networks.persist(buf))?; + + Ok(index) + }))?; + + persist.run()?; status.read_into(index, response) } async fn handle_remove_network( &self, - _ctx: impl InvokeContext, + ctx: impl InvokeContext, request: RemoveNetworkRequest<'_>, response: NetworkConfigResponseBuilder

, ) -> Result { + let mut persist = Persist::new(ctx.kv()); + let (status, _, index) = - NetworkCommissioningStatusEnum::map(self.networks.remove(request.network_id()?.0))?; + NetworkCommissioningStatusEnum::map(self.networks.access(|networks| { + let index = networks.remove(request.network_id()?.0)?; + + persist.store(NETWORKS_KEY, |buf| networks.persist(buf))?; + + Ok(index) + }))?; + + persist.run()?; status.read_into(index, response) } @@ -1027,24 +1239,25 @@ where let dataset_buf = response.writer().available_space(); let mut dataset_len = 0; - let (mut status, mut err_code, _) = NetworkCommissioningStatusEnum::map( - self.networks.creds(request.network_id()?.0, &mut |creds| { - let WirelessCreds::Thread { dataset_tlv } = creds else { - error!("Thread creds expected"); - return Err(ErrorCode::InvalidAction.into()); - }; + let (mut status, mut err_code, _) = + NetworkCommissioningStatusEnum::map(self.networks.access(|networks| { + networks.creds(request.network_id()?.0, &mut |creds| { + let WirelessCreds::Thread { dataset_tlv } = creds else { + error!("Thread creds expected"); + return Err(ErrorCode::InvalidAction.into()); + }; - if dataset_tlv.len() > dataset_buf.len() { - error!("Dataset too large"); - return Err(ErrorCode::ConstraintError.into()); - } + if dataset_tlv.len() > dataset_buf.len() { + error!("Dataset too large"); + return Err(ErrorCode::ConstraintError.into()); + } - dataset_buf[..dataset_tlv.len()].copy_from_slice(dataset_tlv); - dataset_len = dataset_tlv.len(); + dataset_buf[..dataset_tlv.len()].copy_from_slice(dataset_tlv); + dataset_len = dataset_tlv.len(); - Ok(()) - }), - )?; + Ok(()) + }) + }))?; if matches!(status, NetworkCommissioningStatusEnum::Success) { (status, err_code, _) = NetworkCommissioningStatusEnum::map_ctl( @@ -1064,31 +1277,32 @@ where let mut ssid_len = 0; let mut pass_len = 0; - let (mut status, mut err_code, _) = NetworkCommissioningStatusEnum::map( - self.networks.creds(request.network_id()?.0, &mut |creds| { - let WirelessCreds::Wifi { ssid, pass } = creds else { - error!("Wifi creds expected"); - return Err(ErrorCode::InvalidAction.into()); - }; - - if ssid.len() > ssid_buf.len() { - error!("SSID too large"); - return Err(ErrorCode::ConstraintError.into()); - } - - if pass.len() > pass_buf.len() { - error!("Password too large"); - return Err(ErrorCode::ConstraintError.into()); - } - - ssid_buf[..ssid.len()].copy_from_slice(ssid); - ssid_len = ssid.len(); - pass_buf[..pass.len()].copy_from_slice(pass); - pass_len = pass.len(); - - Ok(()) - }), - )?; + let (mut status, mut err_code, _) = + NetworkCommissioningStatusEnum::map(self.networks.access(|networks| { + networks.creds(request.network_id()?.0, &mut |creds| { + let WirelessCreds::Wifi { ssid, pass } = creds else { + error!("Wifi creds expected"); + return Err(ErrorCode::InvalidAction.into()); + }; + + if ssid.len() > ssid_buf.len() { + error!("SSID too large"); + return Err(ErrorCode::ConstraintError.into()); + } + + if pass.len() > pass_buf.len() { + error!("Password too large"); + return Err(ErrorCode::ConstraintError.into()); + } + + ssid_buf[..ssid.len()].copy_from_slice(ssid); + ssid_len = ssid.len(); + pass_buf[..pass.len()].copy_from_slice(pass); + pass_len = pass.len(); + + Ok(()) + }) + }))?; if matches!(status, NetworkCommissioningStatusEnum::Success) { (status, err_code, _) = NetworkCommissioningStatusEnum::map_ctl( @@ -1117,29 +1331,38 @@ where async fn handle_reorder_network( &self, - _ctx: impl InvokeContext, + ctx: impl InvokeContext, request: ReorderNetworkRequest<'_>, response: NetworkConfigResponseBuilder

, ) -> Result { - let (status, _, index) = NetworkCommissioningStatusEnum::map( - self.networks - .reorder(request.network_index()? as _, request.network_id()?.0), - )?; + let mut persist = Persist::new(ctx.kv()); + + let (status, _, index) = + NetworkCommissioningStatusEnum::map(self.networks.access(|networks| { + let index = + networks.reorder(request.network_index()? as _, request.network_id()?.0)?; + + persist.store(NETWORKS_KEY, |buf| networks.persist(buf))?; + + Ok(index) + }))?; + + persist.run()?; status.read_into(index, response) } - async fn handle_query_identity( + fn handle_query_identity( &self, _ctx: impl InvokeContext, _request: QueryIdentityRequest<'_>, _response: QueryIdentityResponseBuilder

, - ) -> Result { - Err(ErrorCode::InvalidAction.into()) + ) -> impl Future> { + ready(Err(ErrorCode::InvalidAction.into())) } } -impl Debug for NetCommHandler<'_, ()> { +impl Debug for NetCommHandler { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { f.debug_struct("NetCommHandler") .field("dataver", &self.dataver.get()) @@ -1148,7 +1371,7 @@ impl Debug for NetCommHandler<'_, ()> { } #[cfg(feature = "defmt")] -impl defmt::Format for NetCommHandler<'_, ()> { +impl defmt::Format for NetCommHandler { fn format(&self, f: defmt::Formatter) { defmt::write!(f, "NetCommHandler {{ dataver: {} }}", self.dataver.get()); } diff --git a/rs-matter/src/dm/clusters/noc.rs b/rs-matter/src/dm/clusters/noc.rs index 1e4a46019..a06b26dd5 100644 --- a/rs-matter/src/dm/clusters/noc.rs +++ b/rs-matter/src/dm/clusters/noc.rs @@ -26,7 +26,7 @@ use crate::crypto::{CanonPkcSignature, Crypto, SigningSecretKey}; use crate::dm::clusters::dev_att::DeviceAttestation; use crate::dm::{ArrayAttributeRead, Cluster, Dataver, InvokeContext, ReadContext}; use crate::error::{Error, ErrorCode}; -use crate::fabric::{Fabric, MAX_FABRICS}; +use crate::fabric::{Fabric, FabricPersist, MAX_FABRICS}; use crate::tlv::{ Nullable, Octets, OctetsArrayBuilder, OctetsBuilder, TLVBuilder, TLVBuilderParent, TLVElement, TLVTag, TLVWrite, @@ -432,10 +432,12 @@ impl ClusterHandler for NocHandler { let buf = response.writer().available_space(); + let mut persist = FabricPersist::new(ctx.kv()); + let status = NodeOperationalCertStatusEnum::map(ctx.exchange().with_state(|state| { let sess = ctx.exchange().id().session(&mut state.sessions); - let fab_idx = state.failsafe.add_noc( + let fabric = state.failsafe.add_noc( ctx.crypto(), &mut state.fabrics, sess.get_session_mode(), @@ -448,6 +450,9 @@ impl ClusterHandler for NocHandler { &mut || ctx.exchange().matter().notify_mdns(), )?; + persist.store(fabric)?; + + let fab_idx = fabric.fab_idx(); let succeeded = Cell::new(false); let _fab_guard = scopeguard::guard(fab_idx, |fab_idx| { @@ -458,6 +463,8 @@ impl ClusterHandler for NocHandler { unwrap!(state .fabrics .remove(fab_idx, &mut || ctx.exchange().matter().notify_mdns())); + + persist.remove(fab_idx).ok(); // Best effort } }); @@ -466,12 +473,13 @@ impl ClusterHandler for NocHandler { } succeeded.set(true); - added_fab_idx = Some(fab_idx.get()); Ok(()) }))?; + persist.run()?; + response .status_code(status)? .fabric_index(added_fab_idx)? @@ -495,10 +503,12 @@ impl ClusterHandler for NocHandler { let buf = response.writer().available_space(); + let mut persist = FabricPersist::new(ctx.kv()); + let status = NodeOperationalCertStatusEnum::map(ctx.exchange().with_state(|state| { let sess = ctx.exchange().id().session(&mut state.sessions); - state.failsafe.update_noc( + let fabric = state.failsafe.update_noc( ctx.crypto(), &mut state.fabrics, sess.get_session_mode(), @@ -508,9 +518,13 @@ impl ClusterHandler for NocHandler { &mut || ctx.exchange().matter().notify_mdns(), )?; + persist.store(fabric)?; + Ok(()) }))?; + persist.run()?; + response .status_code(status)? .fabric_index(Some(ctx.cmd().fab_idx))? @@ -528,6 +542,8 @@ impl ClusterHandler for NocHandler { let mut updated_fab_idx = None; + let mut persist = FabricPersist::new(ctx.kv()); + let status = NodeOperationalCertStatusEnum::map(ctx.exchange().with_state(|state| { let sess = ctx.exchange().id().session(&mut state.sessions); @@ -535,9 +551,7 @@ impl ClusterHandler for NocHandler { return Err(ErrorCode::GennCommInvalidAuthentication.into()); }; - updated_fab_idx = Some(fab_idx.get()); - - state + let fabric = state .fabrics .update_label(*fab_idx, request.label()?) .map_err(|e| { @@ -546,9 +560,16 @@ impl ClusterHandler for NocHandler { } else { e } - }) + })?; + + persist.store(fabric)?; + updated_fab_idx = Some(fabric.fab_idx().get()); + + Ok(()) }))?; + persist.run()?; + response .status_code(status)? .fabric_index(updated_fab_idx)? @@ -566,10 +587,12 @@ impl ClusterHandler for NocHandler { let fab_idx = NonZeroU8::new(request.fabric_index()?).ok_or(ErrorCode::ConstraintError)?; - ctx.exchange().with_state(|state| { + let mut persist = FabricPersist::new(ctx.kv()); + + let status = ctx.exchange().with_state(|state| { let sess = ctx.exchange().id().session(&mut state.sessions); - let status = if state + if state .fabrics .remove(fab_idx, &mut || ctx.exchange().matter().notify_mdns()) .is_ok() @@ -584,31 +607,29 @@ impl ClusterHandler for NocHandler { // If `expire_sess_id` is Some, the session will be expired instead of removed. state.sessions.remove_for_fabric(fab_idx, expire_sess_id); - // Notify that the fabrics need to be persisted - // We need to explicitly do this because if the fabric being removed - // is the one on which the session is running, the session will be removed - // and the response will fail - ctx.exchange().matter().notify_persist(); - // Notify that a session was removed ctx.exchange().matter().session_removed.notify(); // Note that since we might have removed our own session, the exchange // will terminate with a "NoSession" error, but that's OK and handled properly + persist.remove(fab_idx)?; + info!("Removed operational fabric with local index {}", fab_idx); - NodeOperationalCertStatusEnum::OK + Ok(NodeOperationalCertStatusEnum::OK) } else { - NodeOperationalCertStatusEnum::InvalidFabricIndex - }; + Ok(NodeOperationalCertStatusEnum::InvalidFabricIndex) + } + })?; - response - .status_code(status)? - .fabric_index(Some(fab_idx.get()))? - .debug_text(None)? - .end() - }) + persist.run()?; + + response + .status_code(status)? + .fabric_index(Some(fab_idx.get()))? + .debug_text(None)? + .end() } fn handle_add_trusted_root_certificate( diff --git a/rs-matter/src/dm/clusters/on_off.rs b/rs-matter/src/dm/clusters/on_off.rs index 0c9d3cee2..b91634bba 100644 --- a/rs-matter/src/dm/clusters/on_off.rs +++ b/rs-matter/src/dm/clusters/on_off.rs @@ -44,6 +44,7 @@ pub use crate::dm::clusters::decl::on_off::*; use crate::tlv::Nullable; use crate::utils::cell::RefCell; +use crate::utils::future::delayed_ready; use crate::utils::sync::blocking::Mutex; use crate::utils::sync::Signal; @@ -742,186 +743,214 @@ impl ClusterAsyncHandler for OnOffHandler< } // Attribute accessors - async fn on_off(&self, _ctx: impl ReadContext) -> Result { - Ok(self.hooks.on_off()) + fn on_off(&self, _ctx: impl ReadContext) -> impl Future> { + delayed_ready(move || Ok(self.hooks.on_off())) } - async fn global_scene_control(&self, _ctx: impl ReadContext) -> Result { - Ok(self.with_state(|state| state.global_scene_control)) + fn global_scene_control( + &self, + _ctx: impl ReadContext, + ) -> impl Future> { + delayed_ready(move || Ok(self.with_state(|state| state.global_scene_control))) } - async fn on_time(&self, _ctx: impl ReadContext) -> Result { - Ok(self.with_state(|state| state.on_time)) + fn on_time(&self, _ctx: impl ReadContext) -> impl Future> { + delayed_ready(move || Ok(self.with_state(|state| state.on_time))) } - async fn off_wait_time(&self, _ctx: impl ReadContext) -> Result { - Ok(self.with_state(|state| state.off_wait_time)) + fn off_wait_time(&self, _ctx: impl ReadContext) -> impl Future> { + delayed_ready(move || Ok(self.with_state(|state| state.off_wait_time))) } - async fn start_up_on_off( + fn start_up_on_off( &self, _ctx: impl ReadContext, - ) -> Result, Error> { - Ok(self.hooks.start_up_on_off()) + ) -> impl Future, Error>> { + delayed_ready(move || Ok(self.hooks.start_up_on_off())) } - async fn set_on_time(&self, ctx: impl WriteContext, value: u16) -> Result<(), Error> { - self.with_state(|state| { - state.on_time = value; - self.dataver_changed(); - ctx.notify_changed(); - Ok(()) + fn set_on_time( + &self, + ctx: impl WriteContext, + value: u16, + ) -> impl Future> { + delayed_ready(move || { + self.with_state(|state| { + state.on_time = value; + self.dataver_changed(); + ctx.notify_changed(); + Ok(()) + }) }) } - async fn set_off_wait_time(&self, ctx: impl WriteContext, value: u16) -> Result<(), Error> { - self.with_state(|state| { - state.off_wait_time = value; - self.dataver_changed(); - ctx.notify_changed(); - Ok(()) + fn set_off_wait_time( + &self, + ctx: impl WriteContext, + value: u16, + ) -> impl Future> { + delayed_ready(move || { + self.with_state(|state| { + state.off_wait_time = value; + self.dataver_changed(); + ctx.notify_changed(); + Ok(()) + }) }) } - async fn set_start_up_on_off( + fn set_start_up_on_off( &self, ctx: impl WriteContext, value: Nullable, - ) -> Result<(), Error> { - self.hooks.set_start_up_on_off(value)?; - self.dataver_changed(); - ctx.notify_changed(); - Ok(()) + ) -> impl Future> { + delayed_ready(move || { + self.hooks.set_start_up_on_off(value)?; + self.dataver_changed(); + ctx.notify_changed(); + Ok(()) + }) } // Commands - async fn handle_off(&self, _ctx: impl InvokeContext) -> Result<(), Error> { - self.state_change_signal.signal(OnOffCommand::Off); - - Ok(()) + fn handle_off(&self, _ctx: impl InvokeContext) -> impl Future> { + delayed_ready(move || { + self.state_change_signal.signal(OnOffCommand::Off); + Ok(()) + }) } - async fn handle_on(&self, _ctx: impl InvokeContext) -> Result<(), Error> { - self.state_change_signal.signal(OnOffCommand::On); - - Ok(()) + fn handle_on(&self, _ctx: impl InvokeContext) -> impl Future> { + delayed_ready(move || { + self.state_change_signal.signal(OnOffCommand::On); + Ok(()) + }) } - async fn handle_toggle(&self, _ctx: impl InvokeContext) -> Result<(), Error> { - self.state_change_signal.signal(OnOffCommand::Toggle); - - Ok(()) + fn handle_toggle(&self, _ctx: impl InvokeContext) -> impl Future> { + delayed_ready(move || { + self.state_change_signal.signal(OnOffCommand::Toggle); + Ok(()) + }) } - async fn handle_off_with_effect( + fn handle_off_with_effect( &self, _ctx: impl InvokeContext, request: OffWithEffectRequest<'_>, - ) -> Result<(), Error> { - if !Self::supports_feature(on_off::Feature::LIGHTING.bits()) { - // This error is currently mapped to the IM status UnsupportedCommand. - return Err(ErrorCode::CommandNotFound.into()); - } + ) -> impl Future> { + delayed_ready(move || { + if !Self::supports_feature(on_off::Feature::LIGHTING.bits()) { + // This error is currently mapped to the IM status UnsupportedCommand. + return Err(ErrorCode::CommandNotFound.into()); + } - let effect_variant = match request.effect_identifier()? { - EffectIdentifierEnum::DelayedAllOff => { - match request.effect_variant()? { - // todo Impl TryFrom for DelayedAllOffEffectVariantEnum and remove this match. - 0 => EffectVariantEnum::DelayedAllOff( - DelayedAllOffEffectVariantEnum::DelayedOffFastFade, - ), - 1 => EffectVariantEnum::DelayedAllOff(DelayedAllOffEffectVariantEnum::NoFade), - 2 => EffectVariantEnum::DelayedAllOff( - DelayedAllOffEffectVariantEnum::DelayedOffSlowFade, - ), - _ => return Err(ErrorCode::Failure.into()), + let effect_variant = match request.effect_identifier()? { + EffectIdentifierEnum::DelayedAllOff => { + match request.effect_variant()? { + // todo Impl TryFrom for DelayedAllOffEffectVariantEnum and remove this match. + 0 => EffectVariantEnum::DelayedAllOff( + DelayedAllOffEffectVariantEnum::DelayedOffFastFade, + ), + 1 => { + EffectVariantEnum::DelayedAllOff(DelayedAllOffEffectVariantEnum::NoFade) + } + 2 => EffectVariantEnum::DelayedAllOff( + DelayedAllOffEffectVariantEnum::DelayedOffSlowFade, + ), + _ => return Err(ErrorCode::Failure.into()), + } } - } - EffectIdentifierEnum::DyingLight => { - match request.effect_variant()? { - // todo Impl TryFrom for DyingLightEffectVariantEnum and remove this match. - 0 => EffectVariantEnum::DyingLight( - DyingLightEffectVariantEnum::DyingLightFadeOff, - ), - _ => return Err(ErrorCode::Failure.into()), + EffectIdentifierEnum::DyingLight => { + match request.effect_variant()? { + // todo Impl TryFrom for DyingLightEffectVariantEnum and remove this match. + 0 => EffectVariantEnum::DyingLight( + DyingLightEffectVariantEnum::DyingLightFadeOff, + ), + _ => return Err(ErrorCode::Failure.into()), + } } - } - }; + }; - self.state_change_signal - .signal(OnOffCommand::OffWithEffect(effect_variant)); + self.state_change_signal + .signal(OnOffCommand::OffWithEffect(effect_variant)); - Ok(()) + Ok(()) + }) } - async fn handle_on_with_recall_global_scene( + fn handle_on_with_recall_global_scene( &self, _ctx: impl InvokeContext, - ) -> Result<(), Error> { - self.with_state(|state| { - // 1.5.7.5.1. Effect on Receipt - // On receipt of the OnWithRecallGlobalScene command, if the GlobalSceneControl attribute is equal - // to TRUE, the server SHALL discard the command. - if state.global_scene_control { - return Ok(()); - } - - // If the GlobalSceneControl attribute is equal to FALSE, the Scene cluster server on the same endpoint - // SHALL recall its global scene, updating the OnOff attribute accordingly. The OnOff server SHALL - // then set the GlobalSceneControl attribute to TRUE. - // Additionally, when the OnTime and OffWaitTime attributes are both supported, if the value of the - // OnTime attribute is equal to 0, the server SHALL set the OffWaitTime attribute to 0. - // todo Implement the above statement once the Scene cluster is implemented. - // self.set_on(false); + ) -> impl Future> { + delayed_ready(move || { + self.with_state(|state| { + // 1.5.7.5.1. Effect on Receipt + // On receipt of the OnWithRecallGlobalScene command, if the GlobalSceneControl attribute is equal + // to TRUE, the server SHALL discard the command. + if state.global_scene_control { + return Ok(()); + } - // This error is currently mapped to the IM status UnsupportedCommand. - Err(ErrorCode::CommandNotFound.into()) + // If the GlobalSceneControl attribute is equal to FALSE, the Scene cluster server on the same endpoint + // SHALL recall its global scene, updating the OnOff attribute accordingly. The OnOff server SHALL + // then set the GlobalSceneControl attribute to TRUE. + // Additionally, when the OnTime and OffWaitTime attributes are both supported, if the value of the + // OnTime attribute is equal to 0, the server SHALL set the OffWaitTime attribute to 0. + // todo Implement the above statement once the Scene cluster is implemented. + // self.set_on(false); + + // This error is currently mapped to the IM status UnsupportedCommand. + Err(ErrorCode::CommandNotFound.into()) + }) }) } - async fn handle_on_with_timed_off( + fn handle_on_with_timed_off( &self, ctx: impl InvokeContext, request: OnWithTimedOffRequest<'_>, - ) -> Result<(), Error> { - // 1.5.7.6.4. Effect on Receipt - // On receipt of this command, if the AcceptOnlyWhenOn sub-field of the OnOffControl field is set to 1, - // and the value of the OnOff attribute is equal to FALSE, the command SHALL be discarded. - if request - .on_off_control()? - .contains(OnOffControlBitmap::ACCEPT_ONLY_WHEN_ON) - && !self.hooks.on_off() - { - return Ok(()); - } - - self.with_state(|state| { - // If the value of the OffWaitTime attribute is greater than zero and the value of the OnOff attribute is - // equal to FALSE, then the server SHALL set the OffWaitTime attribute to the minimum of the - // OffWaitTime attribute and the value specified in the OffWaitTime field. - if state.off_wait_time > 0 && !self.hooks.on_off() { - state.off_wait_time = state.off_wait_time.min(request.off_wait_time()?); - } - // In all other cases, the server SHALL set the OnTime attribute to the maximum of the OnTime - // attribute and the value specified in the OnTime field, set the OffWaitTime attribute to the value - // specified in the OffWaitTime field and set the OnOff attribute to TRUE. - else { - state.on_time = state.on_time.max(request.on_time()?); - state.off_wait_time = request.off_wait_time()?; - self.set_on(state, false, &ctx); - } - - // If the values of the OnTime and OffWaitTime attributes are both not equal to 0xFFFF, the server - // SHALL then update these attributes every 1/10th second until both the OnTime and OffWaitTime - // attributes are equal to 0, as follows: - if state.on_time == 0xFFFF && state.off_wait_time == 0xFFFF { + ) -> impl Future> { + delayed_ready(move || { + // 1.5.7.6.4. Effect on Receipt + // On receipt of this command, if the AcceptOnlyWhenOn sub-field of the OnOffControl field is set to 1, + // and the value of the OnOff attribute is equal to FALSE, the command SHALL be discarded. + if request + .on_off_control()? + .contains(OnOffControlBitmap::ACCEPT_ONLY_WHEN_ON) + && !self.hooks.on_off() + { return Ok(()); } - self.state_change_signal - .signal(OnOffCommand::OnWithTimedOff); + self.with_state(|state| { + // If the value of the OffWaitTime attribute is greater than zero and the value of the OnOff attribute is + // equal to FALSE, then the server SHALL set the OffWaitTime attribute to the minimum of the + // OffWaitTime attribute and the value specified in the OffWaitTime field. + if state.off_wait_time > 0 && !self.hooks.on_off() { + state.off_wait_time = state.off_wait_time.min(request.off_wait_time()?); + } + // In all other cases, the server SHALL set the OnTime attribute to the maximum of the OnTime + // attribute and the value specified in the OnTime field, set the OffWaitTime attribute to the value + // specified in the OffWaitTime field and set the OnOff attribute to TRUE. + else { + state.on_time = state.on_time.max(request.on_time()?); + state.off_wait_time = request.off_wait_time()?; + self.set_on(state, false, &ctx); + } - Ok(()) + // If the values of the OnTime and OffWaitTime attributes are both not equal to 0xFFFF, the server + // SHALL then update these attributes every 1/10th second until both the OnTime and OffWaitTime + // attributes are equal to 0, as follows: + if state.on_time == 0xFFFF && state.off_wait_time == 0xFFFF { + return Ok(()); + } + + self.state_change_signal + .signal(OnOffCommand::OnWithTimedOff); + + Ok(()) + }) }) } } diff --git a/rs-matter/src/dm/endpoints.rs b/rs-matter/src/dm/endpoints.rs index 09ef6d25a..48e6528c6 100644 --- a/rs-matter/src/dm/endpoints.rs +++ b/rs-matter/src/dm/endpoints.rs @@ -17,6 +17,7 @@ use rand_core::RngCore; +use crate::dm::clusters::net_comm::{NetworksAccess, SharedNetworks}; use crate::{devices, handler_chain_type}; use super::clusters::acl::{self, AclHandler, ClusterHandler as _}; @@ -29,7 +30,7 @@ use super::clusters::gen_diag::{self, ClusterHandler as _, GenDiag, GenDiagHandl use super::clusters::groups::{self, ClusterHandler as _, GroupsHandler}; use super::clusters::grp_key_mgmt::{self, ClusterHandler as _, GrpKeyMgmtHandler}; use super::clusters::net_comm::{ - self, ClusterAsyncHandler as _, NetCommHandler, NetCtl, NetCtlStatus, NetworkType, Networks, + self, ClusterAsyncHandler as _, NetCommHandler, NetCtl, NetCtlStatus, NetworkType, }; use super::clusters::noc::{self, ClusterHandler as _, NocHandler}; use super::clusters::thread_diag::{self, ClusterHandler as _, ThreadDiag, ThreadDiagHandler}; @@ -58,23 +59,23 @@ pub const fn root_endpoint_with_groups(net_type: NetworkType) -> Endpoint<'stati /// A type alias for the handler chain returned by `with_eth()`. pub type EthHandler<'a, H> = NetHandler< 'a, - net_comm::HandlerAsyncAdaptor>, + net_comm::HandlerAsyncAdaptor>, EthNetCtl>>, Async>, H, >; /// A type alias for the handler chain returned by `with_wifi()`. -pub type WifiHandler<'a, T, H> = NetHandler< +pub type WifiHandler<'a, N, T, H> = NetHandler< 'a, - net_comm::HandlerAsyncAdaptor>, + net_comm::HandlerAsyncAdaptor>, Async>>, H, >; /// A type alias for the handler chain returned by `with_thread()`. -pub type ThreadHandler<'a, T, H> = NetHandler< +pub type ThreadHandler<'a, N, T, H> = NetHandler< 'a, - net_comm::HandlerAsyncAdaptor>, + net_comm::HandlerAsyncAdaptor>, Async>>, H, >; @@ -127,8 +128,6 @@ pub fn with_eth<'a, R: RngCore, H>( mut rand: R, handler: H, ) -> EthHandler<'a, H> { - const NETWORK: EthNetwork<'static> = EthNetwork::new("eth"); - ChainedHandler::new( EpClMatcher::new(Some(ROOT_ENDPOINT_ID), Some(GenDiagHandler::CLUSTER.id)), Async(GenDiagHandler::new(Dataver::new_rand(&mut rand), gen_diag, netif_diag).adapt()), @@ -141,9 +140,14 @@ pub fn with_eth<'a, R: RngCore, H>( .chain( EpClMatcher::new( Some(ROOT_ENDPOINT_ID), - Some(NetCommHandler::::CLUSTER.id), + Some(NetCommHandler::>, EthNetCtl>::CLUSTER.id), ), - NetCommHandler::new(Dataver::new_rand(&mut rand), &NETWORK, EthNetCtl).adapt(), + NetCommHandler::new( + Dataver::new_rand(&mut rand), + SharedNetworks::new(EthNetwork::new("eth")), + EthNetCtl, + ) + .adapt(), ) } @@ -161,15 +165,16 @@ pub fn with_eth<'a, R: RngCore, H>( /// - `networks`: The `Networks` implementation. /// - `rand`: A random number generator. /// - `handler`: The handler to be decorated. -pub fn with_wifi<'a, R: RngCore, T, H>( +pub fn with_wifi<'a, R: RngCore, N, T, H>( gen_diag: &'a dyn GenDiag, netif_diag: &'a dyn NetifDiag, + networks: N, net_ctl: &'a T, - networks: &'a dyn Networks, mut rand: R, handler: H, -) -> WifiHandler<'a, &'a T, H> +) -> WifiHandler<'a, N, &'a T, H> where + N: NetworksAccess, T: NetCtl + NetCtlStatus + WifiDiag, { ChainedHandler::new( @@ -184,7 +189,7 @@ where .chain( EpClMatcher::new( Some(ROOT_ENDPOINT_ID), - Some(NetCommHandler::::CLUSTER.id), + Some(NetCommHandler::::CLUSTER.id), ), NetCommHandler::new(Dataver::new_rand(&mut rand), networks, net_ctl).adapt(), ) @@ -203,15 +208,16 @@ where /// - `net_ctl`: The `NetCtl` implementation. /// - `networks`: The `Networks` implementation. /// - `rand`: A random number generator. -pub fn with_thread<'a, R: RngCore, T, H>( +pub fn with_thread<'a, R: RngCore, N, T, H>( gen_diag: &'a dyn GenDiag, netif_diag: &'a dyn NetifDiag, + networks: N, net_ctl: &'a T, - networks: &'a dyn Networks, mut rand: R, handler: H, -) -> ThreadHandler<'a, &'a T, H> +) -> ThreadHandler<'a, N, &'a T, H> where + N: NetworksAccess, T: NetCtl + NetCtlStatus + ThreadDiag, { ChainedHandler::new( @@ -226,7 +232,7 @@ where .chain( EpClMatcher::new( Some(ROOT_ENDPOINT_ID), - Some(NetCommHandler::::CLUSTER.id), + Some(NetCommHandler::::CLUSTER.id), ), NetCommHandler::new(Dataver::new_rand(&mut rand), networks, net_ctl).adapt(), ) diff --git a/rs-matter/src/dm/events.rs b/rs-matter/src/dm/events.rs index a3b4182cc..2aa5ce50a 100644 --- a/rs-matter/src/dm/events.rs +++ b/rs-matter/src/dm/events.rs @@ -14,19 +14,20 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -use core::cell::Cell; + use core::fmt::Debug; use crate::error::{Error, ErrorCode}; use crate::im::{EventData, EventDataTag, EventDataTimestamp, EventPath, EventRespTag}; +use crate::persist::{KvBlobStore, KvBlobStoreAccess, Persist, EVENT_EPOCH_KEY}; use crate::tlv::{ FromTLV, TLVElement, TLVSequence, TLVSequenceIter, TLVTag, TLVWrite, TagType, ToTLV, }; +use crate::utils::cell::RefCell; use crate::utils::epoch::Epoch; use crate::utils::init::{init, Init}; use crate::utils::storage::WriteBuf; use crate::utils::sync::blocking; -use crate::utils::sync::Notification; use crate::utils::sync::{IfMutex, IfMutexGuard}; // TODO we currently only have this singular config to set the size, but the events are stored in three "tiered" buffers, and @@ -64,7 +65,44 @@ pub(crate) struct PersistedState { /// emitted without a new flash write. When we approach this boundary /// (within `EVENT_ID_COUNTER_PERSIST_AHEAD`), we bump it by another epoch. pub(crate) event_epoch_end: u64, - pub(crate) changed: bool, +} + +impl PersistedState { + const fn new() -> Self { + Self { + next_event_no: 0, + event_epoch_end: 0, + } + } + + fn init() -> impl Init { + init!(Self { + next_event_no: 0, + event_epoch_end: 0, + }) + } + + fn reset(&mut self) { + self.next_event_no = 0; + self.event_epoch_end = 0; + } + + /// Restore events from previously persisted state. + fn load(&mut self, data: &[u8]) -> Result<(), Error> { + let epoch_end = TLVElement::new(data).u64()?; + + self.next_event_no = epoch_end; + self.event_epoch_end = epoch_end.saturating_add(EVENT_NO_EPOCH_SIZE); + + Ok(()) + } + + /// Store events persistence state into a byte slice. + fn store(&mut self, buf: &mut [u8]) -> Result { + let mut wb = WriteBuf::new(buf); + wb.u64(&TLVTag::Anonymous, self.event_epoch_end)?; + Ok(wb.get_tail()) + } } /// This is the event queue system, it lets you publish Matter Events into a priority queue, @@ -82,25 +120,17 @@ pub(crate) struct PersistedState { /// If your application emits no events you can disable this subsystem using the NO_EVENTS constant. pub struct Events { state: IfMutex>, + pub(crate) persisted_state: blocking::Mutex>, epoch: Epoch, - pub(crate) persisted_state: blocking::Mutex>, - persist_notification: Notification, } impl Events { - const PERSIST_INIT: PersistedState = PersistedState { - next_event_no: 0, - event_epoch_end: 0, - changed: false, - }; - #[inline(always)] pub const fn new(epoch: Epoch) -> Self { Self { state: IfMutex::new(EventsInner::new()), + persisted_state: blocking::Mutex::new(RefCell::new(PersistedState::new())), epoch, - persisted_state: blocking::Mutex::new(Cell::new(Self::PERSIST_INIT)), - persist_notification: Notification::new(), } } @@ -114,46 +144,55 @@ impl Events { pub fn init(epoch: Epoch) -> impl Init { init!(Self { state <- IfMutex::init(EventsInner::init()), + persisted_state <- blocking::Mutex::init(RefCell::init(PersistedState::init())), epoch, - persisted_state: blocking::Mutex::new(Cell::new(Self::PERSIST_INIT)), - persist_notification: Notification::new(), }) } + pub fn reset(&mut self) { + self.state.get_mut().reset(); + self.persisted_state.get_mut().get_mut().reset(); + } + /// Push a new event into the event queue, making it visible to any subscribers. /// NOTE: This API is unstable and may change, for instance it currently requires knowing the tag number for writing data /// /// To write event data you use the provided EventQueueWriter and write the tag EventDataTag::Data. - pub async fn push( + pub async fn push( &self, path: EventPath, priority: u8, + kv: S, data: impl FnOnce(&mut EventQueueWriter) -> Result<(), Error>, - ) -> Result<(), Error> { - let (event_no, notify) = self.persisted_state.lock(|cell| { - let mut state = cell.get(); - let event_no = state.next_event_no; - state.next_event_no += 1; - - let notify = if state.next_event_no - >= state - .event_epoch_end - .saturating_sub(EVENT_ID_COUNTER_PERSIST_AHEAD) - { - state.event_epoch_end = state.next_event_no.saturating_add(EVENT_NO_EPOCH_SIZE); - state.changed = true; - true - } else { - false - }; - - cell.set(state); - (event_no, notify) - }); - - if notify { - self.persist_notification.notify(); - } + ) -> Result<(), Error> + where + S: KvBlobStoreAccess, + { + let event_no = { + let mut persist = Persist::new(kv); + + let event_no = self.persisted_state.lock(|cell| { + let mut state = cell.borrow_mut(); + let event_no = state.next_event_no; + state.next_event_no += 1; + + if state.next_event_no + >= state + .event_epoch_end + .saturating_sub(EVENT_ID_COUNTER_PERSIST_AHEAD) + { + state.event_epoch_end = state.next_event_no.saturating_add(EVENT_NO_EPOCH_SIZE); + + persist.store(EVENT_EPOCH_KEY, |buf| state.store(buf).map(Some))?; + } + + Ok::<_, Error>(event_no) + })?; + + persist.run()?; + + event_no + }; let mut internal = self.state.lock().await; let timestamp = EventDataTimestamp::EpochTimestamp((self.epoch)().as_millis() as u64); @@ -172,51 +211,44 @@ impl Events { // TODO(events) we can't do it like this, this will miss events when pushing happens after for_each but before we call this // we need to return the last processed one from for_each or something like that pub fn peek_next_event_no(&self) -> u64 { - self.persisted_state.lock(|cell| cell.get().next_event_no) + self.persisted_state + .lock(|cell| cell.borrow().next_event_no) } - /// True if the persisted state has changed since the last time store() was called. - pub fn changed(&self) -> bool { - self.persisted_state.lock(|cell| cell.get().changed) - } + /// Remove persisted state from the given key-value store. + pub async fn reset_persist(&mut self, mut kv: S, buf: &mut [u8]) -> Result<(), Error> + where + S: KvBlobStore, + { + self.reset(); - /// Restore events from previously persisted state. - pub fn load(&self, data: &[u8]) -> Result<(), Error> { - let epoch_end = TLVElement::new(data).u64()?; + kv.remove(EVENT_EPOCH_KEY, buf)?; - self.persisted_state.lock(|cell| { - cell.set(PersistedState { - next_event_no: epoch_end, - event_epoch_end: epoch_end.saturating_add(EVENT_NO_EPOCH_SIZE), - changed: true, - }); - }); + info!("Removed events counter from storage"); Ok(()) } - /// Store events persistence state into a byte slice. - pub fn store(&self, buf: &mut [u8]) -> Result { - let epoch_end = self.persisted_state.lock(|cell| { - let mut state = cell.get(); - state.changed = false; - cell.set(state); - state.event_epoch_end - }); + /// Load persisted state from the given key-value store, so that we can continue emitting events without reusing event numbers. + pub async fn load_persist(&mut self, mut kv: S, buf: &mut [u8]) -> Result<(), Error> + where + S: KvBlobStore, + { + self.reset(); - let mut wb = WriteBuf::new(buf); - wb.u64(&TLVTag::Anonymous, epoch_end)?; - Ok(wb.get_tail()) - } + if let Some(data) = kv.load(EVENT_EPOCH_KEY, buf)? { + self.load(data)?; - /// Wait until the persistent state needs re-persisting. - pub async fn wait_persist(&self) { - loop { - if self.persisted_state.lock(|cell| cell.get().changed) { - break; - } - self.persist_notification.wait().await; + info!("Loaded events counter from storage"); } + + Ok(()) + } + + /// Restore events from previously persisted state. + pub fn load(&mut self, data: &[u8]) -> Result<(), Error> { + let cell = self.persisted_state.get_mut(); + cell.borrow_mut().load(data) } } @@ -281,6 +313,12 @@ impl EventsInner { }) } + pub fn reset(&mut self) { + self.buf_debug.reset(); + self.buf_info.reset(); + self.buf_critical.reset(); + } + pub fn push<'a>( &'a mut self, path: EventPath, @@ -521,6 +559,10 @@ impl LevelBuf { }) } + pub fn reset(&mut self) { + self.head = 0; + } + fn write(&mut self, byte: u8) -> Result<(), OverflowError> { if self.capacity() == 0 { return Err(OverflowError {}); diff --git a/rs-matter/src/dm/networks/eth.rs b/rs-matter/src/dm/networks/eth.rs index 964155a28..44d05e3bb 100644 --- a/rs-matter/src/dm/networks/eth.rs +++ b/rs-matter/src/dm/networks/eth.rs @@ -22,7 +22,6 @@ use crate::dm::clusters::net_comm::{ NetworkType, NetworksError, WirelessCreds, }; use crate::error::{Error, ErrorCode}; -use crate::utils::sync::DynBase; /// A fixed `Networks` trait implementation for Ethernet. /// @@ -39,8 +38,6 @@ impl<'a> EthNetwork<'a> { } } -impl DynBase for EthNetwork<'_> {} - impl net_comm::Networks for EthNetwork<'_> { fn max_networks(&self) -> Result { Ok(1) @@ -76,21 +73,25 @@ impl net_comm::Networks for EthNetwork<'_> { Ok(true) } - fn set_enabled(&self, _enabled: bool) -> Result<(), Error> { + fn set_enabled(&mut self, _enabled: bool) -> Result<(), Error> { Ok(()) } - fn add_or_update(&self, _creds: &WirelessCreds<'_>) -> Result { + fn add_or_update(&mut self, _creds: &WirelessCreds<'_>) -> Result { Err(NetworksError::Other(ErrorCode::InvalidAction.into())) } - fn reorder(&self, _index: u8, _network_id: &[u8]) -> Result { + fn reorder(&mut self, _index: u8, _network_id: &[u8]) -> Result { Err(NetworksError::Other(ErrorCode::InvalidAction.into())) } - fn remove(&self, _network_id: &[u8]) -> Result { + fn remove(&mut self, _network_id: &[u8]) -> Result { Err(NetworksError::Other(ErrorCode::InvalidAction.into())) } + + fn persist(&self, _buf: &mut [u8]) -> Result, Error> { + Ok(None) + } } /// A `net_comm::NetCtl` implementation for Ethernet that errors out on all methods. diff --git a/rs-matter/src/dm/networks/wireless.rs b/rs-matter/src/dm/networks/wireless.rs index e4d0d793e..ff677fcbb 100644 --- a/rs-matter/src/dm/networks/wireless.rs +++ b/rs-matter/src/dm/networks/wireless.rs @@ -20,19 +20,20 @@ use core::fmt::{Debug, Display}; use crate::dm::clusters::net_comm::{ - self, NetCtlError, NetworkCommissioningStatusEnum, NetworkType, NetworksError, + self, NetCtlError, NetworkCommissioningStatusEnum, NetworkType, Networks, NetworksError, ThreadCapabilitiesBitmap, WirelessCreds, }; use crate::dm::clusters::{thread_diag, wifi_diag}; use crate::error::{Error, ErrorCode}; use crate::fmt::Bytes; +use crate::persist::{KvBlobStore, NETWORKS_KEY}; use crate::tlv::{FromTLV, TLVElement, TLVTag, ToTLV}; use crate::transport::network::btp::Btp; use crate::utils::cell::RefCell; use crate::utils::init::{init, Init}; use crate::utils::storage::{Vec, WriteBuf}; -use crate::utils::sync::blocking::{self, Mutex}; -use crate::utils::sync::{DynBase, Notification}; +use crate::utils::sync::blocking; +use crate::utils::sync::DynBase; use super::NetChangeNotif; @@ -102,316 +103,90 @@ pub trait WirelessNetwork: Send + for<'a> FromTLV<'a> + ToTLV { } /// A fixed-size storage for wireless networks credentials. +#[derive(Clone, Debug)] +#[cfg_attr(feature = "defmt", derive(defmt::Format))] pub struct WirelessNetworks { - state: Mutex>>, - state_changed: Notification, - persist_state_changed: Notification, + networks: crate::utils::storage::Vec, +} + +impl Default for WirelessNetworks +where + T: WirelessNetwork, +{ + fn default() -> Self { + Self::new() + } } impl WirelessNetworks where T: WirelessNetwork, { - /// Create a new instance. pub const fn new() -> Self { Self { - state: Mutex::new(RefCell::new(WirelessNetworksStore::new())), - state_changed: Notification::new(), - persist_state_changed: Notification::new(), + networks: crate::utils::storage::Vec::new(), } } - /// Return an in-place initializer for the struct. pub fn init() -> impl Init { init!(Self { - state <- Mutex::init(RefCell::init(WirelessNetworksStore::init())), - state_changed: Notification::new(), - persist_state_changed: Notification::new(), + networks <- crate::utils::storage::Vec::init(), }) } /// Reset the state - /// - /// # Arguments - /// - `flag_changed`: If `true`, the state will be marked as changed - pub fn reset(&self, flag_changed: bool) { - self.state - .lock(|state| state.borrow_mut().reset(flag_changed)); - - self.state_changed.notify(); - - if flag_changed { - self.persist_state_changed.notify(); - } - } - - /// Load the state from a byte slice. - /// - /// # Arguments - /// - `data`: The byte slice to load the state from - pub fn load(&self, data: &[u8]) -> Result<(), Error> { - self.state.lock(|state| state.borrow_mut().load(data)) - } - - /// Store the state into a byte slice. - /// - /// # Arguments - /// - `buf`: The byte slice to store the state into - /// - /// Returns the number of bytes written into the buffer. - pub fn store(&self, buf: &mut [u8]) -> Result { - self.state.lock(|state| state.borrow_mut().store(buf)) - } - - /// Return `true` if the state has changed. - pub fn changed(&self) -> bool { - self.state.lock(|state| state.borrow().changed) - } - - /// Wait for the state to change. - pub async fn wait_state_changed(&self) { - loop { - if self.state.lock(|state| state.borrow().changed) { - break; - } - - self.state_changed.wait().await; - } - } - - /// Wait for the state to be changed in a way that requires persisting. - pub async fn wait_persist(&self) { - loop { - if self.state.lock(|state| state.borrow().changed) { - break; - } - - self.persist_state_changed.wait().await; - } - } - - /// Iterate over the registered network credentials - /// - /// # Arguments - /// - `f`: A closure that will be called for each network registered in the storage - pub fn networks(&self, f: F) -> Result<(), Error> - where - F: FnMut(&T) -> Result<(), Error>, - { - self.state.lock(|state| state.borrow().networks(f)) - } - - /// Get the credentials of a network by its ID - /// - /// # Arguments - /// - `network_id`: The ID of the network to get - /// - `f`: A closure that will be called with the credentials of the network, if the network exists - /// - /// Returns the index of the network in the storage if the network exists, `NetworkError::NetworkIdNotFound` otherwise - pub fn network(&self, network_id: &[u8], f: F) -> Result - where - F: FnOnce(&T) -> Result<(), Error>, - { - self.state - .lock(|state| state.borrow().network(network_id, f)) - } - - /// Get the next network credentials after the one with the given ID - /// - /// # Arguments - /// - `after_network_id`: The ID of the network to get the next one after. - /// If no network with the provided network ID exists, the first network in the storage will be returned. - pub fn next_network(&self, after_network_id: Option<&[u8]>, f: F) -> Result - where - F: FnOnce(&T) -> Result<(), Error>, - { - self.state - .lock(|state| state.borrow_mut().next_network(after_network_id, f)) + pub fn reset(&mut self) { + self.networks.clear(); } - /// Add or update a network in the storage + /// Remove all networks from the provided BLOB store and from memory /// /// # Arguments - /// - `network_id`: The ID of the network to add or update - /// - `add`: An in-place initializer for the network to add. The initializer will be used only if a network with the provided - /// network ID does not exist in the storage - /// - `update`: A closure that will be called with the network to update. The closure will be called only if a network with the provided - /// network ID exists in the storage - pub fn add_or_update( - &self, - network_id: &[u8], - add: A, - update: U, - ) -> Result - where - A: Init, - U: FnOnce(&mut T) -> Result<(), Error>, - { - self.state.lock(|state| { - let index = state.borrow_mut().add_or_update(network_id, add, update)?; + /// - `store`: the BLOB store to remove the networks from + /// - `buf`: a temporary buffer to use for removing the networks + pub async fn reset_persist( + &mut self, + mut kv: S, + buf: &mut [u8], + ) -> Result<(), Error> { + self.reset(); - self.state_changed.notify(); - self.persist_state_changed.notify(); + kv.remove(NETWORKS_KEY, buf)?; - Ok(index) - }) - } + info!("Removed all wireless networks from storage"); - /// Reorder a network in the storage - /// - /// # Arguments - /// - `index`: The new index of the network - /// - `network_id`: The ID of the network to reorder - /// - /// Returns the new index of the network in the storage, if a network with the provided ID exists - /// or `NetworkError::NetworkIdNotFound` otherwise - pub fn reorder(&self, index: u8, network_id: &[u8]) -> Result { - self.state.lock(|state| { - let index = state.borrow_mut().reorder(index, network_id)?; - - self.state_changed.notify(); - self.persist_state_changed.notify(); - - Ok(index) - }) + Ok(()) } - /// Remove a network from the storage + /// Load all networks from the provided BLOB store /// /// # Arguments - /// - `network_id`: The ID of the network to remove - /// - /// Returns the index of the network in the storage if the network exists and was removed, `NetworkError::NetworkIdNotFound` otherwise - pub fn remove(&self, network_id: &[u8]) -> Result { - self.state.lock(|state| { - let index = state.borrow_mut().remove(network_id)?; - - self.state_changed.notify(); - self.persist_state_changed.notify(); - - Ok(index) - }) - } -} - -impl Default for WirelessNetworks -where - T: WirelessNetwork + Clone, -{ - fn default() -> Self { - Self::new() - } -} - -impl DynBase for WirelessNetworks where T: WirelessNetwork {} - -impl net_comm::Networks for WirelessNetworks -where - T: WirelessNetwork, -{ - fn max_networks(&self) -> Result { - Ok(N as _) - } - - fn networks( - &self, - f: &mut dyn FnMut(&net_comm::NetworkInfo) -> Result<(), Error>, + /// - `store`: the BLOB store to load the networks from + /// - `buf`: a temporary buffer to use for loading the networks + pub async fn load_persist( + &mut self, + mut kv: S, + buf: &mut [u8], ) -> Result<(), Error> { - WirelessNetworks::networks(self, |network| { - let network_id = network.id(); + self.reset(); - let network_info = net_comm::NetworkInfo { - network_id, - connected: false, // TODO - }; + if let Some(data) = kv.load(NETWORKS_KEY, buf)? { + self.load(data)?; - f(&network_info) - }) - } - - fn creds( - &self, - network_id: &[u8], - f: &mut dyn FnMut(&net_comm::WirelessCreds) -> Result<(), Error>, - ) -> Result { - WirelessNetworks::network(self, network_id, |network| f(&network.creds())) - } - - fn next_creds( - &self, - last_network_id: Option<&[u8]>, - f: &mut dyn FnMut(&WirelessCreds) -> Result<(), Error>, - ) -> Result { - WirelessNetworks::next_network(self, last_network_id, |network| f(&network.creds())) - } - - fn enabled(&self) -> Result { - Ok(true) - } - - fn set_enabled(&self, _enabled: bool) -> Result<(), Error> { - Ok(()) - } - - fn add_or_update( - &self, - creds: &net_comm::WirelessCreds<'_>, - ) -> Result { - WirelessNetworks::add_or_update(self, creds.id()?, T::init_from(creds), |network| { - network.update(creds) - }) - } - - fn reorder(&self, index: u8, network_id: &[u8]) -> Result { - WirelessNetworks::reorder(self, index, network_id) - } - - fn remove(&self, network_id: &[u8]) -> Result { - WirelessNetworks::remove(self, network_id) - } -} - -impl NetChangeNotif for WirelessNetworks -where - T: WirelessNetwork, -{ - async fn wait_changed(&self) { - self.state_changed.wait().await; - } -} - -/// The internal unsychronized storage for network credentials. -#[derive(Clone, Debug)] -#[cfg_attr(feature = "defmt", derive(defmt::Format))] -struct WirelessNetworksStore { - networks: crate::utils::storage::Vec, - changed: bool, -} - -impl WirelessNetworksStore -where - T: WirelessNetwork, -{ - const fn new() -> Self { - Self { - networks: crate::utils::storage::Vec::new(), - changed: false, + info!( + "Loaded {} wireless networks from storage", + self.networks.len() + ); } - } - - fn init() -> impl Init { - init!(Self { - networks <- crate::utils::storage::Vec::init(), - changed: false, - }) - } - fn reset(&mut self, flag_changed: bool) { - self.networks.clear(); - self.changed = flag_changed; + Ok(()) } - fn load(&mut self, data: &[u8]) -> Result<(), Error> { + /// Load the state from a byte slice. + /// + /// # Arguments + /// - `data`: The byte slice to load the state from + pub fn load(&mut self, data: &[u8]) -> Result<(), Error> { let root = TLVElement::new(data); let iter = root.array()?.iter(); @@ -426,24 +201,30 @@ where })?; } - self.changed = false; - Ok(()) } - fn store(&mut self, buf: &mut [u8]) -> Result { + /// Store the state into a byte slice. + /// + /// # Arguments + /// - `buf`: The byte slice to store the state into + /// + /// Returns the number of bytes written into the buffer. + pub fn store(&self, buf: &mut [u8]) -> Result { let mut wb = WriteBuf::new(buf); self.networks.to_tlv(&TLVTag::Anonymous, &mut wb)?; - self.changed = false; - let tail = wb.get_tail(); Ok(tail) } - fn networks(&self, mut f: F) -> Result<(), Error> + /// Iterate over the registered network credentials + /// + /// # Arguments + /// - `f`: A closure that will be called for each network registered in the storage + pub fn networks(&self, mut f: F) -> Result<(), Error> where F: FnMut(&T) -> Result<(), Error>, { @@ -454,7 +235,14 @@ where Ok(()) } - fn network(&self, network_id: &[u8], f: F) -> Result + /// Get the credentials of a network by its ID + /// + /// # Arguments + /// - `network_id`: The ID of the network to get + /// - `f`: A closure that will be called with the credentials of the network, if the network exists + /// + /// Returns the index of the network in the storage if the network exists, `NetworkError::NetworkIdNotFound` otherwise + pub fn network(&self, network_id: &[u8], f: F) -> Result where F: FnOnce(&T) -> Result<(), Error>, { @@ -473,7 +261,12 @@ where } } - fn next_network(&mut self, last_network_id: Option<&[u8]>, f: F) -> Result + /// Get the next network credentials after the one with the given ID + /// + /// # Arguments + /// - `after_network_id`: The ID of the network to get the next one after. + /// If no network with the provided network ID exists, the first network in the storage will be returned. + pub fn next_network(&self, last_network_id: Option<&[u8]>, f: F) -> Result where F: FnOnce(&T) -> Result<(), Error>, { @@ -516,7 +309,15 @@ where } } - fn add_or_update( + /// Add or update a network in the storage + /// + /// # Arguments + /// - `network_id`: The ID of the network to add or update + /// - `add`: An in-place initializer for the network to add. The initializer will be used only if a network with the provided + /// network ID does not exist in the storage + /// - `update`: A closure that will be called with the network to update. The closure will be called only if a network with the provided + /// network ID exists in the storage + pub fn add_or_update( &mut self, network_id: &[u8], add: A, @@ -536,8 +337,6 @@ where // Update update(unetwork)?; - self.changed = true; - info!("Updated network with ID {}", unetwork.display()); Ok(index as _) @@ -553,15 +352,21 @@ where self.networks .push_init(add, || ErrorCode::ResourceExhausted.into())?; - self.changed = true; - info!("Added network with ID {}", T::display_id(network_id)); Ok((self.networks.len() - 1) as _) } } - fn reorder(&mut self, index: u8, network_id: &[u8]) -> Result { + /// Reorder a network in the storage + /// + /// # Arguments + /// - `index`: The new index of the network + /// - `network_id`: The ID of the network to reorder + /// + /// Returns the new index of the network in the storage, if a network with the provided ID exists + /// or `NetworkError::NetworkIdNotFound` otherwise + pub fn reorder(&mut self, index: u8, network_id: &[u8]) -> Result { let cur_index = self .networks .iter() @@ -574,8 +379,6 @@ where let conf = self.networks.remove(cur_index); unwrap!(self.networks.insert(index as usize, conf).map_err(|_| ())); - self.changed = true; - info!( "Network with ID {} reordered to index {}", T::display_id(network_id), @@ -598,7 +401,13 @@ where Ok(index) } - fn remove(&mut self, network_id: &[u8]) -> Result { + /// Remove a network from the storage + /// + /// # Arguments + /// - `network_id`: The ID of the network to remove + /// + /// Returns the index of the network in the storage if the network exists and was removed, `NetworkError::NetworkIdNotFound` otherwise + pub fn remove(&mut self, network_id: &[u8]) -> Result { let index = self .networks .iter() @@ -608,8 +417,6 @@ where // Found self.networks.remove(index); - self.changed = true; - info!("Removed network with ID {}", T::display_id(network_id)); Ok(index as _) @@ -621,6 +428,76 @@ where } } +impl Networks for WirelessNetworks +where + T: WirelessNetwork, +{ + fn max_networks(&self) -> Result { + Ok(N as _) + } + + fn networks( + &self, + f: &mut dyn FnMut(&net_comm::NetworkInfo) -> Result<(), Error>, + ) -> Result<(), Error> { + WirelessNetworks::networks(self, |network| { + let network_id = network.id(); + + let network_info = net_comm::NetworkInfo { + network_id, + connected: false, // TODO + }; + + f(&network_info) + }) + } + + fn creds( + &self, + network_id: &[u8], + f: &mut dyn FnMut(&net_comm::WirelessCreds) -> Result<(), Error>, + ) -> Result { + WirelessNetworks::network(self, network_id, |network| f(&network.creds())) + } + + fn next_creds( + &self, + last_network_id: Option<&[u8]>, + f: &mut dyn FnMut(&WirelessCreds) -> Result<(), Error>, + ) -> Result { + WirelessNetworks::next_network(self, last_network_id, |network| f(&network.creds())) + } + + fn enabled(&self) -> Result { + Ok(true) + } + + fn set_enabled(&mut self, _enabled: bool) -> Result<(), Error> { + Ok(()) + } + + fn add_or_update( + &mut self, + creds: &net_comm::WirelessCreds<'_>, + ) -> Result { + WirelessNetworks::add_or_update(self, creds.id()?, T::init_from(creds), |network| { + network.update(creds) + }) + } + + fn reorder(&mut self, index: u8, network_id: &[u8]) -> Result { + WirelessNetworks::reorder(self, index, network_id) + } + + fn remove(&mut self, network_id: &[u8]) -> Result { + WirelessNetworks::remove(self, network_id) + } + + fn persist(&self, buf: &mut [u8]) -> Result, Error> { + WirelessNetworks::store(self, buf).map(Some) + } +} + /// An enum capable of displaying a network ID in a human-readable format. #[derive(Debug)] enum DisplayId<'a> { diff --git a/rs-matter/src/dm/networks/wireless/mgr.rs b/rs-matter/src/dm/networks/wireless/mgr.rs index a177c6c96..a44c610cd 100644 --- a/rs-matter/src/dm/networks/wireless/mgr.rs +++ b/rs-matter/src/dm/networks/wireless/mgr.rs @@ -40,7 +40,7 @@ pub struct WirelessMgr<'a, W, T> { impl<'a, W, T> WirelessMgr<'a, W, T> where - W: net_comm::Networks + NetChangeNotif, + W: net_comm::NetworksAccess + NetChangeNotif, T: net_comm::NetCtl + wifi_diag::WirelessDiag + NetChangeNotif, { /// Creates a new `WirelessMgr` instance. @@ -76,36 +76,38 @@ where let mut c = None; - networks.next_creds( - (!network_id.is_empty()).then_some(&network_id), - &mut |creds| { - match creds { - WirelessCreds::Wifi { ssid, pass } => { - if ssid.len() + pass.len() > buf.len() { - error!("SSID and password too large"); - return Err(ErrorCode::InvalidData.into()); + networks.access(|networks| { + networks.next_creds( + (!network_id.is_empty()).then_some(&network_id), + &mut |creds| { + match creds { + WirelessCreds::Wifi { ssid, pass } => { + if ssid.len() + pass.len() > buf.len() { + error!("SSID and password too large"); + return Err(ErrorCode::InvalidData.into()); + } + + buf[..ssid.len()].copy_from_slice(ssid); + buf[ssid.len()..][..pass.len()].copy_from_slice(pass); + + c = Some((ssid.len(), Some(pass.len()))) } + WirelessCreds::Thread { dataset_tlv } => { + if dataset_tlv.len() > buf.len() { + error!("Dataset TLV too large"); + return Err(ErrorCode::InvalidData.into()); + } - buf[..ssid.len()].copy_from_slice(ssid); - buf[ssid.len()..][..pass.len()].copy_from_slice(pass); + buf[..dataset_tlv.len()].copy_from_slice(dataset_tlv); - c = Some((ssid.len(), Some(pass.len()))) - } - WirelessCreds::Thread { dataset_tlv } => { - if dataset_tlv.len() > buf.len() { - error!("Dataset TLV too large"); - return Err(ErrorCode::InvalidData.into()); + c = Some((dataset_tlv.len(), None)) } - - buf[..dataset_tlv.len()].copy_from_slice(dataset_tlv); - - c = Some((dataset_tlv.len(), None)) } - } - Ok(()) - }, - )?; + Ok(()) + }, + ) + })?; if let Some((len1, len2)) = c { let creds = if let Some(len2) = len2 { diff --git a/rs-matter/src/dm/types/handler.rs b/rs-matter/src/dm/types/handler.rs index 941b02db1..c2cb674fb 100644 --- a/rs-matter/src/dm/types/handler.rs +++ b/rs-matter/src/dm/types/handler.rs @@ -22,6 +22,7 @@ use crate::crypto::backend::dummy::DummyCrypto; use crate::crypto::Crypto; use crate::dm::IMBuffer; use crate::error::{Error, ErrorCode}; +use crate::persist::{DummyKvBlobStoreAccess, KvBlobStoreAccess}; use crate::tlv::TLVElement; use crate::transport::exchange::Exchange; use crate::utils::select::Coalesce; @@ -63,6 +64,9 @@ pub trait BasicContext { /// Return the crypto object that is associated with this operation. fn crypto(&self) -> impl Crypto + '_; + /// Return a blob store that can be used to persist data across reboots. + fn kv(&self) -> impl KvBlobStoreAccess + '_; + /// Notify that the state of an attribute has changed. /// /// # Arguments @@ -94,6 +98,10 @@ where (**self).crypto() } + fn kv(&self) -> impl KvBlobStoreAccess + '_ { + (**self).kv() + } + fn notify_attribute_changed( &self, endpoint_id: EndptId, @@ -159,6 +167,7 @@ pub trait Context: HandlerContext { DummyCrypto, EmptyHandler, PooledBuffers<0, IMBuffer>, + DummyKvBlobStoreAccess, >, >::None } @@ -171,6 +180,7 @@ pub trait Context: HandlerContext { DummyCrypto, EmptyHandler, PooledBuffers<0, IMBuffer>, + DummyKvBlobStoreAccess, >, >::None } @@ -179,7 +189,12 @@ pub trait Context: HandlerContext { /// The operation will return `Some` only if the underlying context represents an invoke operation. fn as_invoke_ctx(&self) -> Option { Option::< - &'static InvokeContextInstance>, + &'static InvokeContextInstance< + DummyCrypto, + EmptyHandler, + PooledBuffers<0, IMBuffer>, + DummyKvBlobStoreAccess, + >, >::None } } @@ -267,68 +282,22 @@ where } } -/// A concrete implementation of the `BasicContext` trait -pub(crate) struct BasicContextInstance<'a, C> { - matter: &'a Matter<'a>, - crypto: C, - pub(crate) notify: &'a dyn ChangeNotify, -} - -impl<'a, C> BasicContextInstance<'a, C> -where - C: Crypto, -{ - /// Construct a new instance. - #[inline(always)] - pub(crate) const fn new( - matter: &'a Matter<'a>, - crypto: C, - notify: &'a dyn ChangeNotify, - ) -> Self { - Self { - matter, - crypto, - notify, - } - } -} - -impl BasicContext for BasicContextInstance<'_, C> -where - C: Crypto, -{ - fn matter(&self) -> &Matter<'_> { - self.matter - } - - fn crypto(&self) -> impl Crypto + '_ { - &self.crypto - } - - fn notify_attribute_changed( - &self, - endpoint_id: EndptId, - cluster_id: ClusterId, - attr_id: AttrId, - ) { - self.notify.notify(endpoint_id, cluster_id, attr_id); - } -} - /// A concrete implementation of the `HandlerContext` trait -pub(crate) struct HandlerContextInstance<'a, C, T, B> { +pub(crate) struct HandlerContextInstance<'a, C, T, B, S> { matter: &'a Matter<'a>, crypto: C, handler: T, buffers: B, + kv: S, pub(crate) notify: &'a dyn ChangeNotify, } -impl<'a, C, T, B> HandlerContextInstance<'a, C, T, B> +impl<'a, C, T, B, S> HandlerContextInstance<'a, C, T, B, S> where C: Crypto, T: AsyncHandler, B: BufferAccess, + S: KvBlobStoreAccess, { /// Construct a new instance. #[inline(always)] @@ -337,6 +306,7 @@ where crypto: C, handler: T, buffers: B, + kv: S, notify: &'a dyn ChangeNotify, ) -> Self { Self { @@ -344,16 +314,18 @@ where crypto, handler, buffers, + kv, notify, } } } -impl BasicContext for HandlerContextInstance<'_, C, T, B> +impl BasicContext for HandlerContextInstance<'_, C, T, B, S> where C: Crypto, T: AsyncHandler, B: BufferAccess, + S: KvBlobStoreAccess, { fn matter(&self) -> &Matter<'_> { self.matter @@ -363,6 +335,10 @@ where &self.crypto } + fn kv(&self) -> impl KvBlobStoreAccess + '_ { + &self.kv + } + fn notify_attribute_changed( &self, endpoint_id: EndptId, @@ -373,11 +349,12 @@ where } } -impl HandlerContext for HandlerContextInstance<'_, C, T, B> +impl HandlerContext for HandlerContextInstance<'_, C, T, B, S> where C: Crypto, T: AsyncHandler, B: BufferAccess, + S: KvBlobStoreAccess, { fn handler(&self) -> impl AsyncHandler + '_ { &self.handler @@ -389,20 +366,22 @@ where } /// A concrete implementation of the `ReadContext` trait -pub(crate) struct ReadContextInstance<'a, C, T, B> { +pub(crate) struct ReadContextInstance<'a, C, T, B, S> { exchange: &'a Exchange<'a>, crypto: C, handler: T, buffers: B, + kv: S, attr: &'a AttrDetails<'a>, pub(crate) notify: &'a dyn ChangeNotify, } -impl<'a, C, T, B> ReadContextInstance<'a, C, T, B> +impl<'a, C, T, B, S> ReadContextInstance<'a, C, T, B, S> where C: Crypto, T: AsyncHandler, B: BufferAccess, + S: KvBlobStoreAccess, { /// Construct a new instance. #[inline(always)] @@ -411,6 +390,7 @@ where crypto: C, handler: T, buffers: B, + kv: S, attr: &'a AttrDetails<'a>, notify: &'a dyn ChangeNotify, ) -> Self { @@ -419,17 +399,19 @@ where crypto, handler, buffers, + kv, attr, notify, } } } -impl BasicContext for ReadContextInstance<'_, C, T, B> +impl BasicContext for ReadContextInstance<'_, C, T, B, S> where C: Crypto, T: AsyncHandler, B: BufferAccess, + S: KvBlobStoreAccess, { fn matter(&self) -> &Matter<'_> { self.exchange().matter() @@ -439,6 +421,10 @@ where &self.crypto } + fn kv(&self) -> impl KvBlobStoreAccess + '_ { + &self.kv + } + fn notify_attribute_changed( &self, endpoint_id: EndptId, @@ -449,11 +435,12 @@ where } } -impl HandlerContext for ReadContextInstance<'_, C, T, B> +impl HandlerContext for ReadContextInstance<'_, C, T, B, S> where C: Crypto, T: AsyncHandler, B: BufferAccess, + S: KvBlobStoreAccess, { fn handler(&self) -> impl AsyncHandler + '_ { &self.handler @@ -464,11 +451,12 @@ where } } -impl Context for ReadContextInstance<'_, C, T, B> +impl Context for ReadContextInstance<'_, C, T, B, S> where C: Crypto, T: AsyncHandler, B: BufferAccess, + S: KvBlobStoreAccess, { fn exchange(&self) -> &Exchange<'_> { self.exchange @@ -479,11 +467,12 @@ where } } -impl ReadContext for ReadContextInstance<'_, C, T, B> +impl ReadContext for ReadContextInstance<'_, C, T, B, S> where C: Crypto, T: AsyncHandler, B: BufferAccess, + S: KvBlobStoreAccess, { fn attr(&self) -> &AttrDetails<'_> { self.attr @@ -491,38 +480,43 @@ where } /// A context implementation of the `WriteContext` trait -pub(crate) struct WriteContextInstance<'a, C, T, B> { +pub(crate) struct WriteContextInstance<'a, C, T, B, S> { exchange: &'a Exchange<'a>, crypto: C, handler: T, buffers: B, + kv: S, attr: &'a AttrDetails<'a>, data: &'a TLVElement<'a>, pub(crate) notify: &'a dyn ChangeNotify, } -impl<'a, C, T, B> WriteContextInstance<'a, C, T, B> +impl<'a, C, T, B, S> WriteContextInstance<'a, C, T, B, S> where C: Crypto, T: AsyncHandler, B: BufferAccess, + S: KvBlobStoreAccess, { /// Create a new instance. #[inline(always)] + #[allow(clippy::too_many_arguments)] pub(crate) const fn new( exchange: &'a Exchange<'a>, crypto: C, handler: T, buffers: B, + kv: S, attr: &'a AttrDetails<'a>, data: &'a TLVElement<'a>, notify: &'a dyn ChangeNotify, ) -> Self { Self { exchange, + crypto, handler, buffers, - crypto, + kv, attr, data, notify, @@ -530,11 +524,12 @@ where } } -impl BasicContext for WriteContextInstance<'_, C, T, B> +impl BasicContext for WriteContextInstance<'_, C, T, B, S> where C: Crypto, T: AsyncHandler, B: BufferAccess, + S: KvBlobStoreAccess, { fn matter(&self) -> &Matter<'_> { self.exchange().matter() @@ -544,6 +539,10 @@ where &self.crypto } + fn kv(&self) -> impl KvBlobStoreAccess + '_ { + &self.kv + } + fn notify_attribute_changed( &self, endpoint_id: EndptId, @@ -554,11 +553,12 @@ where } } -impl HandlerContext for WriteContextInstance<'_, C, T, B> +impl HandlerContext for WriteContextInstance<'_, C, T, B, S> where C: Crypto, T: AsyncHandler, B: BufferAccess, + S: KvBlobStoreAccess, { fn handler(&self) -> impl AsyncHandler + '_ { &self.handler @@ -569,11 +569,12 @@ where } } -impl Context for WriteContextInstance<'_, C, T, B> +impl Context for WriteContextInstance<'_, C, T, B, S> where C: Crypto, T: AsyncHandler, B: BufferAccess, + S: KvBlobStoreAccess, { fn exchange(&self) -> &Exchange<'_> { self.exchange @@ -584,11 +585,12 @@ where } } -impl WriteContext for WriteContextInstance<'_, C, T, B> +impl WriteContext for WriteContextInstance<'_, C, T, B, S> where C: Crypto, T: AsyncHandler, B: BufferAccess, + S: KvBlobStoreAccess, { fn attr(&self) -> &AttrDetails<'_> { self.attr @@ -600,38 +602,43 @@ where } /// A concrete implementation of the `InvokeContext` trait -pub(crate) struct InvokeContextInstance<'a, C, T, B> { +pub(crate) struct InvokeContextInstance<'a, C, T, B, S> { exchange: &'a Exchange<'a>, crypto: C, handler: T, buffers: B, + kv: S, cmd: &'a CmdDetails<'a>, data: &'a TLVElement<'a>, notify: &'a dyn ChangeNotify, } -impl<'a, C, T, B> InvokeContextInstance<'a, C, T, B> +impl<'a, C, T, B, S> InvokeContextInstance<'a, C, T, B, S> where C: Crypto, T: AsyncHandler, B: BufferAccess, + S: KvBlobStoreAccess, { /// Construct a new instance. #[inline(always)] + #[allow(clippy::too_many_arguments)] pub(crate) const fn new( exchange: &'a Exchange<'a>, crypto: C, handler: T, buffers: B, + kv: S, cmd: &'a CmdDetails<'a>, data: &'a TLVElement<'a>, notify: &'a dyn ChangeNotify, ) -> Self { Self { exchange, + crypto, handler, buffers, - crypto, + kv, cmd, data, notify, @@ -639,11 +646,12 @@ where } } -impl BasicContext for InvokeContextInstance<'_, C, T, B> +impl BasicContext for InvokeContextInstance<'_, C, T, B, S> where C: Crypto, T: AsyncHandler, B: BufferAccess, + S: KvBlobStoreAccess, { fn matter(&self) -> &Matter<'_> { self.exchange().matter() @@ -653,6 +661,10 @@ where &self.crypto } + fn kv(&self) -> impl KvBlobStoreAccess + '_ { + &self.kv + } + fn notify_attribute_changed( &self, endpoint_id: EndptId, @@ -663,11 +675,12 @@ where } } -impl HandlerContext for InvokeContextInstance<'_, C, T, B> +impl HandlerContext for InvokeContextInstance<'_, C, T, B, S> where C: Crypto, T: AsyncHandler, B: BufferAccess, + S: KvBlobStoreAccess, { fn handler(&self) -> impl AsyncHandler + '_ { &self.handler @@ -678,11 +691,12 @@ where } } -impl Context for InvokeContextInstance<'_, C, T, B> +impl Context for InvokeContextInstance<'_, C, T, B, S> where C: Crypto, T: AsyncHandler, B: BufferAccess, + S: KvBlobStoreAccess, { fn exchange(&self) -> &Exchange<'_> { self.exchange @@ -693,11 +707,12 @@ where } } -impl InvokeContext for InvokeContextInstance<'_, C, T, B> +impl InvokeContext for InvokeContextInstance<'_, C, T, B, S> where C: Crypto, T: AsyncHandler, B: BufferAccess, + S: KvBlobStoreAccess, { fn cmd(&self) -> &CmdDetails<'_> { self.cmd diff --git a/rs-matter/src/dm/types/reply.rs b/rs-matter/src/dm/types/reply.rs index 727a47e63..b5c50e317 100644 --- a/rs-matter/src/dm/types/reply.rs +++ b/rs-matter/src/dm/types/reply.rs @@ -24,6 +24,7 @@ use crate::im::{ AttrDataTag, AttrPath, AttrResp, AttrRespTag, AttrStatus, CmdDataTag, CmdPath, CmdResp, CmdRespTag, CmdStatus, EventData, EventFilter, EventPath, EventRespTag, IMStatusCode, }; +use crate::persist::KvBlobStoreAccess; use crate::tlv::{TLVArray, TLVElement, TLVTag, TLVWrite, TagType, ToTLV}; use crate::transport::exchange::Exchange; use crate::utils::storage::pooled::BufferAccess; @@ -71,25 +72,34 @@ pub trait InvokeReply { fn with_command(self, cmd: u32) -> Result; } -pub struct HandlerInvoker<'a, 'b, C, D, B> { +pub struct HandlerInvoker<'a, 'b, C, D, B, S> { exchange: &'b mut Exchange<'a>, crypto: C, handler: D, buffers: B, + kv: S, } -impl<'a, 'b, C, D, B> HandlerInvoker<'a, 'b, C, D, B> +impl<'a, 'b, C, D, B, S> HandlerInvoker<'a, 'b, C, D, B, S> where C: Crypto, D: AsyncHandler, B: BufferAccess, + S: KvBlobStoreAccess, { - pub const fn new(exchange: &'b mut Exchange<'a>, crypto: C, handler: D, buffers: B) -> Self { + pub const fn new( + exchange: &'b mut Exchange<'a>, + crypto: C, + handler: D, + buffers: B, + kv: S, + ) -> Self { Self { exchange, crypto, handler, buffers, + kv, } } @@ -164,6 +174,7 @@ where &self.crypto, &self.handler, &self.buffers, + &self.kv, attr, notify, ), @@ -237,6 +248,7 @@ where &self.crypto, &self.handler, &self.buffers, + &self.kv, attr, data, notify, @@ -317,6 +329,7 @@ where &self.crypto, &self.handler, &self.buffers, + &self.kv, cmd, data, notify, diff --git a/rs-matter/src/fabric.rs b/rs-matter/src/fabric.rs index 78bcdd253..0db959918 100644 --- a/rs-matter/src/fabric.rs +++ b/rs-matter/src/fabric.rs @@ -31,9 +31,10 @@ use crate::crypto::{ use crate::dm::Privilege; use crate::error::{Error, ErrorCode}; use crate::group_keys::{GroupKeySet, KeySet}; -use crate::tlv::{FromTLV, TLVElement, TLVTag, TLVWrite, TagType, ToTLV}; +use crate::persist::{KvBlobStore, KvBlobStoreAccess, Persist, FABRIC_KEYS_START}; +use crate::tlv::{FromTLV, TLVElement, ToTLV}; use crate::utils::init::{init, Init, InitMaybeUninit, IntoFallibleInit}; -use crate::utils::storage::{Vec, WriteBuf}; +use crate::utils::storage::Vec; use crate::MatterMdnsService; const COMPRESSED_FABRIC_ID_LEN: usize = 8; @@ -132,7 +133,7 @@ pub struct GroupKeyMapping { #[derive(Debug, FromTLV, ToTLV)] #[cfg_attr(feature = "defmt", derive(defmt::Format))] -struct Groups { +pub struct Groups { /// Group key sets (excluding IPK which is stored in `ipk`) key_sets: Vec, /// Groups keyset mapping @@ -149,6 +150,153 @@ impl Groups { endpoint_mapping <- Vec::init(), }) } + + /// Return an iterator over the group key sets of the fabric + pub fn key_set_iter(&self) -> impl Iterator { + self.key_sets.iter() + } + + /// Find a group key set by ID + pub fn key_set_get(&self, id: u16) -> Option<&GroupKeySet> { + self.key_sets.iter().find(|e| e.group_key_set_id == id) + } + + /// Add or update a group key set + pub fn key_set_add(&mut self, entry: GroupKeySet) -> Result<(), Error> { + if let Some(existing) = self + .key_sets + .iter_mut() + .find(|e| e.group_key_set_id == entry.group_key_set_id) + { + *existing = entry; + } else { + self.key_sets + .push(entry) + .map_err(|_| ErrorCode::ResourceExhausted)?; + } + Ok(()) + } + + /// Remove a group key set by ID. Returns true if found and removed. + pub fn key_set_remove(&mut self, id: u16) -> Result<(), Error> { + let before = self.key_sets.len(); + self.key_sets.retain(|e| e.group_key_set_id != id); + let removed = self.key_sets.len() < before; + + self.key_map_remove_by_key_set(id); + + // Check if element was actually removed + if removed { + Ok(()) + } else { + Err(Error::new(ErrorCode::NotFound)) + } + } + + pub fn key_map_add(&mut self, entry: GroupKeyMapping) -> Result<(), Error> { + self.key_map.push(entry).map_err(|_| ErrorCode::Failure)?; + + Ok(()) + } + + /// Return an iterator over the group key map entries of the fabric + pub fn key_map_iter(&self) -> impl Iterator { + self.key_map.iter() + } + + /// Replace all group key map entries + pub fn key_map_replace( + &mut self, + entries: impl Iterator, + ) -> Result<(), Error> { + self.key_map.clear(); + for entry in entries { + self.key_map + .push(entry) + .map_err(|_| ErrorCode::ResourceExhausted)?; + } + Ok(()) + } + + /// Remove group key map entries that reference a specific key set ID + pub fn key_map_remove_by_key_set(&mut self, key_set_id: u16) { + self.key_map.retain(|e| e.group_key_set_id != key_set_id); + } + + /// Return an iterator over the group table entries + pub fn iter(&self) -> impl Iterator { + self.endpoint_mapping.iter() + } + + /// Look up a group by ID + pub fn get(&self, group_id: u16) -> Option<&GroupEndpointMapping> { + self.endpoint_mapping + .iter() + .find(|e| e.group_id == group_id) + } + + /// Add an endpoint to a group. + /// Returns true if the endpoint was already a member (name still updated per spec). + pub fn add( + &mut self, + endpoint_id: u16, + group_id: u16, + group_name: &str, + ) -> Result { + let entry = if let Some(entry) = self + .endpoint_mapping + .iter_mut() + .find(|e| e.group_id == group_id) + { + entry + } else { + self.endpoint_mapping + .push(GroupEndpointMapping { + group_id, + endpoints: Vec::new(), + group_name: unwrap!(String::from_str(group_name)), + }) + .map_err(|_| ErrorCode::ResourceExhausted)?; + unwrap!(self.endpoint_mapping.last_mut()) + }; + + // Update group name + entry.group_name.clear(); + unwrap!(entry.group_name.push_str(group_name)); + + if entry.endpoints.contains(&endpoint_id) { + return Ok(true); + } + + entry + .endpoints + .push(endpoint_id) + .map_err(|_| ErrorCode::ResourceExhausted)?; + + Ok(false) + } + + /// Remove an endpoint from a group, or from all groups if `group_id` is `None`. + /// Returns true if the endpoint was removed from at least one group. + pub fn remove(&mut self, endpoint_id: u16, group_id: Option) -> bool { + let mut removed = false; + + for entry in self.endpoint_mapping.iter_mut() { + if group_id.is_some_and(|id| id != entry.group_id) { + continue; + } + let before = entry.endpoints.len(); + entry.endpoints.retain(|&ep| ep != endpoint_id); + if entry.endpoints.len() < before { + removed = true; + } + } + + // Remove entries with no endpoints left + self.endpoint_mapping.retain(|e| !e.endpoints.is_empty()); + + removed + } } /// Fabric type @@ -382,6 +530,16 @@ impl Fabric { &self.ipk } + /// Return the fabric's groups + pub fn groups(&self) -> &Groups { + &self.groups + } + + /// Return a mutable reference to the fabric's groups + pub fn groups_mut(&mut self) -> &mut Groups { + &mut self.groups + } + /// Return an iterator over the ACL entries of the fabric pub fn acl_iter(&self) -> impl Iterator { self.acl.iter() @@ -390,7 +548,7 @@ impl Fabric { /// Add a new ACL entry to the fabric. /// /// Return the index of the added entry. - fn acl_add(&mut self, mut entry: AclEntry) -> Result { + pub fn acl_add(&mut self, mut entry: AclEntry) -> Result { if entry.auth_mode() == AuthMode::Pase { // Reserved for future use Err(ErrorCode::ConstraintError)?; @@ -409,7 +567,7 @@ impl Fabric { /// Add a new ACL entry to the fabric using the supplied initializer. /// /// Return the index of the added entry. - fn acl_add_init(&mut self, init: I) -> Result + pub fn acl_add_init(&mut self, init: I) -> Result where I: Init, { @@ -431,7 +589,7 @@ impl Fabric { } /// Update an existing ACL entry in the fabric - fn acl_update(&mut self, idx: usize, mut entry: AclEntry) -> Result<(), Error> { + pub fn acl_update(&mut self, idx: usize, mut entry: AclEntry) -> Result<(), Error> { if self.acl.len() <= idx { return Err(ErrorCode::NotFound.into()); } @@ -445,7 +603,7 @@ impl Fabric { } /// Update an existing ACL entry in the fabric using the supplied initializer - fn acl_update_init(&mut self, idx: usize, init: I) -> Result<(), Error> + pub fn acl_update_init(&mut self, idx: usize, init: I) -> Result<(), Error> where I: Init, { @@ -466,7 +624,7 @@ impl Fabric { } /// Remove an ACL entry from the fabric - fn acl_remove(&mut self, idx: usize) -> Result<(), Error> { + pub fn acl_remove(&mut self, idx: usize) -> Result<(), Error> { if self.acl.len() <= idx { return Err(ErrorCode::NotFound.into()); } @@ -501,166 +659,6 @@ impl Fabric { false } - /// Return an iterator over the group key sets of the fabric - pub fn group_key_set_iter(&self) -> impl Iterator { - self.groups.key_sets.iter() - } - - /// Find a group key set by ID - pub fn group_key_set_get(&self, id: u16) -> Option<&GroupKeySet> { - self.groups - .key_sets - .iter() - .find(|e| e.group_key_set_id == id) - } - - /// Add or update a group key set - fn group_key_set_add(&mut self, entry: GroupKeySet) -> Result<(), Error> { - if let Some(existing) = self - .groups - .key_sets - .iter_mut() - .find(|e| e.group_key_set_id == entry.group_key_set_id) - { - *existing = entry; - } else { - self.groups - .key_sets - .push(entry) - .map_err(|_| ErrorCode::ResourceExhausted)?; - } - Ok(()) - } - - /// Remove a group key set by ID. Returns true if found and removed. - fn group_key_set_remove(&mut self, id: u16) -> Result<(), Error> { - let before = self.groups.key_sets.len(); - self.groups.key_sets.retain(|e| e.group_key_set_id != id); - - // Check if element was actually removed - // TODO: Should this also remove group entries referencing this key set? - match self.groups.key_sets.len() < before { - true => Ok(()), - false => Err(Error::new(ErrorCode::NotFound)), - } - } - - fn group_key_map_add(&mut self, entry: GroupKeyMapping) -> Result<(), Error> { - self.groups - .key_map - .push(entry) - .map_err(|_| ErrorCode::Failure)?; - - Ok(()) - } - - /// Return an iterator over the group key map entries of the fabric - pub fn group_key_map_iter(&self) -> impl Iterator { - self.groups.key_map.iter() - } - - /// Replace all group key map entries - fn group_key_map_replace( - &mut self, - entries: impl Iterator, - ) -> Result<(), Error> { - self.groups.key_map.clear(); - for entry in entries { - self.groups - .key_map - .push(entry) - .map_err(|_| ErrorCode::ResourceExhausted)?; - } - Ok(()) - } - - /// Remove group key map entries that reference a specific key set ID - fn group_key_map_remove_by_key_set(&mut self, key_set_id: u16) { - self.groups - .key_map - .retain(|e| e.group_key_set_id != key_set_id); - } - - /// Return an iterator over the group table entries - pub fn group_iter(&self) -> impl Iterator { - self.groups.endpoint_mapping.iter() - } - - /// Look up a group by ID - pub fn group_get(&self, group_id: u16) -> Option<&GroupEndpointMapping> { - self.groups - .endpoint_mapping - .iter() - .find(|e| e.group_id == group_id) - } - - /// Add an endpoint to a group. - /// Returns true if the endpoint was already a member (name still updated per spec). - fn group_add( - &mut self, - endpoint_id: u16, - group_id: u16, - group_name: &str, - ) -> Result { - let entry = if let Some(entry) = self - .groups - .endpoint_mapping - .iter_mut() - .find(|e| e.group_id == group_id) - { - entry - } else { - self.groups - .endpoint_mapping - .push(GroupEndpointMapping { - group_id, - endpoints: Vec::new(), - group_name: unwrap!(String::from_str(group_name)), - }) - .map_err(|_| ErrorCode::ResourceExhausted)?; - unwrap!(self.groups.endpoint_mapping.last_mut()) - }; - - // Update group name - entry.group_name.clear(); - unwrap!(entry.group_name.push_str(group_name)); - - if entry.endpoints.contains(&endpoint_id) { - return Ok(true); - } - - entry - .endpoints - .push(endpoint_id) - .map_err(|_| ErrorCode::ResourceExhausted)?; - - Ok(false) - } - - /// Remove an endpoint from a group, or from all groups if `group_id` is `None`. - /// Returns true if the endpoint was removed from at least one group. - fn group_remove(&mut self, endpoint_id: u16, group_id: Option) -> bool { - let mut removed = false; - - for entry in self.groups.endpoint_mapping.iter_mut() { - if group_id.is_some_and(|id| id != entry.group_id) { - continue; - } - let before = entry.endpoints.len(); - entry.endpoints.retain(|&ep| ep != endpoint_id); - if entry.endpoints.len() < before { - removed = true; - } - } - - // Remove entries with no endpoints left - self.groups - .endpoint_mapping - .retain(|e| !e.endpoints.is_empty()); - - removed - } - /// Compute the compressed fabric ID pub(crate) fn compute_compressed_fabric_id( crypto: C, @@ -709,7 +707,6 @@ cfg_if! { /// All fabrics pub struct Fabrics { fabrics: Vec, - changed: bool, } impl Default for Fabrics { @@ -724,7 +721,6 @@ impl Fabrics { pub const fn new() -> Self { Self { fabrics: Vec::new(), - changed: false, } } @@ -732,72 +728,65 @@ impl Fabrics { pub fn init() -> impl Init { init!(Self { fabrics <- Vec::init(), - changed: false, }) } - /// Removes all fabrics - /// - /// # Arguments - /// - `flag_changed`: Whether to mark the fabrics as changed - pub fn reset(&mut self, flag_changed: bool) { + /// Remove all fabrics + pub fn reset(&mut self) { self.fabrics.clear(); - self.changed = flag_changed; } - /// Load the fabrics from the provided buffer as TLV data. + /// Remove all fabrics from the provided BLOB store as well as from memory. /// /// # Arguments - /// - `data`: The TLV data to load the fabrics from - /// - `mdns_notif`: A callback function to notify about mDNS changes - pub fn load(&mut self, data: &[u8], mdns_notif: &mut dyn FnMut()) -> Result<(), Error> { - self.fabrics.clear(); - - mdns_notif(); - - for entry in TLVElement::new(data).array()?.iter() { - let entry = entry?; + /// - `store`: the BLOB store to remove the fabrics from + /// - `buf`: a temporary buffer to use for removing the fabrics + pub async fn reset_persist( + &mut self, + mut store: S, + buf: &mut [u8], + ) -> Result<(), Error> { + self.reset(); - self.fabrics.push_init(Fabric::init_from_tlv(entry), || { - ErrorCode::ResourceExhausted.into() - })?; + for idx in 1..=255u8 { + store.remove(FABRIC_KEYS_START + idx as u16, buf)?; } - mdns_notif(); - self.changed = false; + info!("Removed all fabrics from storage"); Ok(()) } - /// Store the fabrics into the provided buffer as TLV data. + /// Load all fabrics from the provided BLOB store /// /// # Arguments - /// - `buf`: The byte slice to store the state into - /// - /// Returns the number of bytes written into the buffer. - pub fn store(&mut self, buf: &mut [u8]) -> Result { - let mut wb = WriteBuf::new(buf); - - wb.start_array(&TLVTag::Anonymous)?; - - for fabric in self.iter() { - fabric - .to_tlv(&TagType::Anonymous, &mut wb) - .map_err(|_| ErrorCode::NoSpace)?; + /// - `store`: the BLOB store to load the fabrics from + /// - `buf`: a temporary buffer to use for loading the fabrics + pub async fn load_persist( + &mut self, + mut store: S, + buf: &mut [u8], + ) -> Result<(), Error> { + self.reset(); + + for idx in 1..=255u8 { + if let Some(data) = store.load(FABRIC_KEYS_START + idx as u16, buf)? { + self.fabrics + .push_init(Fabric::init_from_tlv(TLVElement::new(data)), || { + ErrorCode::ResourceExhausted.into() + })?; + + let fabric = unwrap!(self.fabrics.last()); + + info!( + "Loaded fabric {} with ID {:x} from storage", + fabric.fab_idx(), + fabric.compressed_fabric_id() + ); + } } - wb.end_container()?; - - self.changed = false; - - let len = wb.get_tail(); - - Ok(len) - } - - /// Check if the fabrics have changed since the last store operation - pub fn is_changed(&self) -> bool { - self.changed + Ok(()) } /// Add a new fabric to the fabrics with the provided data and immediately updates it with the provided post-init updater. @@ -836,7 +825,6 @@ impl Fabrics { )?; let fabric = unwrap!(self.fabrics.last_mut()); - self.changed = true; Ok(fabric) } @@ -900,12 +888,10 @@ impl Fabrics { crypto, root_ca, noc, icac, secret_key, None, None, None, mdns_notif, )?; - self.changed = true; - Ok(fabric) } - pub fn update_label(&mut self, fab_idx: NonZeroU8, label: &str) -> Result<(), Error> { + pub fn update_label(&mut self, fab_idx: NonZeroU8, label: &str) -> Result<&mut Fabric, Error> { if self.iter().any(|fabric| { fabric.fab_idx != fab_idx && !fabric.label.is_empty() && fabric.label == label }) { @@ -919,9 +905,7 @@ impl Fabrics { .push_str(label) .map_err(|_| ErrorCode::ConstraintError)?; - self.changed = true; - - Ok(()) + Ok(fabric) } /// Remove a fabric from the fabrics @@ -937,7 +921,6 @@ impl Fabrics { self.fabrics.retain(|fabric| fabric.fab_idx != fab_idx); mdns_notif(); - self.changed = true; Ok(()) } @@ -971,6 +954,20 @@ impl Fabrics { self.fabrics.iter() } + /// Get a fabric by its local index + /// + /// Returns an error if the fabric is not found + pub fn fabric(&self, fab_idx: NonZeroU8) -> Result<&Fabric, Error> { + self.get(fab_idx).ok_or(ErrorCode::NotFound.into()) + } + + /// Get a mutable fabric reference by its local index + /// + /// Returns an error if the fabric is not found + pub fn fabric_mut(&mut self, fab_idx: NonZeroU8) -> Result<&mut Fabric, Error> { + self.get_mut(fab_idx).ok_or(ErrorCode::NotFound.into()) + } + /// Check if the given access request should be allowed, based on all operational fabrics /// and their ACLs pub fn allow(&self, req: &AccessReq) -> bool { @@ -1009,195 +1006,34 @@ impl Fabrics { fabric.allow(req) } +} - /// Add a new ACL entry to the fabric with the provided local index - /// - /// Return the index of the added entry. - pub fn acl_add(&mut self, fab_idx: NonZeroU8, entry: AclEntry) -> Result { - let index = self - .get_mut(fab_idx) - .ok_or(ErrorCode::NotFound)? - .acl_add(entry)?; - self.changed = true; - - Ok(index) - } - - /// Add a new ACL entry to the fabric with the provided local index and initializer - /// - /// Return the index of the added entry. - pub fn acl_add_init(&mut self, fab_idx: NonZeroU8, init: I) -> Result - where - I: Init, - { - let index = self - .get_mut(fab_idx) - .ok_or(ErrorCode::NotFound)? - .acl_add_init(init)?; - self.changed = true; - - Ok(index) - } - - /// Update an existing ACL entry in the fabric with the provided local index - pub fn acl_update( - &mut self, - fab_idx: NonZeroU8, - idx: usize, - entry: AclEntry, - ) -> Result<(), Error> { - self.get_mut(fab_idx) - .ok_or(ErrorCode::NotFound)? - .acl_update(idx, entry)?; - self.changed = true; - - Ok(()) - } - - /// Update an existing ACL entry in the fabric with the provided local index and initializer - pub fn acl_update_init( - &mut self, - fab_idx: NonZeroU8, - idx: usize, - init: I, - ) -> Result<(), Error> - where - I: Init, - { - self.get_mut(fab_idx) - .ok_or(ErrorCode::NotFound)? - .acl_update_init(idx, init)?; - self.changed = true; - - Ok(()) - } - - /// Remove an ACL entry from the fabric with the provided local index - pub fn acl_remove(&mut self, fab_idx: NonZeroU8, idx: usize) -> Result<(), Error> { - self.get_mut(fab_idx) - .ok_or(ErrorCode::NotFound)? - .acl_remove(idx)?; - self.changed = true; - - Ok(()) - } - - /// Remove all ACL entries from the fabric with the provided local index - pub fn acl_remove_all(&mut self, fab_idx: NonZeroU8) -> Result<(), Error> { - self.get_mut(fab_idx) - .ok_or(ErrorCode::NotFound)? - .acl_remove_all(); - self.changed = true; - - Ok(()) - } - - /// Add or update a group key set for the fabric with the provided local index - pub fn group_key_set_add( - &mut self, - fab_idx: NonZeroU8, - entry: GroupKeySet, - ) -> Result<(), Error> { - self.get_mut(fab_idx) - .ok_or(ErrorCode::NotFound)? - .group_key_set_add(entry)?; - self.changed = true; - Ok(()) - } - - /// Remove a group key set from the fabric with the provided local index. - /// Also removes all group key map entries referencing this key set. - /// Returns Ok(()) if the key set was found and removed. - pub fn group_key_set_remove(&mut self, fab_idx: NonZeroU8, id: u16) -> Result<(), Error> { - let fabric = self.get_mut(fab_idx).ok_or(ErrorCode::NotFound)?; - fabric.group_key_set_remove(id)?; - fabric.group_key_map_remove_by_key_set(id); - self.changed = true; - Ok(()) - } - - /// Replace all group key map entries for the fabric with the provided local index - pub fn group_key_map_replace( - &mut self, - fab_idx: NonZeroU8, - entries: impl Iterator, - ) -> Result<(), Error> { - self.get_mut(fab_idx) - .ok_or(ErrorCode::NotFound)? - .group_key_map_replace(entries)?; - self.changed = true; - Ok(()) - } - - /// Replace all group key map entries for the fabric with the provided local index - pub fn group_key_map_add( - &mut self, - fab_idx: NonZeroU8, - entry: GroupKeyMapping, - ) -> Result<(), Error> { - self.get_mut(fab_idx) - .ok_or(ErrorCode::NotFound)? - .group_key_map_add(entry)?; - self.changed = true; - Ok(()) - } - - #[inline(always)] - pub fn max_group_keys_per_fabric(&self) -> u16 { - // Group key sets + IPK - MAX_GROUP_KEYS_PER_FABRIC as u16 - } +/// A utility for persisting a fabric in a `KvBlobStore` instance. +pub struct FabricPersist(Persist); - #[inline(always)] - pub fn max_groups_per_fabric(&self) -> u16 { - MAX_GROUPS_PER_FABRIC as u16 +impl FabricPersist +where + S: KvBlobStoreAccess, +{ + /// Create a new `FabricPersist` with the given key-value store instance. + pub const fn new(kvb: S) -> Self { + Self(Persist::new(kvb)) } - /// Add an endpoint to a group for the given fabric. - /// Returns true if the endpoint was already a member (name updated). - pub fn group_add( - &mut self, - endpoint_id: u16, - group_id: u16, - group_name: &str, - fab_idx: NonZeroU8, - ) -> Result { - let result = self - .get_mut(fab_idx) - .ok_or(ErrorCode::NotFound)? - .group_add(endpoint_id, group_id, group_name)?; - self.changed = true; - Ok(result) + /// Save the provided fabric in the persistent storage. + pub fn store(&mut self, fabric: &Fabric) -> Result<(), Error> { + self.0 + .store_tlv(FABRIC_KEYS_START + fabric.fab_idx().get() as u16, fabric) } - /// Remove an endpoint from a group for the given fabric. - /// Returns true if the endpoint was a member and was removed. - pub fn group_remove( - &mut self, - endpoint_id: u16, - group_id: u16, - fab_idx: NonZeroU8, - ) -> Result { - let removed = self - .get_mut(fab_idx) - .ok_or(ErrorCode::NotFound)? - .group_remove(endpoint_id, Some(group_id)); - if removed { - self.changed = true; - } - Ok(removed) + /// Remove the fabric with the given index from the persistent storage. + pub fn remove(&mut self, fab_idx: NonZeroU8) -> Result<(), Error> { + self.0.remove(FABRIC_KEYS_START + fab_idx.get() as u16) } - /// Remove all group memberships for an endpoint on the given fabric. - pub fn group_remove_all_for_endpoint( - &mut self, - endpoint_id: u16, - fab_idx: NonZeroU8, - ) -> Result<(), Error> { - self.get_mut(fab_idx) - .ok_or(ErrorCode::NotFound)? - .group_remove(endpoint_id, None); - self.changed = true; - Ok(()) + /// Call at the end when finished with everything else + /// No-op for now + pub fn run(self) -> Result<(), Error> { + self.0.run() } } diff --git a/rs-matter/src/failsafe.rs b/rs-matter/src/failsafe.rs index bd5dcbeb4..e5d076fc5 100644 --- a/rs-matter/src/failsafe.rs +++ b/rs-matter/src/failsafe.rs @@ -23,10 +23,9 @@ use crate::crypto::{ CanonAeadKeyRef, CanonPkcSecretKey, CanonPkcSecretKeyRef, Crypto, SecretKey, PKC_SECRET_KEY_ZEROED, }; -use crate::dm::BasicContext; use crate::error::{Error, ErrorCode}; -use crate::fabric::Fabrics; -use crate::im::IMStatusCode; +use crate::fabric::{Fabric, Fabrics}; +use crate::im::{AttrId, ClusterId, EndptId, IMStatusCode}; use crate::sc::pase::Pase; use crate::tlv::TLVElement; use crate::transport::session::SessionMode; @@ -115,7 +114,8 @@ impl FailSafe { breadcrumb: u64, session_mode: &SessionMode, pase: &mut Pase, - ctx: impl BasicContext, + notify_mdns: impl FnMut(), + notify_change: impl FnMut(EndptId, ClusterId, AttrId), ) -> Result<(), Error> { self.update_state_timeout(); @@ -126,7 +126,8 @@ impl FailSafe { } // Cannot arm via CASE while there's an active window - if pase.comm_window(&ctx)?.is_some() && matches!(session_mode, SessionMode::Case { .. }) + if pase.comm_window(notify_mdns, notify_change)?.is_some() + && matches!(session_mode, SessionMode::Case { .. }) { return Err(ErrorCode::Busy)?; } @@ -259,16 +260,16 @@ impl FailSafe { } #[allow(clippy::too_many_arguments)] - pub fn update_noc( + pub fn update_noc<'a, C: Crypto>( &mut self, crypto: C, - fabrics: &mut Fabrics, + fabrics: &'a mut Fabrics, session_mode: &SessionMode, icac: Option<&[u8]>, noc: &[u8], buf: &mut [u8], mdns_notif: &mut dyn FnMut(), - ) -> Result<(), Error> { + ) -> Result<&'a mut Fabric, Error> { self.update_state_timeout(); let fab_idx = Self::get_case_fab_idx(session_mode)?; @@ -289,7 +290,7 @@ impl FailSafe { buf, )?; - fabrics.update( + let fabric = fabrics.update( &crypto, fab_idx, self.secret_key.reference(), @@ -301,14 +302,14 @@ impl FailSafe { self.add_flags(NocFlags::UPDATE_NOC_RECVD); - Ok(()) + Ok(fabric) } #[allow(clippy::too_many_arguments)] - pub fn add_noc( + pub fn add_noc<'a, C: Crypto>( &mut self, crypto: C, - fabrics: &mut Fabrics, + fabrics: &'a mut Fabrics, session_mode: &SessionMode, vendor_id: u16, icac: Option<&[u8]>, @@ -317,7 +318,7 @@ impl FailSafe { case_admin_subject: u64, buf: &mut [u8], mdns_notif: &mut dyn FnMut(), - ) -> Result { + ) -> Result<&'a mut Fabric, Error> { self.update_state_timeout(); self.check_state( @@ -339,7 +340,7 @@ impl FailSafe { // TODO: Copy functionality from C++ FabricTable::FindExistingFabricByNocChaining // i.e. need to check to see if a fabric with these creds are already present - let fab_idx = fabrics + let fabric = fabrics .add( &crypto, self.secret_key.reference(), @@ -357,10 +358,12 @@ impl FailSafe { } else { e } - })? - .fab_idx(); + })?; - info!("Added operational fabric with local index {}", fab_idx); + info!( + "Added operational fabric with local index {}", + fabric.fab_idx() + ); let State::Armed(ctx) = &mut self.state else { // Impossible to be in any other state because otherwise @@ -368,10 +371,10 @@ impl FailSafe { unreachable!(); }; - ctx.fab_idx = fab_idx.get(); + ctx.fab_idx = fabric.fab_idx().get(); self.add_flags(NocFlags::ADD_NOC_RECVD); - Ok(fab_idx) + Ok(fabric) } pub fn breadcrumb(&mut self) -> u64 { diff --git a/rs-matter/src/lib.rs b/rs-matter/src/lib.rs index dae404aac..7b5f69d3c 100644 --- a/rs-matter/src/lib.rs +++ b/rs-matter/src/lib.rs @@ -34,15 +34,17 @@ use core::future::Future; use crate::crypto::Crypto; use crate::dm::clusters::basic_info::{BasicInfoConfig, BasicInfoSettings}; use crate::dm::clusters::dev_att::DeviceAttestation; -use crate::dm::{BasicContextInstance, ChangeNotify}; +use crate::dm::ChangeNotify; use crate::error::{Error, ErrorCode}; use crate::fabric::Fabrics; use crate::failsafe::FailSafe; +use crate::im::{AttrId, ClusterId, EndptId}; use crate::pairing::qr::{ no_optional_data, CommFlowType, NoOptionalData, Qr, QrPayload, QrTextType, }; use crate::pairing::DiscoveryCapabilities; -use crate::sc::pase::spake2p::Spake2pVerifierPassword; +use crate::persist::KvBlobStore; +use crate::sc::pase::spake2p::{Spake2pVerifierPassword, SPAKE2P_VERIFIER_SALT_ZEROED}; use crate::sc::pase::Pase; use crate::transport::network::{NetworkMulticast, NetworkReceive, NetworkSend}; use crate::transport::session::Sessions; @@ -57,6 +59,8 @@ use crate::utils::storage::WriteBuf; use crate::utils::sync::blocking::Mutex; use crate::utils::sync::Notification; +use rand_core::RngCore; + /// Re-export the `rs_matter_macros::import` proc-macro pub use rs_matter_macros::import; @@ -212,8 +216,6 @@ pub struct Matter<'a> { /// /// Public for unit tests pub transport: Transport, - /// A notification that the Matter state had changed in a way that might require persistence - state_changed: Notification, /// A notification that the Matter mDNS services might have changed mdns_changed: Notification, /// A notification that a session had been removed @@ -279,7 +281,6 @@ impl<'a> Matter<'a> { Self { state: Mutex::new(RefCell::new(MatterState::new(epoch))), transport: Transport::new(dev_det), - state_changed: Notification::new(), mdns_changed: Notification::new(), session_removed: Notification::new(), groups_modified: Notification::new(), @@ -337,7 +338,6 @@ impl<'a> Matter<'a> { Self { state <- Mutex::init(RefCell::init(MatterState::init(epoch))), transport <- Transport::init(dev_det), - state_changed: Notification::new(), mdns_changed: Notification::new(), session_removed: Notification::new(), groups_modified: Notification::new(), @@ -512,15 +512,27 @@ impl<'a> Matter<'a> { crypto: C, notify: &dyn ChangeNotify, ) -> Result<(), Error> { - let ctx = BasicContextInstance::new(self, crypto, notify); + let notify_mdns = || self.notify_mdns(); + let notify_change = + |endpt_id, clust_id, attr_id| notify.notify(endpt_id, clust_id, attr_id); self.with_state(|state| { + let mut rand = crypto.rand()?; + + let mdns_id = rand.next_u64(); + + let mut salt = SPAKE2P_VERIFIER_SALT_ZEROED; + rand.fill_bytes(salt.access_mut()); + state.pase.open_basic_comm_window( - ctx, + mdns_id, + salt.reference(), self.dev_comm.password.reference(), self.dev_comm.discriminator, timeout_secs, None, + notify_mdns, + notify_change, ) }) } @@ -528,14 +540,12 @@ impl<'a> Matter<'a> { /// Close the basic commissioning window /// /// The method will return Ok(false) if there is no active PASE commissioning window to close. - pub fn close_comm_window( - &self, - crypto: C, - notify: &dyn ChangeNotify, - ) -> Result { - let ctx = BasicContextInstance::new(self, crypto, notify); + pub fn close_comm_window(&self, notify: &dyn ChangeNotify) -> Result { + let notify_mdns = || self.notify_mdns(); + let notify_change = + |endpt_id, clust_id, attr_id| notify.notify(endpt_id, clust_id, attr_id); - self.with_state(|state| state.pase.close_comm_window(ctx)) + self.with_state(|state| state.pase.close_comm_window(notify_mdns, notify_change)) } /// Create a new transport runner instance @@ -598,104 +608,65 @@ impl<'a> Matter<'a> { /// Reset the Matter persistable state by removing all fabrics and resetting basic info settings /// - /// # Arguments - /// - `flag_changed`: If true, notifies that fabrics and basic info settings have changed - pub fn reset_persist(&self, flag_changed: bool) { - self.with_state(|state| { - state.basic_info_settings.reset(flag_changed); - state.fabrics.reset(flag_changed); - }); - - self.notify_mdns(); + /// Arguments: + /// - `kv`: The key-value store to load the fabrics and basic info settings from + /// - `buf`: A buffer to use for loading the fabrics and basic info settings + pub async fn reset_persist( + &mut self, + mut kv: S, + buf: &mut [u8], + ) -> Result<(), Error> { + { + let state = self.state.get_mut(); + let mut state = state.borrow_mut(); - if flag_changed { - self.notify_persist(); + state.fabrics.reset_persist(&mut kv, buf).await?; + state + .basic_info_settings + .reset_persist(&mut kv, buf) + .await?; } - } - /// Notify that the ACLs, Fabrics or Basic Info _might_ have changed - /// This method is supposed to be called after processing SC and IM messages that might affect the ACLs, Fabrics or Basic Info. - /// - /// The default IM and SC handlers (`DataModel` and `SecureChannel`) do call this method after processing the messages. - /// - /// TODO: Fix the method name as it is not clear enough. Potentially revamp the whole persistence notification logic - pub fn notify_persist(&self) { - self.state_changed.notify(); - } - - /// Load fabrics from the given data - /// - /// Arguments: - /// - `data`: The data to load the fabrics from - pub fn load_fabrics(&self, data: &[u8]) -> Result<(), Error> { - self.with_state(|state| state.fabrics.load(data, &mut || self.notify_mdns())) - } - - /// Store fabrics into the given buffer - /// - /// Arguments: - /// - `buf`: The buffer to store the fabrics into - /// - /// Returns the number of bytes written into the buffer. - pub fn store_fabrics(&self, buf: &mut [u8]) -> Result { - self.with_state(|state| state.fabrics.store(buf)) - } + self.notify_mdns(); - /// Return true if the fabrics have changed since the last call to `store_fabrics` - pub fn fabrics_changed(&self) -> bool { - self.with_state(|state| state.fabrics.is_changed()) + Ok(()) } - /// Load basic info settings from the given data + /// Load fabrics from the given data /// /// Arguments: - /// - `data`: The data to load the basic info settings from - pub fn load_basic_info(&self, data: &[u8]) -> Result<(), Error> { - self.with_state(|state| state.basic_info_settings.load(data)) - } + /// - `kv`: The key-value store to load the fabrics and basic info settings from + /// - `buf`: A buffer to use for loading the fabrics and basic info settings + pub async fn load_persist( + &mut self, + mut kv: S, + buf: &mut [u8], + ) -> Result<(), Error> { + { + let state = self.state.get_mut(); + let mut state = state.borrow_mut(); - /// Store basic info settings into the given buffer - /// - /// Arguments: - /// - `buf`: The buffer to store the basic info settings into - /// - /// Returns the number of bytes written into the buffer. - pub fn store_basic_info(&self, buf: &mut [u8]) -> Result { - self.with_state(|state| state.basic_info_settings.store(buf)) - } + state.fabrics.load_persist(&mut kv, buf).await?; + state.basic_info_settings.load_persist(&mut kv, buf).await?; + } - /// Return true if the basic info settings have changed since the last call to `store_basic_info` - pub fn basic_info_changed(&self) -> bool { - self.with_state(|state| state.basic_info_settings.changed) - } + self.notify_mdns(); - /// A hook for user persistence code to wait for potential changes in ACLs, Fabrics or basic info. - /// - /// Once this future resolves, user code is supposed to inspect ACLs, Fabrics and basic info for changes, and - /// if there are changes, persist them. - /// - /// TODO: Fix the method name as it is not clear enough. Potentially revamp the whole persistence notification logic - pub fn wait_persist(&self) -> impl Future + '_ { - self.state_changed.wait() + Ok(()) } /// Invoke the given closure for each currently published Matter mDNS service. - pub fn mdns_services( - &self, - crypto: C, - notify: &dyn ChangeNotify, - mut f: F, - ) -> Result<(), Error> + pub fn mdns_services(&self, notify_change: C, mut f: F) -> Result<(), Error> where - C: Crypto, + C: FnMut(EndptId, ClusterId, AttrId), F: FnMut(MatterMdnsService) -> Result<(), Error>, { - let ctx = BasicContextInstance::new(self, crypto, notify); - debug!("=== Currently published mDNS services"); + let notify_mdns = || self.notify_mdns(); + self.with_state(|state| { - if let Some(comm_window) = state.pase.comm_window(ctx)? { + if let Some(comm_window) = state.pase.comm_window(notify_mdns, notify_change)? { // Do not remove this logging line or change its formatting. // C++ E2E tests rely on this log line to determine when the mDNS service is published debug!("mDNS service published: {:?}", comm_window.mdns_service()); diff --git a/rs-matter/src/persist.rs b/rs-matter/src/persist.rs index 78d6722f2..936c041e0 100644 --- a/rs-matter/src/persist.rs +++ b/rs-matter/src/persist.rs @@ -15,248 +15,286 @@ * limitations under the License. */ -//! This module provides a simple persistent storage manager (PSM) for `rs-matter`. +//! This module provides the key-value BLOB store traits used throughout `rs-matter` for persistence, as well as some implementations for those. + +use core::borrow::BorrowMut; + +use crate::error::Error; +use crate::tlv::{TLVTag, ToTLV}; +use crate::utils::cell::RefCell; +use crate::utils::storage::WriteBuf; +use crate::utils::sync::blocking::Mutex; #[cfg(feature = "std")] pub use fileio::*; -#[cfg(feature = "std")] -pub mod fileio { - use core::mem::MaybeUninit; +/// The first key available for the vendor-specific data. +pub const VENDOR_KEYS_START: u16 = 0x1000; - use std::fs; - use std::io::{Read, Write}; - use std::path::Path; +/// The key range reserved for fabrics (256 keys). +pub const FABRIC_KEYS_START: u16 = 0; - use embassy_futures::select::{select, select3}; +/// The key used for storing the basic info settings. +pub const BASIC_INFO_KEY: u16 = FABRIC_KEYS_START + 256; - use crate::dm::events::Events; - use crate::dm::networks::wireless::{Wifi, WirelessNetwork, WirelessNetworks}; - use crate::error::{Error, ErrorCode}; - use crate::tlv::{ - Octets, TLVArray, TLVContainerIter, TLVElement, TLVTag, TLVValueType, TLVWrite, - }; - use crate::utils::init::{init, Init}; - use crate::utils::storage::WriteBuf; - use crate::Matter; +/// The key used for storing the events epoch number. +pub const EVENT_EPOCH_KEY: u16 = BASIC_INFO_KEY + 1; - /// A constant representing the absence of wireless networks. - pub const NO_NETWORKS: Option<&'static WirelessNetworks<0, Wifi>> = None; +/// The key used for storing the wireless networks state. +pub const NETWORKS_KEY: u16 = BASIC_INFO_KEY + 2; - /// A constant representing the absence of events. - pub const NO_EVENTS: Option<&'static Events<0>> = None; +/// A trait representing a key-value BLOB storage. +/// +/// NOTE: For now, the trait is deliberately modeled as non-async, so that it can be used from +/// regular `Handler` non-async instances so as to avoid code bloat due to too much async handlers. +/// +/// However, this might change in future once/if rustc starts to optimize the generated async code a bit better. +pub trait KvBlobStore { + /// Load a BLOB with the specified key from the storage. + /// + /// # Arguments + /// - `key` - the key of the BLOB + /// - `buf` - a buffer that the `KvBlobStore` implementation might use for its own purposes + /// + /// # Returns + /// - `Ok(Some(&[u8]))` if the BLOB was successfully loaded, + /// - `Ok(None)` if the BLOB with the specified key does not exist, + /// - `Err` if an error occurred during loading. + fn load<'a>(&mut self, key: u16, buf: &'a mut [u8]) -> Result, Error>; - /// A simple persistent storage manager (PSM) for `rs-matter`. + /// Store a BLOB with the specified key in the storage. + /// + /// # Arguments + /// - `key` - the key of the BLOB + /// - `data` - the data to store + /// - `buf` - a buffer that the `KvBlobStore` implementation might use for its own purposes /// - /// This storage saves everything (fabrics, basic info settings and wireless networks (if any)) - /// as a single file, which is compatible with the `chip-tool` YAML tests which - at least in V1.3.0.0 - - /// do expect a single file for all persistent data. + /// # Returns + /// - `Ok(())` if the BLOB was successfully stored, + /// - `Err` if an error occurred during storing. + fn store(&mut self, key: u16, data: &[u8], buf: &mut [u8]) -> Result<(), Error>; + + /// Remove a BLOB with the specified key from the storage. /// - /// Moreover, this storage always persists the whole state, regardless what had changed, which - /// requires a large memory buffer, which can keep the TLV data of all fabrics, basic info settings and wireless networks. + /// # Arguments + /// - `key` - the key of the BLOB + /// - `buf` - a buffer that the `KvBlobStore` implementation might use for its own purposes /// - /// NOTE: Production applications might need a more sophisticated persistent storage where e.g. - /// each fabric is stored as a separate item. - pub struct Psm { - buf: MaybeUninit<[u8; N]>, + /// # Returns + /// - `Ok(())` if the BLOB was successfully removed or did not exist + /// - `Err` if an error occurred during removing. + fn remove(&mut self, key: u16, buf: &mut [u8]) -> Result<(), Error>; +} + +impl KvBlobStore for &mut T +where + T: KvBlobStore, +{ + fn load<'a>(&mut self, key: u16, buf: &'a mut [u8]) -> Result, Error> { + T::load(self, key, buf) } - impl Default for Psm { - fn default() -> Self { - Self::new() - } + fn store(&mut self, key: u16, data: &[u8], buf: &mut [u8]) -> Result<(), Error> { + T::store(self, key, data, buf) } - impl Psm { - /// Create a new `Psm` instance. - #[inline(always)] - pub const fn new() -> Self { - Self { - buf: MaybeUninit::uninit(), - } - } + fn remove(&mut self, key: u16, buf: &mut [u8]) -> Result<(), Error> { + T::remove(self, key, buf) + } +} - /// Return an in-place initializer for `Psm`. - pub fn init() -> impl Init { - init!(Self { - buf <- crate::utils::init::zeroed(), - }) - } +/// A noop implementation of the `KvBlobStore` trait. +pub struct DummyKvBlobStore; - /// Load the persistent state from the given file path into the provided `Matter` instance - /// - /// Arguments: - /// - `path`: The file path from where to load the persistent state. - /// - `matter`: The `Matter` instance to load the state into (for fabrics and basic info settings). - /// - `networks`: An optional reference to `WirelessNetworks` to load the wireless networks state into (if provided). - /// - `events`: An optional reference to `Events` to load the events state into (if provided). - pub fn load( - &mut self, - path: P, - matter: &Matter, - networks: Option<&WirelessNetworks>, - events: Option<&Events>, - ) -> Result<(), Error> - where - P: AsRef, - T: WirelessNetwork, - { - let buf = unsafe { self.buf.assume_init_mut() }; - - let Some(data) = Self::load_storage(path.as_ref(), buf)? else { - return Ok(()); - }; - - let root = TLVElement::new(data); - - if root.control()?.value_type == TLVValueType::Array { - // Legacy format: anonymous array with positional octet-strings - let mut items: TLVContainerIter<'_, Octets<'_>> = TLVArray::new(root)?.iter(); - - matter.load_fabrics(items.next().ok_or(ErrorCode::Invalid)??.0)?; - matter.load_basic_info(items.next().ok_or(ErrorCode::Invalid)??.0)?; - - if let Some(networks) = networks { - networks.load(items.next().ok_or(ErrorCode::Invalid)??.0)?; - } - } else { - // New format: struct with context-tagged octet-strings - let container = root.container()?; +impl KvBlobStore for DummyKvBlobStore { + fn load<'a>(&mut self, _key: u16, _buf: &'a mut [u8]) -> Result, Error> { + Ok(None) + } - matter.load_fabrics(container.find_ctx(0)?.octets()?)?; - matter.load_basic_info(container.find_ctx(1)?.octets()?)?; + fn store(&mut self, _key: u16, _data: &[u8], _buf: &mut [u8]) -> Result<(), Error> { + Ok(()) + } - if let Some(networks) = networks { - networks.load(container.find_ctx(2)?.octets()?)?; - } + fn remove(&mut self, _key: u16, _buf: &mut [u8]) -> Result<(), Error> { + Ok(()) + } +} + +/// A trait representing access to a `KvBlobStore` instance and a buffer for its use. +pub trait KvBlobStoreAccess { + /// Get the `KvBlobStore` instance and buffer provided by this access. + fn access(&self, f: F) -> R + where + F: FnOnce(&mut dyn KvBlobStore, &mut [u8]) -> R; +} + +impl KvBlobStoreAccess for &T +where + T: KvBlobStoreAccess, +{ + fn access(&self, f: F) -> R + where + F: FnOnce(&mut dyn KvBlobStore, &mut [u8]) -> R, + { + T::access(self, f) + } +} + +/// A noop implementation of the `KvBlobStoreAccess` trait. +pub struct DummyKvBlobStoreAccess; + +impl KvBlobStoreAccess for DummyKvBlobStoreAccess { + fn access(&self, f: F) -> R + where + F: FnOnce(&mut dyn KvBlobStore, &mut [u8]) -> R, + { + f(&mut DummyKvBlobStore, &mut []) + } +} + +/// An implementation of the `KvBlobStoreAccess` trait that provides access +/// to a shared `KvBlobStore` instance and a shared buffer using async mutex. +pub struct SharedKvBlobStore(Mutex>); + +impl SharedKvBlobStore { + /// Create a new `SharedKvBlobStore` instance. + /// + /// # Arguments + /// - `store` - the wrapped `KvBlobStore` instance + /// - `buf` - the wrapped buffer + pub const fn new(store: S, buf: T) -> Self { + Self(Mutex::new(RefCell::new((store, buf)))) + } +} - if let Some(events) = events { - events.load(container.find_ctx(3)?.octets()?)?; +impl KvBlobStoreAccess for SharedKvBlobStore +where + S: KvBlobStore, + T: BorrowMut<[u8]>, +{ + fn access(&self, f: F) -> R + where + F: FnOnce(&mut dyn KvBlobStore, &mut [u8]) -> R, + { + self.0.lock(|cell| { + let mut kvb = cell.borrow_mut(); + let kvb = &mut *kvb; + + f(&mut kvb.0, kvb.1.borrow_mut()) + }) + } +} + +/// A utility for persisting a value in a `KvBlobStore` instance. +pub struct Persist { + kvb: S, +} + +impl Persist +where + S: KvBlobStoreAccess, +{ + /// Create a new `Persist` instance with the given key-value store instance. + pub const fn new(kvb: S) -> Self { + Self { kvb } + } + + /// Save a value in the storage with the specified key by calling the provided closure to serialize the value into a buffer. + pub fn store Result, Error>>( + &mut self, + key: u16, + f: F, + ) -> Result<(), Error> { + self.kvb.access(|kvb, buf| { + if !buf.is_empty() { + // DummyKvBlobStoreAccess uses an empty buffer + if let Some(len) = f(buf)? { + let (data, buf) = buf.split_at_mut(len); + kvb.store(key, data, buf)?; } } Ok(()) - } + }) + } - /// Store the persistent state from the provided `Matter` instance - /// - /// If the fabrics, basic info settings or wireless networks (if provided) have not changed, - /// this method does nothing. - /// - /// Arguments: - /// - `path`: The file path where to store the persistent state. - /// - `matter`: The `Matter` instance whose state to store (for fabrics and basic info settings). - /// - `networks`: An optional reference to `WirelessNetworks` whose state to store. - /// - `events`: An optional reference to `Events` whose state to store. - pub fn store( - &mut self, - path: P, - matter: &Matter, - networks: Option<&WirelessNetworks>, - events: Option<&Events>, - ) -> Result<(), Error> - where - P: AsRef, - T: WirelessNetwork, - { - if !matter.fabrics_changed() - && !matter.basic_info_changed() - && !networks.map(|networks| networks.changed()).unwrap_or(false) - && !events.map(|events| events.changed()).unwrap_or(false) - { - return Ok(()); - } + /// Save a value that implements the `ToTLV` trait in the storage with the specified key. + pub fn store_tlv(&mut self, key: u16, tlv: T) -> Result<(), Error> { + self.store(key, |buf| { + let mut wb = WriteBuf::new(buf); - let buf = unsafe { self.buf.assume_init_mut() }; + tlv.to_tlv(&TLVTag::Anonymous, &mut wb)?; - let mut wb = WriteBuf::new(buf); + Ok(Some(wb.get_tail())) + }) + } - wb.start_struct(&TLVTag::Anonymous)?; + /// Remove the value with the specified key from the storage. + pub fn remove(&mut self, key: u16) -> Result<(), Error> { + self.kvb.access(|kvb, buf| { + if !buf.is_empty() { + // DummyKvBlobStoreAccess uses an empty buffer + kvb.remove(key, buf)?; + } - wb.str_cb(&TLVTag::Context(0), |buf| matter.store_fabrics(buf))?; + Ok(()) + }) + } - wb.str_cb(&TLVTag::Context(1), |buf| matter.store_basic_info(buf))?; + /// Call at the end when finished with everything else + /// No-op for now + pub fn run(self) -> Result<(), Error> { + // No-op for now - if let Some(networks) = networks { - wb.str_cb(&TLVTag::Context(2), |buf| networks.store(buf))?; - } + Ok(()) + } +} - if let Some(events) = events { - wb.str_cb(&TLVTag::Context(3), |buf| events.store(buf))?; - } +#[cfg(feature = "std")] +mod fileio { + use std::collections::HashMap; + use std::fs::File; + use std::io::{Read, Write}; + use std::path::{Path, PathBuf}; - wb.end_container()?; + use crate::error::Error; - Self::save_storage(path.as_ref(), wb.as_slice())?; + use super::KvBlobStore; - Ok(()) - } + extern crate std; - /// Run the persistent storage, which waits for changes in the `Matter` instance - /// and the optional `WirelessNetworks` instance (if provided) and stores the state - /// to the given file path whenever a change occurs. - /// - /// Arguments: - /// - `path`: The file path where to store the persistent state. - /// - `matter`: The `Matter` instance to monitor for changes and for state to store (for fabrics and basic info settings). - /// - `networks`: An optional reference to `WirelessNetworks` to monitor for changes and for state to store (if provided). - /// - `events`: An optional reference to `Events` to monitor for changes and for state to store (if provided). - pub async fn run( - &mut self, - path: P, - matter: &Matter<'_>, - networks: Option<&WirelessNetworks>, - events: Option<&Events>, - ) -> Result<(), Error> - where - P: AsRef, - T: WirelessNetwork, - { - // NOTE: Calling `load` here does not make sense, because the `Psm::run` future / async method is executed - // concurrently with other `rs-matter` futures. Including the future (`Matter::run`) that takes a decision whether - // the state of `rs-matter` is such that it is not provisioned yet (no fabrics) and as such - // it has to open the basic commissioning window and print the QR code. - // - // User is supposed to instead explicitly call `load` before calling `Psm::run` and `Matter::run` - // self.load_networks(dir, networks)?; - - loop { - match (networks, events) { - (Some(networks), Some(events)) => { - select3( - matter.wait_persist(), - networks.wait_persist(), - events.wait_persist(), - ) - .await; - } - (Some(networks), None) => { - select(matter.wait_persist(), networks.wait_persist()).await; - } - (None, Some(events)) => { - select(matter.wait_persist(), events.wait_persist()).await; - } - (None, None) => { - matter.wait_persist().await; - } - } + /// An implementation of the `KvBlobStore` trait that stores the BLOBs in a directory. + /// + /// The BLOBs are stored in files named after the keys in the specified directory. + #[derive(Debug, Clone)] + #[cfg_attr(feature = "defmt", derive(defmt::Format))] + pub struct DirKvBlobStore( + #[cfg_attr(feature = "defmt", defmt(Debug2Format))] std::path::PathBuf, + ); + + impl DirKvBlobStore { + /// Create a new `DirKvBlobStore` instance, which will persist + /// its settings in `/rs-matter`. + pub fn new_default() -> Self { + Self(std::env::temp_dir().join("rs-matter")) + } - self.store(path.as_ref(), matter, networks, events)?; - } + /// Create a new `DirKvBlobStore` instance. + pub const fn new(path: std::path::PathBuf) -> Self { + Self(path) } - /// Loads the data from the provided file path into the given buffer. - /// - /// Returns `Ok(Some(&[u8]))` if data was successfully loaded, - /// `Ok(None)` if the file does not exist, or an `Err` if an error occurred. - fn load_storage<'b>(path: &Path, buf: &'b mut [u8]) -> Result, Error> { - match fs::File::open(path) { + /// Load a BLOB with the specified key from the directory. + pub fn load(&self, key: u16, buf: &mut [u8]) -> Result, Error> { + let path = self.key_path(key); + + match File::open(path) { Ok(mut file) => { let mut offset = 0; loop { if offset == buf.len() { - Err(ErrorCode::BufferTooSmall)?; + Err(crate::error::ErrorCode::NoSpace)?; } let len = file.read(&mut buf[offset..])?; @@ -270,162 +308,202 @@ pub mod fileio { let data = &buf[..offset]; - trace!("Loaded {} bytes {:?}", data.len(), data); + debug!("Key {}: loaded {}B ({:?})", key, data.len(), data); - Ok(Some(data)) + Ok(Some(data.len())) } Err(_) => Ok(None), } } - /// Saves the given data to the specified file path. - fn save_storage(path: &Path, data: &[u8]) -> Result<(), Error> { - let mut file = fs::File::create(path)?; + /// Store a BLOB with the specified key in the directory. + pub fn store(&self, key: u16, data: &[u8]) -> Result<(), Error> { + let path = self.key_path(key); + + std::fs::create_dir_all(unwrap!(path.parent()))?; + + let mut file = File::create(path)?; file.write_all(data)?; - trace!("Stored {} bytes {:?}", data.len(), data); + debug!("Key {}: stored {}B ({:?})", key, data.len(), data); Ok(()) } + + /// Remove a BLOB with the specified key from the directory. + /// If the BLOB does not exist, this method does nothing. + pub fn remove(&self, key: u16) -> Result<(), Error> { + let path = self.key_path(key); + + if std::fs::remove_file(path).is_ok() { + debug!("Key {}: removed", key); + } + + Ok(()) + } + + fn key_path(&self, key: u16) -> std::path::PathBuf { + self.0.join(format!("k_{key:04x}")) + } + } + + impl Default for DirKvBlobStore { + fn default() -> Self { + Self::new_default() + } } - #[cfg(test)] - mod tests { - use crate::dm::devices::test::{TEST_DEV_ATT, TEST_DEV_COMM, TEST_DEV_DET}; - use crate::dm::events::{Events, PersistedState}; - use crate::utils::epoch::sys_epoch; - use crate::MATTER_PORT; - - use super::*; - - fn new_test_matter() -> Matter<'static> { - let matter = Matter::new( - &TEST_DEV_DET, - TEST_DEV_COMM, - &TEST_DEV_ATT, - sys_epoch, - MATTER_PORT, - ); - - matter.with_state(|state| { - state.fabrics.add_with_post_init(|_| Ok(())).unwrap(); - }); - - matter + impl KvBlobStore for DirKvBlobStore { + fn load<'a>(&mut self, key: u16, buf: &'a mut [u8]) -> Result, Error> { + Ok(Self::load(self, key, buf)?.map(|len| &buf[..len])) + } + + fn store(&mut self, key: u16, data: &[u8], _buf: &mut [u8]) -> Result<(), Error> { + Self::store(self, key, data) + } + + fn remove(&mut self, key: u16, _buf: &mut [u8]) -> Result<(), Error> { + Self::remove(self, key) + } + } + + /// An implementation of the `KvBlobStore` trait that stores all BLOBs in a single file. + /// + /// While the implementation is very inefficient, it is necessary when testing with the C++ SDK test harness, + /// as it expects all data to be persisted as a single file (`/tmp/chip_kvs`). + #[derive(Debug, Clone)] + #[cfg_attr(feature = "defmt", derive(defmt::Format))] + pub struct FileKvBlobStore { + #[cfg_attr(feature = "defmt", defmt(Debug2Format))] + path: std::path::PathBuf, + #[cfg_attr(feature = "defmt", defmt(Debug2Format))] + blobs: Option>>, + } + + impl FileKvBlobStore { + /// Create a new `FileKvBlobStore` instance, which will persist its settings in `/tmp/chip_kvs`. + pub fn new_default() -> Self { + Self::new(PathBuf::from("/tmp/chip_kvs")) + } + + /// Create a new `FileKvBlobStore` instance. + pub const fn new(path: PathBuf) -> Self { + Self { path, blobs: None } + } + + /// Load a BLOB with the specified key from the file. + pub fn load(&mut self, key: u16, buf: &mut [u8]) -> Result, Error> { + self.initialize()?; + + let blobs = self.blobs.as_ref().unwrap(); + + if let Some(blob) = blobs.get(&key) { + if blob.len() > buf.len() { + Err(crate::error::ErrorCode::NoSpace)?; + } + + buf[..blob.len()].copy_from_slice(blob); + + Ok(Some(blob.len())) + } else { + Ok(None) + } + } + + /// Store a BLOB with the specified key in the directory. + pub fn store(&mut self, key: u16, data: &[u8]) -> Result<(), Error> { + self.initialize()?; + + let blobs = self.blobs.as_mut().unwrap(); + + blobs.insert(key, data.to_vec()); + + Self::save_all(&self.path, blobs) + } + + /// Remove a BLOB with the specified key from the directory. + /// If the BLOB does not exist, this method does nothing. + pub fn remove(&mut self, key: u16) -> Result<(), Error> { + self.initialize()?; + + let blobs = self.blobs.as_mut().unwrap(); + + blobs.remove(&key); + + Self::save_all(&self.path, blobs) } - #[test] - fn test_store_load_roundtrip() { - let dir = tempfile::tempdir().unwrap(); - let path = dir.path().join("persist.bin"); - - // Set up a matter instance with some non-default config - let initial_matter = new_test_matter(); - { - initial_matter.with_state(|state| { - let basic = &mut state.basic_info_settings; - basic.node_label = heapless::String::try_from("my-test-node").unwrap(); - basic.location = Some(heapless::String::try_from("ab").unwrap()); - basic.changed = true; - }); + fn initialize(&mut self) -> Result<(), Error> { + if self.blobs.is_none() { + let mut blobs = HashMap::new(); + + Self::load_all(&self.path, &mut blobs)?; + + self.blobs = Some(blobs); } - // Set up events with a recognizable epoch value - let events = Events::<64>::new(sys_epoch); - events.persisted_state.lock(|cell| { - cell.set(PersistedState { - next_event_no: 0, - event_epoch_end: 0xDEADBEEF, - changed: true, - }); - }); - - let mut psm = Psm::<32768>::new(); - psm.store(&path, &initial_matter, NO_NETWORKS, Some(&events)) - .unwrap(); - - assert!(path.exists()); - assert!(std::fs::metadata(&path).unwrap().len() > 0); - - // Load into fresh instances - let roundtripped = Matter::new( - &TEST_DEV_DET, - TEST_DEV_COMM, - &TEST_DEV_ATT, - sys_epoch, - MATTER_PORT, - ); - let roundtripped_events = Events::<64>::new(sys_epoch); - - let mut psm2 = Psm::<32768>::new(); - psm2.load( - &path, - &roundtripped, - NO_NETWORKS, - Some(&roundtripped_events), - ) - .unwrap(); - - // Basic info fields should've been restored - roundtripped.with_state(|state| { - let basic = &state.basic_info_settings; - assert_eq!(basic.node_label, "my-test-node"); - assert_eq!(basic.location.as_deref(), Some("ab")); - }); - - // Events epoch should've been restored and bumped by one epoch - let events = roundtripped_events.persisted_state.lock(|cell| cell.get()); - assert_eq!(events.next_event_no, 0xDEADBEEF); - assert_eq!(events.event_epoch_end, 0xDEADBEEF + 0x10000); + Ok(()) + } + + fn load_all(path: &Path, blobs: &mut HashMap>) -> Result<(), Error> { + if let Ok(mut file) = File::open(path) { + loop { + let mut key_buf = [0; 2]; + + if file.read_exact(&mut key_buf).is_err() { + break; + } + + let key = u16::from_le_bytes(key_buf); + + let mut len_buf = [0; 4]; + + file.read_exact(&mut len_buf)?; + + let len = u32::from_le_bytes(len_buf) as usize; + + let mut data = vec![0; len]; + + file.read_exact(&mut data)?; + + blobs.insert(key, data); + } + } + + Ok(()) + } + + fn save_all(path: &Path, blobs: &HashMap>) -> Result<(), Error> { + let mut file = File::create(path)?; + + for (key, data) in blobs { + file.write_all(&key.to_le_bytes())?; + file.write_all(&(data.len() as u32).to_le_bytes())?; + file.write_all(data)?; + } + + Ok(()) + } + } + + impl Default for FileKvBlobStore { + fn default() -> Self { + Self::new_default() + } + } + + impl KvBlobStore for FileKvBlobStore { + fn load<'a>(&mut self, key: u16, buf: &'a mut [u8]) -> Result, Error> { + Ok(Self::load(self, key, buf)?.map(|len| &buf[..len])) + } + + fn store(&mut self, key: u16, data: &[u8], _buf: &mut [u8]) -> Result<(), Error> { + Self::store(self, key, data) } - #[test] - fn test_load_legacy_format() { - // Generate a "legacy" blob using the old array-based format - // (anonymous array with positional anonymous octet-strings) - let source_matter = new_test_matter(); - source_matter.with_state(|state| { - let basic = &mut state.basic_info_settings; - basic.node_label = heapless::String::try_from("my-test-node").unwrap(); - basic.location = Some(heapless::String::try_from("ab").unwrap()); - }); - - let mut buf = [0u8; 4096]; - let mut wb = crate::utils::storage::WriteBuf::new(&mut buf); - wb.start_array(&crate::tlv::TLVTag::Anonymous).unwrap(); - wb.str_cb(&crate::tlv::TLVTag::Anonymous, |buf| { - source_matter.store_fabrics(buf) - }) - .unwrap(); - wb.str_cb(&crate::tlv::TLVTag::Anonymous, |buf| { - source_matter.store_basic_info(buf) - }) - .unwrap(); - wb.end_container().unwrap(); - let tail = wb.get_tail(); - let legacy_blob = &buf[..tail]; - - let dir = tempfile::tempdir().unwrap(); - let path = dir.path().join("legacy.bin"); - std::fs::write(&path, legacy_blob).unwrap(); - - let matter = Matter::new( - &TEST_DEV_DET, - TEST_DEV_COMM, - &TEST_DEV_ATT, - sys_epoch, - MATTER_PORT, - ); - - let mut psm = Psm::<32768>::new(); - psm.load(&path, &matter, NO_NETWORKS, NO_EVENTS).unwrap(); - - matter.with_state(|state| { - let basic = &state.basic_info_settings; - assert_eq!(basic.node_label, "my-test-node"); - assert_eq!(basic.location.as_deref(), Some("ab")); - }); + fn remove(&mut self, key: u16, _buf: &mut [u8]) -> Result<(), Error> { + Self::remove(self, key) } } } diff --git a/rs-matter/src/respond.rs b/rs-matter/src/respond.rs index 4076c8906..df17c1a5e 100644 --- a/rs-matter/src/respond.rs +++ b/rs-matter/src/respond.rs @@ -27,6 +27,7 @@ use crate::dm::{DataModel, IMBuffer}; use crate::error::Error; use crate::im::busy::BusyInteractionModel; use crate::im::PROTO_ID_INTERACTION_MODEL; +use crate::persist::KvBlobStoreAccess; use crate::sc::busy::BusySecureChannel; use crate::sc::SecureChannel; use crate::transport::exchange::Exchange; @@ -237,21 +238,22 @@ where } /// A type alias for the "default" responder handler, which is a chained handler of the `DataModel` and `SecureChannel` handlers. -pub type DefaultExchangeHandler<'d, 'a, const NS: usize, const NE: usize, C, B, T> = - ChainedExchangeHandler<&'d DataModel<'a, NS, NE, C, B, T>, SecureChannel<'d, &'d C>>; +pub type DefaultExchangeHandler<'d, 'a, const NS: usize, const NE: usize, C, B, T, S> = + ChainedExchangeHandler<&'d DataModel<'a, NS, NE, C, B, T, S>, SecureChannel<'d, &'d C>>; -impl<'d, 'a, const NS: usize, const NE: usize, C, B, T> - Responder<'a, DefaultExchangeHandler<'d, 'a, NS, NE, C, B, T>> +impl<'d, 'a, const NS: usize, const NE: usize, C, B, T, S> + Responder<'a, DefaultExchangeHandler<'d, 'a, NS, NE, C, B, T, S>> where - C: Crypto, B: BufferAccess, { /// Creates a "default" responder. This is a responder that composes and uses the `rs-matter`-provided `ExchangeHandler` implementations /// (`SecureChannel` and `DataModel`) for handling the Secure Channel protocol and the Interaction Model protocol. #[inline(always)] - pub const fn new_default(data_model: &'d DataModel<'a, NS, NE, C, B, T>) -> Self + pub const fn new_default(data_model: &'d DataModel<'a, NS, NE, C, B, T, S>) -> Self where + C: Crypto, T: DataModelHandler, + S: KvBlobStoreAccess, { Self::new( "Responder", @@ -292,23 +294,25 @@ impl<'a> Responder<'a, BusyExchangeHandler> { } /// A composition of the `Responder::new_default` and `Responder::new_busy` responders. -pub struct DefaultResponder<'d, 'a, const NS: usize, const NE: usize, C, B, T> +pub struct DefaultResponder<'d, 'a, const NS: usize, const NE: usize, C, B, T, S> where B: BufferAccess, { - responder: Responder<'a, DefaultExchangeHandler<'d, 'a, NS, NE, C, B, T>>, + responder: Responder<'a, DefaultExchangeHandler<'d, 'a, NS, NE, C, B, T, S>>, busy_responder: Responder<'a, BusyExchangeHandler>, } -impl<'d, 'a, const NS: usize, const NE: usize, C, B, T> DefaultResponder<'d, 'a, NS, NE, C, B, T> +impl<'d, 'a, const NS: usize, const NE: usize, C, B, T, S> + DefaultResponder<'d, 'a, NS, NE, C, B, T, S> where + C: Crypto, B: BufferAccess, T: DataModelHandler, - C: Crypto, + S: KvBlobStoreAccess, { /// Creates the responder composition. #[inline(always)] - pub const fn new(data_model: &'d DataModel<'a, NS, NE, C, B, T>) -> Self { + pub const fn new(data_model: &'d DataModel<'a, NS, NE, C, B, T, S>) -> Self { Self { responder: Responder::new_default(data_model), busy_responder: Responder::new_busy(data_model.matter(), RESPOND_BUSY_MS), @@ -332,7 +336,7 @@ where &self, ) -> &Responder< 'a, - ChainedExchangeHandler<&'d DataModel<'a, NS, NE, C, B, T>, SecureChannel<'d, &'d C>>, + ChainedExchangeHandler<&'d DataModel<'a, NS, NE, C, B, T, S>, SecureChannel<'d, &'d C>>, > { &self.responder } diff --git a/rs-matter/src/sc/case.rs b/rs-matter/src/sc/case.rs index 259cf91db..ca9de7ae3 100644 --- a/rs-matter/src/sc/case.rs +++ b/rs-matter/src/sc/case.rs @@ -108,7 +108,6 @@ impl<'a, C: Crypto> Case<'a, C> { self.handle_casesigma3(exchange, session).await?; exchange.acknowledge().await?; - exchange.matter().notify_persist(); Ok(()) } diff --git a/rs-matter/src/sc/pase.rs b/rs-matter/src/sc/pase.rs index 7146a543a..7777b14e9 100644 --- a/rs-matter/src/sc/pase.rs +++ b/rs-matter/src/sc/pase.rs @@ -24,17 +24,14 @@ use core::num::NonZeroU8; use core::ops::Add; use core::time::Duration; -use rand_core::RngCore; use spake2p::Spake2pVerifierData; -use crate::crypto::Crypto; use crate::dm::clusters::adm_comm::{self}; use crate::dm::endpoints::ROOT_ENDPOINT_ID; -use crate::dm::BasicContext; use crate::error::{Error, ErrorCode}; +use crate::im::{AttrId, ClusterId, EndptId}; use crate::sc::pase::spake2p::{ Spake2pVerifierPasswordRef, Spake2pVerifierSaltRef, Spake2pVerifierStrRef, - SPAKE2P_VERIFIER_SALT_ZEROED, }; use crate::sc::SessionParameters; use crate::tlv::{FromTLV, OctetStr, ToTLV}; @@ -205,10 +202,11 @@ impl Pase { }) } - pub fn comm_window(&mut self, ctx: C) -> Result, Error> - where - C: BasicContext, - { + pub fn comm_window( + &mut self, + notify_mdns: impl FnMut(), + notify_change: impl FnMut(EndptId, ClusterId, AttrId), + ) -> Result, Error> { let expired = self .comm_window .as_opt_ref() @@ -218,7 +216,7 @@ impl Pase { if expired { warn!("PASE Commissioning Window expired, closing"); - self.close_comm_window(ctx)?; + self.close_comm_window(notify_mdns, notify_change)?; Ok(None) } else { @@ -240,18 +238,22 @@ impl Pase { /// - `Err(Error)` if an error occurred /// (i.e. there is another non-expired commissioning window already opened /// or the timeout is invalid) - pub fn open_basic_comm_window( + #[allow(clippy::too_many_arguments)] + pub fn open_basic_comm_window( &mut self, - ctx: C, + mdns_id: u64, + salt: Spake2pVerifierSaltRef<'_>, password: Spake2pVerifierPasswordRef<'_>, discriminator: u16, timeout_secs: u16, opener: Option, - ) -> Result<(), Error> - where - C: BasicContext, - { - if self.comm_window(&ctx)?.is_some() { + mut notify_mdns: impl FnMut(), + mut notify_change: impl FnMut(EndptId, ClusterId, AttrId), + ) -> Result<(), Error> { + if self + .comm_window(&mut notify_mdns, &mut notify_change)? + .is_some() + { Err(ErrorCode::Busy)?; } @@ -261,27 +263,18 @@ impl Pase { let window_expiry = (self.epoch)().add(Duration::from_secs(timeout_secs as _)); - let crypto = ctx.crypto(); - let mut rand = crypto.rand()?; - - let mdns_id = rand.next_u64(); - - let mut salt = SPAKE2P_VERIFIER_SALT_ZEROED; - rand.fill_bytes(salt.access_mut()); - self.comm_window .reinit(Maybe::init_some(CommWindow::init_with_pw( mdns_id, password, - salt.reference(), + salt, discriminator, opener, window_expiry, ))); - ctx.matter().notify_mdns(); - - ctx.notify_attribute_changed( + notify_mdns(); + notify_change( ROOT_ENDPOINT_ID, adm_comm::FULL_CLUSTER.id, adm_comm::AttributeId::WindowStatus as _, @@ -309,20 +302,22 @@ impl Pase { /// (i.e. there is another non-expired commissioning window already opened /// or the timeout is invalid) #[allow(clippy::too_many_arguments)] - pub fn open_comm_window( + pub fn open_comm_window( &mut self, - ctx: C, + mdns_id: u64, verifier: Spake2pVerifierStrRef<'_>, salt: Spake2pVerifierSaltRef<'_>, count: u32, discriminator: u16, timeout_secs: u16, opener: Option, - ) -> Result<(), Error> - where - C: BasicContext, - { - if self.comm_window(&ctx)?.is_some() { + mut notify_mdns: impl FnMut(), + mut notify_change: impl FnMut(EndptId, ClusterId, AttrId), + ) -> Result<(), Error> { + if self + .comm_window(&mut notify_mdns, &mut notify_change)? + .is_some() + { Err(ErrorCode::Busy)?; } @@ -332,11 +327,6 @@ impl Pase { let window_expiry = (self.epoch)().add(Duration::from_secs(timeout_secs as _)); - let crypto = ctx.crypto(); - let mut rand = crypto.rand()?; - - let mdns_id = rand.next_u64(); - self.comm_window.reinit(Maybe::init_some(CommWindow::init( mdns_id, verifier, @@ -347,9 +337,8 @@ impl Pase { window_expiry, ))); - ctx.matter().notify_mdns(); - - ctx.notify_attribute_changed( + notify_mdns(); + notify_change( ROOT_ENDPOINT_ID, adm_comm::FULL_CLUSTER.id, adm_comm::AttributeId::WindowStatus as _, @@ -368,15 +357,16 @@ impl Pase { /// # Returns /// - `Ok(true)` if a commissioning window was closed /// - `Ok(false)` if there was no commissioning window to close - pub fn close_comm_window(&mut self, ctx: C) -> Result - where - C: BasicContext, - { + pub fn close_comm_window( + &mut self, + mut notify_mdns: impl FnMut(), + mut notify_change: impl FnMut(EndptId, ClusterId, AttrId), + ) -> Result { if self.comm_window.is_some() { self.comm_window.clear(); - ctx.matter().notify_mdns(); - ctx.notify_attribute_changed( + notify_mdns(); + notify_change( ROOT_ENDPOINT_ID, adm_comm::FULL_CLUSTER.id, adm_comm::AttributeId::WindowStatus as _, diff --git a/rs-matter/src/sc/pase/responder.rs b/rs-matter/src/sc/pase/responder.rs index 156ebe45a..c8aa13182 100644 --- a/rs-matter/src/sc/pase/responder.rs +++ b/rs-matter/src/sc/pase/responder.rs @@ -26,7 +26,7 @@ use crate::crypto::{ CanonEcPointRef, Crypto, HmacHashRef, Kdf, AEAD_CANON_KEY_LEN, EC_POINT_ZEROED, HMAC_HASH_ZEROED, }; -use crate::dm::{BasicContextInstance, ChangeNotify}; +use crate::dm::ChangeNotify; use crate::error::{Error, ErrorCode}; use crate::sc::pase::spake2p::{Spake2P, Spake2pRandom, Spake2pRandomRef, Spake2pSessionKeys}; use crate::sc::{check_opcode, complete_with_status, OpCode, SCStatusCodes}; @@ -96,7 +96,6 @@ impl<'a, C: Crypto> PaseResponder<'a, C> { self.handle_pasepake3(exchange, session).await?; exchange.acknowledge().await?; - exchange.matter().notify_persist(); self.clear_session_timeout(exchange) } @@ -113,11 +112,13 @@ impl<'a, C: Crypto> PaseResponder<'a, C> { let mut salt = super::spake2p::SPAKE2P_VERIFIER_SALT_ZEROED; let mut count = 0; - let has_comm_window = { - let ctx = BasicContextInstance::new(exchange.matter(), &self.crypto, self.notify); + let notify_mdns = || exchange.matter().notify_mdns(); + let notify_change = + |endpt_id, cluster_id, attr_id| self.notify.notify(endpt_id, cluster_id, attr_id); + let has_comm_window = { exchange.with_state(|state| { - if let Some(comm_window) = state.pase.comm_window(&ctx)? { + if let Some(comm_window) = state.pase.comm_window(notify_mdns, notify_change)? { salt.load(comm_window.verifier.salt.reference()); count = comm_window.verifier.count; @@ -200,11 +201,13 @@ impl<'a, C: Crypto> PaseResponder<'a, C> { let mut b_pt = EC_POINT_ZEROED; let mut cb = HMAC_HASH_ZEROED; - let has_comm_window = { - let ctx = BasicContextInstance::new(exchange.matter(), &self.crypto, self.notify); + let notify_mdns = || exchange.matter().notify_mdns(); + let notify_change = + |endpt_id, cluster_id, attr_id| self.notify.notify(endpt_id, cluster_id, attr_id); + let has_comm_window = { exchange.with_state(|state| { - if let Some(comm_window) = state.pase.comm_window(&ctx)? { + if let Some(comm_window) = state.pase.comm_window(notify_mdns, notify_change)? { self.spake2p.setup_verifier( &self.crypto, &comm_window.verifier, diff --git a/rs-matter/src/transport.rs b/rs-matter/src/transport.rs index fe36353dd..480e4ad36 100644 --- a/rs-matter/src/transport.rs +++ b/rs-matter/src/transport.rs @@ -410,7 +410,7 @@ impl<'a, C: Crypto> TransportRunner<'a, C> { let addr_op = self.matter.with_state(|state| { let group_addrs = || { state.fabrics.iter().flat_map(|fabric| { - fabric.group_iter().map(|group| { + fabric.groups().iter().map(|group| { compute_group_multicast_addr(fabric.fabric_id(), group.group_id) }) }) diff --git a/rs-matter/src/transport/network/mdns/astro.rs b/rs-matter/src/transport/network/mdns/astro.rs index 74f597946..8d49cee52 100644 --- a/rs-matter/src/transport/network/mdns/astro.rs +++ b/rs-matter/src/transport/network/mdns/astro.rs @@ -25,9 +25,8 @@ use std::collections::{HashMap, HashSet}; use std::net::ToSocketAddrs; use std::time::Duration; -use crate::crypto::Crypto; -use crate::dm::ChangeNotify; use crate::error::{Error, ErrorCode}; +use crate::im::{AttrId, ClusterId, EndptId}; use crate::transport::network::mdns::Service; use crate::{Matter, MatterMdnsService}; @@ -57,16 +56,15 @@ impl<'a> AstroMdnsResponder<'a> { /// # Arguments /// - `crypto`: A crypto provider instance. /// - `notify`: A change notification interface. - pub async fn run( + pub async fn run( &mut self, - crypto: C, - notify: &dyn ChangeNotify, + mut notify: impl FnMut(EndptId, ClusterId, AttrId), ) -> Result<(), Error> { loop { self.matter.wait_mdns().await; let mut services = HashSet::new(); - self.matter.mdns_services(&crypto, notify, |service| { + self.matter.mdns_services(&mut notify, |service| { services.insert(service); Ok(()) @@ -165,7 +163,7 @@ pub fn discover_commissionable( info!("Browsing for mDNS services: {}", service_type); - let browser = ServiceBrowserBuilder::new(&service_type) + let browser = ServiceBrowserBuilder::new(service_type) .browse() .map_err(|e| { error!("Failed to create service browser: {:?}", e); diff --git a/rs-matter/src/transport/network/mdns/avahi.rs b/rs-matter/src/transport/network/mdns/avahi.rs index cdde29ac2..e21f1d8d8 100644 --- a/rs-matter/src/transport/network/mdns/avahi.rs +++ b/rs-matter/src/transport/network/mdns/avahi.rs @@ -33,9 +33,8 @@ use futures_lite::StreamExt; use zbus::zvariant::{ObjectPath, OwnedObjectPath}; use zbus::Connection; -use crate::crypto::Crypto; -use crate::dm::ChangeNotify; use crate::error::Error; +use crate::im::{AttrId, ClusterId, EndptId}; use crate::transport::network::mdns::Service; use crate::utils::zbus_proxies::avahi::entry_group::EntryGroupProxy; use crate::utils::zbus_proxies::avahi::server2::Server2Proxy; @@ -70,11 +69,10 @@ impl<'a> AvahiMdnsResponder<'a> { /// - `connection`: A reference to the DBus system connection to use for communication with Avahi. /// - `crypto`: A crypto provider instance. /// - `notify`: A change notification interface. - pub async fn run( + pub async fn run( &mut self, connection: &Connection, - crypto: C, - notify: &dyn ChangeNotify, + mut notify: impl FnMut(EndptId, ClusterId, AttrId), ) -> Result<(), Error> { { let avahi = Server2Proxy::new(connection).await?; @@ -85,7 +83,7 @@ impl<'a> AvahiMdnsResponder<'a> { self.matter.wait_mdns().await; let mut services = HashSet::new(); - self.matter.mdns_services(&crypto, notify, |service| { + self.matter.mdns_services(&mut notify, |service| { services.insert(service); Ok(()) diff --git a/rs-matter/src/transport/network/mdns/builtin.rs b/rs-matter/src/transport/network/mdns/builtin.rs index 56ed0c52d..2a8ef099d 100644 --- a/rs-matter/src/transport/network/mdns/builtin.rs +++ b/rs-matter/src/transport/network/mdns/builtin.rs @@ -248,14 +248,16 @@ where where F: FnMut(&Service) -> Result<(), Error>, { - self.matter - .mdns_services(&self.crypto, self.notify, |service| { - Service::call_with( - &service, - self.matter.dev_det(), - self.matter.port(), - &mut callback, - ) - }) + let notify_change = + |endpt_id, clust_id, attr_id| self.notify.notify(endpt_id, clust_id, attr_id); + + self.matter.mdns_services(notify_change, |service| { + Service::call_with( + &service, + self.matter.dev_det(), + self.matter.port(), + &mut callback, + ) + }) } } diff --git a/rs-matter/src/transport/network/mdns/resolve.rs b/rs-matter/src/transport/network/mdns/resolve.rs index 0a365e3ff..9c00f8d22 100644 --- a/rs-matter/src/transport/network/mdns/resolve.rs +++ b/rs-matter/src/transport/network/mdns/resolve.rs @@ -26,9 +26,8 @@ use domain::base::Name; use zbus::zvariant::{ObjectPath, OwnedObjectPath}; use zbus::Connection; -use crate::crypto::Crypto; -use crate::dm::ChangeNotify; use crate::error::Error; +use crate::im::{AttrId, ClusterId, EndptId}; use crate::transport::network::mdns::Service; use crate::utils::zbus_proxies::resolve::manager::ManagerProxy; use crate::{Matter, MatterMdnsService}; @@ -86,17 +85,16 @@ impl<'a> ResolveMdnsResponder<'a> { /// - `connection`: A reference to the DBus system connection to use for communication with Avahi. /// - `crypto`: A crypto provider instance. /// - `notify`: A change notification interface. - pub async fn run( + pub async fn run( &mut self, connection: &Connection, - crypto: C, - notify: &dyn ChangeNotify, + mut notify: impl FnMut(EndptId, ClusterId, AttrId), ) -> Result<(), Error> { loop { self.matter.wait_mdns().await; let mut services = HashSet::new(); - self.matter.mdns_services(&crypto, notify, |service| { + self.matter.mdns_services(&mut notify, |service| { services.insert(service); Ok(()) diff --git a/rs-matter/src/transport/network/mdns/zeroconf.rs b/rs-matter/src/transport/network/mdns/zeroconf.rs index f3221f099..b101209ab 100644 --- a/rs-matter/src/transport/network/mdns/zeroconf.rs +++ b/rs-matter/src/transport/network/mdns/zeroconf.rs @@ -30,9 +30,8 @@ use zeroconf::service::TMdnsService; use zeroconf::txt_record::TTxtRecord; use zeroconf::{MdnsBrowser, ServiceDiscovery, ServiceType}; -use crate::crypto::Crypto; -use crate::dm::ChangeNotify; use crate::error::{Error, ErrorCode}; +use crate::im::{AttrId, ClusterId, EndptId}; use crate::transport::network::mdns::Service; use crate::{Matter, MatterMdnsService}; @@ -59,16 +58,15 @@ impl<'a> ZeroconfMdnsResponder<'a> { /// # Arguments /// - `crypto`: A crypto provider instance. /// - `notify`: A change notification interface. - pub async fn run( + pub async fn run( &mut self, - crypto: C, - notify: &dyn ChangeNotify, + mut notify: impl FnMut(EndptId, ClusterId, AttrId), ) -> Result<(), Error> { loop { self.matter.wait_mdns().await; let mut services = HashSet::new(); - self.matter.mdns_services(&crypto, notify, |service| { + self.matter.mdns_services(&mut notify, |service| { services.insert(service); Ok(()) diff --git a/rs-matter/src/transport/session.rs b/rs-matter/src/transport/session.rs index 5c98b3ee0..0b4f0d630 100644 --- a/rs-matter/src/transport/session.rs +++ b/rs-matter/src/transport/session.rs @@ -800,12 +800,12 @@ impl Sessions { let fab_idx = fabric.fab_idx(); let compressed_fabric_id = fabric.compressed_fabric_id(); - for map_entry in fabric.group_key_map_iter() { + for map_entry in fabric.groups().key_map_iter() { if map_entry.group_id != group_id { continue; } - let Some(key_set_entry) = fabric.group_key_set_get(map_entry.group_key_set_id) + let Some(key_set_entry) = fabric.groups().key_set_get(map_entry.group_key_set_id) else { continue; }; diff --git a/rs-matter/tests/commissioning.rs b/rs-matter/tests/commissioning.rs index be338d293..4b9be83ad 100644 --- a/rs-matter/tests/commissioning.rs +++ b/rs-matter/tests/commissioning.rs @@ -49,6 +49,7 @@ use log::{debug, info, warn}; use rand_core::RngCore; +use rs_matter::persist::DummyKvBlobStoreAccess; use socket2::{Domain, Protocol, Socket, Type}; use static_cell::StaticCell; @@ -196,6 +197,7 @@ async fn run_test() -> Result<(), Error> { device_subscriptions, NO_EVENTS, dm_handler(rand, &on_off_handler), + DummyKvBlobStoreAccess, ); // Open commissioning window before starting the mDNS responder so the diff --git a/rs-matter/tests/common/e2e.rs b/rs-matter/tests/common/e2e.rs index 90e9184ac..f2da46506 100644 --- a/rs-matter/tests/common/e2e.rs +++ b/rs-matter/tests/common/e2e.rs @@ -29,6 +29,7 @@ use rs_matter::dm::subscriptions::Subscriptions; use rs_matter::dm::{AsyncHandler, AsyncMetadata, Privilege}; use rs_matter::dm::{DataModel, IMBuffer}; use rs_matter::error::Error; +use rs_matter::persist::DummyKvBlobStoreAccess; use rs_matter::respond::Responder; use rs_matter::transport::exchange::Exchange; use rs_matter::transport::network::{ @@ -138,7 +139,9 @@ impl E2eRunner { self.matter.with_state(|state| { state .fabrics - .acl_add(NonZeroU8::new(1).unwrap(), default_acl) + .fabric_mut(NonZeroU8::new(1).unwrap()) + .unwrap() + .acl_add(default_acl) .unwrap(); }); } @@ -185,6 +188,7 @@ impl E2eRunner { &self.subscriptions, Some(&self.events), handler, + DummyKvBlobStoreAccess, ); let responder = Responder::new_default(&dm); diff --git a/rs-matter/tests/common/mdns.rs b/rs-matter/tests/common/mdns.rs index 15e336ab6..532c26979 100644 --- a/rs-matter/tests/common/mdns.rs +++ b/rs-matter/tests/common/mdns.rs @@ -38,12 +38,12 @@ pub async fn run_mdns( ) -> Result<(), Error> { #[cfg(feature = "astro-dnssd")] rs_matter::transport::network::mdns::astro::AstroMdnsResponder::new(matter) - .run(crypto, notify) + .run(|endpt_id, clust_id, attr_id| notify.notify(endpt_id, clust_id, attr_id)) .await?; #[cfg(all(feature = "zeroconf", not(feature = "astro-dnssd")))] rs_matter::transport::network::mdns::zeroconf::ZeroconfMdnsResponder::new(matter) - .run(crypto, notify) + .run(|endpt_id, clust_id, attr_id| notify.notify(endpt_id, clust_id, attr_id)) .await?; // Both `avahi` and `resolve` modules are compiled under the single `zbus` @@ -57,8 +57,7 @@ pub async fn run_mdns( rs_matter::transport::network::mdns::avahi::AvahiMdnsResponder::new(matter) .run( &rs_matter::utils::zbus::Connection::system().await.unwrap(), - crypto, - notify, + |endpt_id, clust_id, attr_id| notify.notify(endpt_id, clust_id, attr_id), ) .await?; diff --git a/rs-matter/tests/common/mod.rs b/rs-matter/tests/common/mod.rs index 87a3292ac..5c6756620 100644 --- a/rs-matter/tests/common/mod.rs +++ b/rs-matter/tests/common/mod.rs @@ -16,6 +16,7 @@ */ pub mod e2e; +#[cfg(feature = "async-io")] pub mod mdns; use core::future::Future; diff --git a/rs-matter/tests/data_model/acl_and_dataver.rs b/rs-matter/tests/data_model/acl_and_dataver.rs index e2d9a459c..cf6305d74 100644 --- a/rs-matter/tests/data_model/acl_and_dataver.rs +++ b/rs-matter/tests/data_model/acl_and_dataver.rs @@ -70,7 +70,12 @@ fn wc_read_attribute() { acl.add_subject(TEST_PEER_ID).unwrap(); acl.add_target(Target::new(Some(0), None, None)).unwrap(); im.matter.with_state(|state| { - state.fabrics.acl_add(FAB_1, acl).unwrap(); + state + .fabrics + .fabric_mut(FAB_1) + .unwrap() + .acl_add(acl) + .unwrap(); }); // Test2: Only Single response as only single endpoint is allowed @@ -85,7 +90,12 @@ fn wc_read_attribute() { acl.add_subject(TEST_PEER_ID).unwrap(); acl.add_target(Target::new(Some(1), None, None)).unwrap(); im.matter.with_state(|state| { - state.fabrics.acl_add(FAB_1, acl).unwrap(); + state + .fabrics + .fabric_mut(FAB_1) + .unwrap() + .acl_add(acl) + .unwrap(); }); // Test3: Both responses are valid @@ -126,7 +136,12 @@ fn exact_read_attribute() { let mut acl = AclEntry::new(None, Privilege::ADMIN, AuthMode::Case); acl.add_subject(TEST_PEER_ID).unwrap(); im.matter.with_state(|state| { - state.fabrics.acl_add(FAB_1, acl).unwrap(); + state + .fabrics + .fabric_mut(FAB_1) + .unwrap() + .acl_add(acl) + .unwrap(); }); // Test2: Only Single response as only single endpoint is allowed @@ -186,7 +201,12 @@ fn wc_write_attribute() { acl.add_subject(TEST_PEER_ID).unwrap(); acl.add_target(Target::new(Some(0), None, None)).unwrap(); im.matter.with_state(|state| { - state.fabrics.acl_add(FAB_1, acl).unwrap(); + state + .fabrics + .fabric_mut(FAB_1) + .unwrap() + .acl_add(acl) + .unwrap(); }); // Test 2: Wildcard write to attributes will only return attributes @@ -207,7 +227,12 @@ fn wc_write_attribute() { acl.add_subject(TEST_PEER_ID).unwrap(); acl.add_target(Target::new(Some(1), None, None)).unwrap(); im.matter.with_state(|state| { - state.fabrics.acl_add(FAB_1, acl).unwrap(); + state + .fabrics + .fabric_mut(FAB_1) + .unwrap() + .acl_add(acl) + .unwrap(); }); // Test 3: Wildcard write to attributes will return multiple attributes @@ -264,7 +289,12 @@ fn exact_write_attribute() { let mut acl = AclEntry::new(None, Privilege::ADMIN, AuthMode::Case); acl.add_subject(TEST_PEER_ID).unwrap(); im.matter.with_state(|state| { - state.fabrics.acl_add(FAB_1, acl).unwrap(); + state + .fabrics + .fabric_mut(FAB_1) + .unwrap() + .acl_add(acl) + .unwrap(); }); // Test 1: Exact write to an attribute with permission should grant @@ -318,7 +348,12 @@ fn exact_write_attribute_noc_cat() { let mut acl = AclEntry::new(None, Privilege::ADMIN, AuthMode::Case); acl.add_subject_catid(cat_in_acl).unwrap(); im.matter.with_state(|state| { - state.fabrics.acl_add(FAB_1, acl).unwrap(); + state + .fabrics + .fabric_mut(FAB_1) + .unwrap() + .acl_add(acl) + .unwrap(); }); // Test 1: Exact write to an attribute with permission should grant @@ -351,7 +386,12 @@ fn insufficient_perms_write() { acl.add_subject(TEST_PEER_ID).unwrap(); acl.add_target(Target::new(Some(0), None, None)).unwrap(); im.matter.with_state(|state| { - state.fabrics.acl_add(FAB_1, acl).unwrap(); + state + .fabrics + .fabric_mut(FAB_1) + .unwrap() + .acl_add(acl) + .unwrap(); }); // Test: Not enough permission should return error @@ -425,7 +465,12 @@ fn write_with_runtime_acl_add() { )) .unwrap(); im.matter.with_state(|state| { - state.fabrics.acl_add(FAB_1, basic_acl).unwrap(); + state + .fabrics + .fabric_mut(FAB_1) + .unwrap() + .acl_add(basic_acl) + .unwrap(); }); // Test: deny write (with error), then ACL is added, then allow write @@ -458,7 +503,12 @@ fn test_read_data_ver() { // Add ACL to allow our peer with only OPERATE permission let acl = AclEntry::new(None, Privilege::OPERATE, AuthMode::Case); im.matter.with_state(|state| { - state.fabrics.acl_add(FAB_1, acl).unwrap(); + state + .fabrics + .fabric_mut(FAB_1) + .unwrap() + .acl_add(acl) + .unwrap(); }); let wc_ep_att1 = GenericPath::new( @@ -553,7 +603,12 @@ fn test_write_data_ver() { // Add ACL to allow our peer with only OPERATE permission let acl = AclEntry::new(None, Privilege::ADMIN, AuthMode::Case); im.matter.with_state(|state| { - state.fabrics.acl_add(FAB_1, acl).unwrap(); + state + .fabrics + .fabric_mut(FAB_1) + .unwrap() + .acl_add(acl) + .unwrap(); }); let wc_ep_attwrite = GenericPath::new( diff --git a/rs-matter/tests/data_model/events.rs b/rs-matter/tests/data_model/events.rs index 45af25ccc..b7a00a34d 100644 --- a/rs-matter/tests/data_model/events.rs +++ b/rs-matter/tests/data_model/events.rs @@ -24,6 +24,7 @@ use rs_matter::im::GenericPath; use rs_matter::im::IMStatusCode; use rs_matter::im::StatusResp; use rs_matter::im::SubscribeResp; +use rs_matter::persist::DummyKvBlobStoreAccess; use rs_matter::tlv::{TLVTag, TLVWrite}; use rs_matter::utils::storage::WriteBuf; @@ -332,24 +333,27 @@ fn test_long_read_events() { ], ); } + fn push_events(im: &crate::common::e2e::E2eRunner, events: &[TestEventData]) where C: rs_matter::crypto::Crypto, { for ev in events { - block_on( - im.events - .push(ev.path.clone(), ev.priority, |tw| -> Result<(), Error> { - if let Some(data) = ev.data { - let mut b = [0u8; 2048]; - let mut wb = WriteBuf::new(&mut b[0..]); - data.test_to_tlv(&TLVTag::Context(EventDataTag::Data as _), &mut wb)?; - let end = wb.get_tail(); - tw.write_raw_data(b[..end].iter().copied())?; - } - Ok(()) - }), - ) + block_on(im.events.push( + ev.path.clone(), + ev.priority, + DummyKvBlobStoreAccess, + |tw| -> Result<(), Error> { + if let Some(data) = ev.data { + let mut b = [0u8; 2048]; + let mut wb = WriteBuf::new(&mut b[0..]); + data.test_to_tlv(&TLVTag::Context(EventDataTag::Data as _), &mut wb)?; + let end = wb.get_tail(); + tw.write_raw_data(b[..end].iter().copied())?; + } + Ok(()) + }, + )) .unwrap(); } }