-
Notifications
You must be signed in to change notification settings - Fork 205
feat(sozo): add controller session lifecycle commands #3395
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 1 commit
7ecaa19
27e5d54
7d3289e
b59c4ec
e13353c
956b56d
6b4ca53
139789e
99aa87d
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,7 +1,12 @@ | ||
| use std::collections::HashMap; | ||
| use std::io::{Read, Write}; | ||
| use std::net::{TcpListener, TcpStream}; | ||
| use std::str::FromStr; | ||
| use std::time::Duration; | ||
|
|
||
| use anyhow::{bail, Result}; | ||
| use anyhow::{Context, Result, anyhow, bail}; | ||
| use dojo_world::contracts::contract_info::ContractInfo; | ||
| use serde::{Deserialize, Serialize}; | ||
| use slot::account_sdk::account::session::account::SessionAccount; | ||
| use slot::account_sdk::account::session::merkle::MerkleTree; | ||
| use slot::account_sdk::account::session::policy::{CallPolicy, MerkleLeaf, Policy, ProvedPolicy}; | ||
|
|
@@ -17,6 +22,63 @@ use url::Url; | |
| #[allow(missing_debug_implementations)] | ||
| pub type ControllerAccount = SessionAccount; | ||
|
|
||
| const CONTROLLER_OAUTH_TIMEOUT_SECS: u64 = 300; | ||
| const CONTROLLER_OAUTH_CALLBACK_PATH: &str = "/callback"; | ||
| const CONTROLLER_LOGIN_PATH: &str = "/slot"; | ||
| const CONTROLLER_ACCOUNT_INFO_QUERY: &str = r#" | ||
| query ControllerAccountInfo { | ||
| me { | ||
| id | ||
| username | ||
| controllers { | ||
| edges { | ||
| node { | ||
| id | ||
| address | ||
| } | ||
| } | ||
| } | ||
| } | ||
| } | ||
| "#; | ||
|
|
||
| #[derive(Debug, Deserialize)] | ||
| struct ControllerAccountInfoResponse { | ||
| me: Option<ControllerAccountInfo>, | ||
| } | ||
|
|
||
| #[derive(Debug, Deserialize)] | ||
| struct ControllerAccountInfo { | ||
| id: String, | ||
| username: String, | ||
| controllers: ControllerEdges, | ||
| } | ||
|
|
||
| #[derive(Debug, Deserialize)] | ||
| struct ControllerEdges { | ||
| edges: Option<Vec<Option<ControllerEdge>>>, | ||
| } | ||
|
|
||
| #[derive(Debug, Deserialize)] | ||
| struct ControllerEdge { | ||
| node: Option<ControllerNode>, | ||
| } | ||
|
|
||
| #[derive(Debug, Deserialize)] | ||
| struct ControllerNode { | ||
| id: String, | ||
| address: String, | ||
| } | ||
|
|
||
| #[derive(Debug, Serialize)] | ||
| struct GraphqlRequest<'a, T> | ||
| where | ||
| T: Serialize, | ||
| { | ||
| query: &'a str, | ||
| variables: T, | ||
| } | ||
|
|
||
| /// Create a new Catridge Controller account based on session key. | ||
| /// | ||
| /// For now, Controller guarantees that if the provided network is among one of the supported | ||
|
|
@@ -38,7 +100,7 @@ pub async fn create_controller( | |
| let chain_id = rpc_provider.chain_id().await?; | ||
|
|
||
| trace!(target: "account::controller", "Loading Slot credentials."); | ||
| let credentials = slot::credential::Credentials::load()?; | ||
| let credentials = load_or_bootstrap_credentials().await?; | ||
| let username = credentials.account.id; | ||
|
|
||
| // Right now, the Cartridge Controller API ensures that there's always a Controller associated | ||
|
|
@@ -85,6 +147,169 @@ pub async fn create_controller( | |
| Ok(session_details.into_account(rpc_provider)) | ||
| } | ||
|
|
||
| async fn load_or_bootstrap_credentials() -> Result<slot::credential::Credentials> { | ||
| match slot::credential::Credentials::load() { | ||
| Ok(credentials) => Ok(credentials), | ||
| Err(err) if should_bootstrap_credentials(&err) => { | ||
| trace!( | ||
| target: "account::controller", | ||
| error = %err, | ||
| "No valid controller credentials found. Starting inline authorization flow." | ||
| ); | ||
| bootstrap_credentials().await?; | ||
| slot::credential::Credentials::load() | ||
| .context("Controller credentials were created but could not be loaded") | ||
| .map_err(Into::into) | ||
| } | ||
| Err(err) => Err(err.into()), | ||
| } | ||
| } | ||
|
|
||
| fn should_bootstrap_credentials(err: &slot::Error) -> bool { | ||
| matches!( | ||
| err, | ||
| slot::Error::Unauthorized | slot::Error::MalformedCredentials | slot::Error::InvalidOAuth | ||
| ) | ||
| } | ||
|
|
||
| async fn bootstrap_credentials() -> Result<()> { | ||
| let listener = TcpListener::bind("127.0.0.1:0") | ||
| .context("Failed to start local callback listener for controller authorization")?; | ||
|
|
||
| let callback_uri = format!( | ||
| "http://127.0.0.1:{}{}", | ||
| listener.local_addr()?.port(), | ||
| CONTROLLER_OAUTH_CALLBACK_PATH | ||
| ); | ||
|
|
||
| let mut authorize_url = Url::parse(&slot::vars::get_cartridge_keychain_url()) | ||
| .context("Invalid Cartridge keychain URL")?; | ||
| authorize_url.set_path(CONTROLLER_LOGIN_PATH); | ||
| authorize_url.query_pairs_mut().append_pair("callback_uri", &callback_uri); | ||
|
|
||
| println!("Authorize your controller account in browser:\n\n {}\n", authorize_url); | ||
|
|
||
| slot::browser::open(authorize_url.as_str())?; | ||
|
|
||
| let code = tokio::time::timeout( | ||
| Duration::from_secs(CONTROLLER_OAUTH_TIMEOUT_SECS), | ||
| tokio::task::spawn_blocking(move || wait_for_oauth_code(listener)), | ||
| ) | ||
| .await | ||
| .map_err(|_| { | ||
| anyhow!( | ||
| "Timed out waiting for controller authorization callback after {} seconds.", | ||
| CONTROLLER_OAUTH_TIMEOUT_SECS | ||
| ) | ||
| })? | ||
| .map_err(|e| anyhow!("Failed to run controller authorization callback listener: {e}"))??; | ||
|
Comment on lines
+741
to
+771
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 🧩 Analysis chain🏁 Script executed: # First, let's find and examine the actual file
find . -name "controller.rs" -path "*/commands/options/account/*" | head -20Repository: dojoengine/dojo Length of output: 112 🏁 Script executed: # List the general structure to understand the codebase
git ls-files | grep -E "(controller|oauth)" | head -20Repository: dojoengine/dojo Length of output: 146 🏁 Script executed: # Examine the file content around lines 175-205
wc -l ./bin/sozo/src/commands/options/account/controller.rsRepository: dojoengine/dojo Length of output: 116 🏁 Script executed: # Read the relevant section of the file
sed -n '170,210p' ./bin/sozo/src/commands/options/account/controller.rs | cat -nRepository: dojoengine/dojo Length of output: 1850 🏁 Script executed: # Also check the imports and function signature
head -50 ./bin/sozo/src/commands/options/account/controller.rs | cat -nRepository: dojoengine/dojo Length of output: 1807 🏁 Script executed: # Let's find and examine the wait_for_oauth_code function
rg -A 20 "fn wait_for_oauth_code" ./bin/sozo/src/commands/options/account/controller.rs | head -40Repository: dojoengine/dojo Length of output: 892 🏁 Script executed: # Let's also check if there are any other uses of spawn_blocking in the file
rg "spawn_blocking" ./bin/sozo/src/commands/options/account/controller.rsRepository: dojoengine/dojo Length of output: 134 🏁 Script executed: # Check if JoinHandle is used elsewhere or if there are similar patterns
rg "tokio::time::timeout" ./bin/sozo/src/ --type rustRepository: dojoengine/dojo Length of output: 149 🏁 Script executed: # Let's verify the exact behavior by checking Tokio documentation behavior
# and understand the context better. First, check if the function returns after timeout
sed -n '175,220p' ./bin/sozo/src/commands/options/account/controller.rs | cat -nRepository: dojoengine/dojo Length of output: 2125 🏁 Script executed: # Check what happens after bootstrap_credentials function completes
rg -B 5 -A 15 "fn bootstrap_credentials" ./bin/sozo/src/commands/options/account/controller.rs | tail -50Repository: dojoengine/dojo Length of output: 879 🏁 Script executed: # Check if there are any cleanup/drop implementations or finalizers
rg "Drop|drop|abort|cancel" ./bin/sozo/src/commands/options/account/controller.rsRepository: dojoengine/dojo Length of output: 41 🏁 Script executed: # Let's verify the behavior of tokio::time::timeout to be absolutely certain
# Check if there are any tests or documentation about this pattern in the repo
rg -i "timeout.*spawn_blocking|spawn_blocking.*timeout" --type rust | head -20Repository: dojoengine/dojo Length of output: 41 🌐 Web query:
💡 Result:
So: a timeout around awaiting a If you need real cancellation, you must build it into the blocking operation (e.g., pass a cancellation flag/channel and have the blocking code check it), because Tokio cannot preemptively stop arbitrary blocking code [3]. Sources: [1] Tokio Citations:
ohayo sensei — Timeout doesn't abort the blocking listener task; the fix needs real cancellation. The timeout will drop the 🤖 Prompt for AI Agents |
||
|
|
||
| let mut api = slot::api::Client::new(); | ||
| let token = api.oauth2(&code).await.context("Failed to exchange OAuth code")?; | ||
| api.set_token(token.clone()); | ||
|
|
||
| let account_info = fetch_controller_account_info(&api) | ||
| .await | ||
| .context("Failed to load Controller account details after authorization")?; | ||
|
|
||
| let path = slot::credential::Credentials::new(account_info, token) | ||
| .store() | ||
| .context("Failed to store controller credentials")?; | ||
|
|
||
| trace!( | ||
| target: "account::controller", | ||
| path = %path.display(), | ||
| "Controller credentials stored." | ||
| ); | ||
|
|
||
| Ok(()) | ||
| } | ||
|
|
||
| async fn fetch_controller_account_info( | ||
| api: &slot::api::Client, | ||
| ) -> Result<slot::account::AccountInfo> { | ||
| let request = | ||
| GraphqlRequest { query: CONTROLLER_ACCOUNT_INFO_QUERY, variables: serde_json::json!({}) }; | ||
|
|
||
| let response: ControllerAccountInfoResponse = api.query(&request).await?; | ||
| let me = response.me.ok_or_else(|| anyhow!("Missing `me` account info in API response"))?; | ||
|
|
||
| let mut controllers = Vec::new(); | ||
| for edge in me.controllers.edges.unwrap_or_default().into_iter().flatten() { | ||
| let Some(node) = edge.node else { | ||
| continue; | ||
| }; | ||
|
|
||
| let address = Felt::from_str(&node.address) | ||
| .with_context(|| format!("Invalid controller address `{}`", node.address))?; | ||
|
|
||
| controllers.push(slot::account::Controller { id: node.id, address }); | ||
| } | ||
|
|
||
| Ok(slot::account::AccountInfo { | ||
| id: me.id, | ||
| username: me.username, | ||
| controllers, | ||
| credentials: Vec::new(), | ||
| }) | ||
| } | ||
|
|
||
| fn wait_for_oauth_code(listener: TcpListener) -> Result<String> { | ||
| let (mut stream, _) = | ||
| listener.accept().context("Failed to accept controller OAuth callback connection")?; | ||
|
|
||
| let mut buffer = [0_u8; 8192]; | ||
| let bytes_read = | ||
| stream.read(&mut buffer).context("Failed to read controller OAuth callback request")?; | ||
| if bytes_read == 0 { | ||
| bail!("Controller OAuth callback request was empty."); | ||
| } | ||
|
|
||
| let request = String::from_utf8_lossy(&buffer[..bytes_read]); | ||
| let request_line = request.lines().next().unwrap_or_default(); | ||
| let target = request_line | ||
| .split_whitespace() | ||
| .nth(1) | ||
| .ok_or_else(|| anyhow!("Invalid callback request line: `{request_line}`"))?; | ||
|
|
||
| let Some(code) = extract_oauth_code(target) else { | ||
| write_http_response( | ||
| &mut stream, | ||
| "400 Bad Request", | ||
| "Missing authorization code. You can close this tab and retry.", | ||
| )?; | ||
| bail!("Controller OAuth callback does not contain `code` query parameter."); | ||
| }; | ||
|
|
||
| write_http_response( | ||
| &mut stream, | ||
| "200 OK", | ||
| "Controller authorization received. You can close this tab and return to sozo.", | ||
| )?; | ||
|
|
||
| Ok(code) | ||
| } | ||
|
|
||
| fn extract_oauth_code(target: &str) -> Option<String> { | ||
| let callback_url = Url::parse(&format!("http://localhost{target}")).ok()?; | ||
| if callback_url.path() != CONTROLLER_OAUTH_CALLBACK_PATH { | ||
| return None; | ||
| } | ||
|
|
||
| callback_url.query_pairs().find_map(|(key, value)| (key == "code").then(|| value.into_owned())) | ||
| } | ||
|
|
||
| fn write_http_response(stream: &mut TcpStream, status: &str, body: &str) -> Result<()> { | ||
| let response = format!( | ||
| "HTTP/1.1 {status}\r\nContent-Type: text/plain; charset=utf-8\r\nContent-Length: \ | ||
| {}\r\nConnection: close\r\n\r\n{body}", | ||
| body.len() | ||
| ); | ||
| stream.write_all(response.as_bytes())?; | ||
| stream.flush()?; | ||
| Ok(()) | ||
| } | ||
|
|
||
| // Check if the new policies are equal to the ones in the existing session | ||
| // | ||
| // This function would compute the merkle root of the new policies and compare it with the root in | ||
|
|
@@ -171,7 +396,7 @@ mod tests { | |
| use scarb_metadata_ext::MetadataDojoExt; | ||
| use starknet::macros::felt; | ||
|
|
||
| use super::{collect_policies, PolicyMethod}; | ||
| use super::{PolicyMethod, collect_policies, extract_oauth_code}; | ||
|
|
||
| #[test] | ||
| fn collect_policies_from_project() { | ||
|
|
@@ -200,4 +425,21 @@ mod tests { | |
| }); | ||
| } | ||
| } | ||
|
|
||
| #[test] | ||
| fn extract_oauth_code_from_callback_target() { | ||
| let code = extract_oauth_code("/callback?code=abc123&state=xyz"); | ||
| assert_eq!(code.as_deref(), Some("abc123")); | ||
| } | ||
|
|
||
| #[test] | ||
| fn extract_oauth_code_decodes_url_encoded_value() { | ||
| let code = extract_oauth_code("/callback?code=abc%2F123"); | ||
| assert_eq!(code.as_deref(), Some("abc/123")); | ||
| } | ||
|
|
||
| #[test] | ||
| fn extract_oauth_code_rejects_non_callback_target() { | ||
| assert_eq!(extract_oauth_code("/not-callback?code=abc123"), None); | ||
| } | ||
| } | ||
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
|
|
@@ -2,7 +2,7 @@ use std::collections::HashMap; | |||||
| use std::str::FromStr; | ||||||
| use std::sync::Arc; | ||||||
|
|
||||||
| use anyhow::{anyhow, Result}; | ||||||
| use anyhow::{Result, anyhow}; | ||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ohayo sensei — Fix rustfmt import ordering to unblock CI. Suggested fix-use anyhow::{Result, anyhow};
+use anyhow::{anyhow, Result};📝 Committable suggestion
Suggested change
🤖 Prompt for AI Agents |
||||||
| use clap::{Args, ValueEnum}; | ||||||
| use dojo_utils::env::DOJO_ACCOUNT_ADDRESS_ENV_VAR; | ||||||
| use dojo_world::config::Environment; | ||||||
|
|
@@ -52,6 +52,13 @@ pub struct AccountOptions { | |||||
| #[cfg(feature = "controller")] | ||||||
| pub controller: bool, | ||||||
|
|
||||||
| #[arg(global = true)] | ||||||
| #[arg(long = "session")] | ||||||
| #[arg(help_heading = "Controller options")] | ||||||
| #[arg(help = "Use Cartridge Controller session account (alias of --slot.controller)")] | ||||||
| #[cfg(feature = "controller")] | ||||||
| pub session: bool, | ||||||
|
|
||||||
| #[command(flatten)] | ||||||
| #[command(next_help_heading = "Signer options")] | ||||||
| pub signer: SignerOptions, | ||||||
|
|
@@ -85,7 +92,7 @@ impl AccountOptions { | |||||
| P: Send + Sync, | ||||||
| { | ||||||
| #[cfg(feature = "controller")] | ||||||
| if self.controller { | ||||||
| if self.controller || self.session { | ||||||
| let url = starknet.url(env_metadata)?; | ||||||
| let cartridge_provider = CartridgeJsonRpcProvider::new(url.clone()); | ||||||
| let account = self.controller(url, cartridge_provider.clone(), contracts).await?; | ||||||
|
|
@@ -222,6 +229,14 @@ mod tests { | |||||
| ); | ||||||
| } | ||||||
|
|
||||||
| #[cfg(feature = "controller")] | ||||||
| #[test] | ||||||
| fn controller_session_alias_flag_is_parsed() { | ||||||
| let cmd = Command::parse_from(["sozo", "--session"]); | ||||||
| assert!(cmd.account.session); | ||||||
| assert!(!cmd.account.controller); | ||||||
| } | ||||||
|
|
||||||
| #[test] | ||||||
| fn account_address_from_both() { | ||||||
| let env_metadata = dojo_world::config::Environment { | ||||||
|
|
||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ohayo sensei — Rustfmt import ordering is failing CI here.
Please run rustfmt or reorder the anyhow imports.
🤖 Prompt for AI Agents