Skip to content
11 changes: 11 additions & 0 deletions bin/sozo/src/commands/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@ pub(crate) mod mcp;
pub(crate) mod migrate;
pub(crate) mod model;
pub(crate) mod options;
#[cfg(feature = "controller")]
pub(crate) mod session;
pub(crate) mod starknet;
pub(crate) mod test;
pub(crate) mod version;
Expand All @@ -45,6 +47,8 @@ use invoke::InvokeArgs;
use mcp::McpArgs;
use migrate::MigrateArgs;
use model::ModelArgs;
#[cfg(feature = "controller")]
use session::SessionArgs;
#[cfg(feature = "walnut")]
use sozo_walnut::walnut::WalnutArgs;
use starknet::StarknetArgs;
Expand Down Expand Up @@ -86,6 +90,9 @@ pub enum Commands {
Migrate(Box<MigrateArgs>),
#[command(about = "Inspect a model")]
Model(Box<ModelArgs>),
#[cfg(feature = "controller")]
#[command(about = "Manage Cartridge controller sessions")]
Session(Box<SessionArgs>),
#[command(about = "Runs cairo tests")]
Test(Box<TestArgs>),
#[command(about = "Print version")]
Expand Down Expand Up @@ -118,6 +125,8 @@ impl fmt::Display for Commands {
Commands::Inspect(_) => write!(f, "Inspect"),
Commands::Migrate(_) => write!(f, "Migrate"),
Commands::Model(_) => write!(f, "Model"),
#[cfg(feature = "controller")]
Commands::Session(_) => write!(f, "Session"),
Commands::Test(_) => write!(f, "Test"),
Commands::Version(_) => write!(f, "Version"),
Commands::Mcp(_) => write!(f, "Mcp"),
Expand Down Expand Up @@ -150,6 +159,8 @@ pub async fn run(command: Commands, scarb_metadata: &Metadata, ui: &SozoUi) -> R
Commands::Mcp(args) => args.run(scarb_metadata).await,
Commands::Migrate(args) => args.run(scarb_metadata, ui).await,
Commands::Model(args) => args.run(scarb_metadata, ui).await,
#[cfg(feature = "controller")]
Commands::Session(args) => args.run(scarb_metadata, ui).await,
Commands::Test(args) => args.run(scarb_metadata),
Commands::Version(args) => args.run(scarb_metadata),
#[cfg(feature = "walnut")]
Expand Down
248 changes: 245 additions & 3 deletions bin/sozo/src/commands/options/account/controller.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,12 @@
use std::collections::HashMap;
use std::io::{Read, Write};
use std::net::{TcpListener, TcpStream};
use std::str::FromStr;
use std::time::Duration;

use anyhow::{bail, Result};
use anyhow::{Context, Result, anyhow, bail};
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

ohayo sensei — Rustfmt import ordering is failing CI here.
Please run rustfmt or reorder the anyhow imports.

🤖 Prompt for AI Agents
In `@bin/sozo/src/commands/options/account/controller.rs` at line 7, The anyhow
import list is out-of-order and failing rustfmt; reorder the import or run
rustfmt so the items are sorted (for example change use anyhow::{Context,
Result, anyhow, bail}; to use anyhow::{anyhow, bail, Context, Result};) or
simply run cargo fmt to fix the import ordering in controller.rs.

use dojo_world::contracts::contract_info::ContractInfo;
use serde::{Deserialize, Serialize};
use slot::account_sdk::account::session::account::SessionAccount;
use slot::account_sdk::account::session::merkle::MerkleTree;
use slot::account_sdk::account::session::policy::{CallPolicy, MerkleLeaf, Policy, ProvedPolicy};
Expand All @@ -17,6 +22,63 @@ use url::Url;
#[allow(missing_debug_implementations)]
pub type ControllerAccount = SessionAccount;

const CONTROLLER_OAUTH_TIMEOUT_SECS: u64 = 300;
const CONTROLLER_OAUTH_CALLBACK_PATH: &str = "/callback";
const CONTROLLER_LOGIN_PATH: &str = "/slot";
const CONTROLLER_ACCOUNT_INFO_QUERY: &str = r#"
query ControllerAccountInfo {
me {
id
username
controllers {
edges {
node {
id
address
}
}
}
}
}
"#;

#[derive(Debug, Deserialize)]
struct ControllerAccountInfoResponse {
me: Option<ControllerAccountInfo>,
}

#[derive(Debug, Deserialize)]
struct ControllerAccountInfo {
id: String,
username: String,
controllers: ControllerEdges,
}

#[derive(Debug, Deserialize)]
struct ControllerEdges {
edges: Option<Vec<Option<ControllerEdge>>>,
}

#[derive(Debug, Deserialize)]
struct ControllerEdge {
node: Option<ControllerNode>,
}

#[derive(Debug, Deserialize)]
struct ControllerNode {
id: String,
address: String,
}

#[derive(Debug, Serialize)]
struct GraphqlRequest<'a, T>
where
T: Serialize,
{
query: &'a str,
variables: T,
}

/// Create a new Catridge Controller account based on session key.
///
/// For now, Controller guarantees that if the provided network is among one of the supported
Expand All @@ -38,7 +100,7 @@ pub async fn create_controller(
let chain_id = rpc_provider.chain_id().await?;

trace!(target: "account::controller", "Loading Slot credentials.");
let credentials = slot::credential::Credentials::load()?;
let credentials = load_or_bootstrap_credentials().await?;
let username = credentials.account.id;

// Right now, the Cartridge Controller API ensures that there's always a Controller associated
Expand Down Expand Up @@ -85,6 +147,169 @@ pub async fn create_controller(
Ok(session_details.into_account(rpc_provider))
}

async fn load_or_bootstrap_credentials() -> Result<slot::credential::Credentials> {
match slot::credential::Credentials::load() {
Ok(credentials) => Ok(credentials),
Err(err) if should_bootstrap_credentials(&err) => {
trace!(
target: "account::controller",
error = %err,
"No valid controller credentials found. Starting inline authorization flow."
);
bootstrap_credentials().await?;
slot::credential::Credentials::load()
.context("Controller credentials were created but could not be loaded")
.map_err(Into::into)
}
Err(err) => Err(err.into()),
}
}

fn should_bootstrap_credentials(err: &slot::Error) -> bool {
matches!(
err,
slot::Error::Unauthorized | slot::Error::MalformedCredentials | slot::Error::InvalidOAuth
)
}

async fn bootstrap_credentials() -> Result<()> {
let listener = TcpListener::bind("127.0.0.1:0")
.context("Failed to start local callback listener for controller authorization")?;

let callback_uri = format!(
"http://127.0.0.1:{}{}",
listener.local_addr()?.port(),
CONTROLLER_OAUTH_CALLBACK_PATH
);

let mut authorize_url = Url::parse(&slot::vars::get_cartridge_keychain_url())
.context("Invalid Cartridge keychain URL")?;
authorize_url.set_path(CONTROLLER_LOGIN_PATH);
authorize_url.query_pairs_mut().append_pair("callback_uri", &callback_uri);

println!("Authorize your controller account in browser:\n\n {}\n", authorize_url);

slot::browser::open(authorize_url.as_str())?;

let code = tokio::time::timeout(
Duration::from_secs(CONTROLLER_OAUTH_TIMEOUT_SECS),
tokio::task::spawn_blocking(move || wait_for_oauth_code(listener)),
)
.await
.map_err(|_| {
anyhow!(
"Timed out waiting for controller authorization callback after {} seconds.",
CONTROLLER_OAUTH_TIMEOUT_SECS
)
})?
.map_err(|e| anyhow!("Failed to run controller authorization callback listener: {e}"))??;
Comment on lines +741 to +771
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

🧩 Analysis chain

🏁 Script executed:

# First, let's find and examine the actual file
find . -name "controller.rs" -path "*/commands/options/account/*" | head -20

Repository: dojoengine/dojo

Length of output: 112


🏁 Script executed:

# List the general structure to understand the codebase
git ls-files | grep -E "(controller|oauth)" | head -20

Repository: dojoengine/dojo

Length of output: 146


🏁 Script executed:

# Examine the file content around lines 175-205
wc -l ./bin/sozo/src/commands/options/account/controller.rs

Repository: dojoengine/dojo

Length of output: 116


🏁 Script executed:

# Read the relevant section of the file
sed -n '170,210p' ./bin/sozo/src/commands/options/account/controller.rs | cat -n

Repository: dojoengine/dojo

Length of output: 1850


🏁 Script executed:

# Also check the imports and function signature
head -50 ./bin/sozo/src/commands/options/account/controller.rs | cat -n

Repository: dojoengine/dojo

Length of output: 1807


🏁 Script executed:

# Let's find and examine the wait_for_oauth_code function
rg -A 20 "fn wait_for_oauth_code" ./bin/sozo/src/commands/options/account/controller.rs | head -40

Repository: dojoengine/dojo

Length of output: 892


🏁 Script executed:

# Let's also check if there are any other uses of spawn_blocking in the file
rg "spawn_blocking" ./bin/sozo/src/commands/options/account/controller.rs

Repository: dojoengine/dojo

Length of output: 134


🏁 Script executed:

# Check if JoinHandle is used elsewhere or if there are similar patterns
rg "tokio::time::timeout" ./bin/sozo/src/ --type rust

Repository: dojoengine/dojo

Length of output: 149


🏁 Script executed:

# Let's verify the exact behavior by checking Tokio documentation behavior
# and understand the context better. First, check if the function returns after timeout
sed -n '175,220p' ./bin/sozo/src/commands/options/account/controller.rs | cat -n

Repository: dojoengine/dojo

Length of output: 2125


🏁 Script executed:

# Check what happens after bootstrap_credentials function completes
rg -B 5 -A 15 "fn bootstrap_credentials" ./bin/sozo/src/commands/options/account/controller.rs | tail -50

Repository: dojoengine/dojo

Length of output: 879


🏁 Script executed:

# Check if there are any cleanup/drop implementations or finalizers
rg "Drop|drop|abort|cancel" ./bin/sozo/src/commands/options/account/controller.rs

Repository: dojoengine/dojo

Length of output: 41


🏁 Script executed:

# Let's verify the behavior of tokio::time::timeout to be absolutely certain
# Check if there are any tests or documentation about this pattern in the repo
rg -i "timeout.*spawn_blocking|spawn_blocking.*timeout" --type rust | head -20

Repository: dojoengine/dojo

Length of output: 41


🌐 Web query:

tokio::time::timeout behavior with spawn_blocking task cancellation

💡 Result:

tokio::time::timeout(dur, fut) “cancels” by dropping the wrapped future when the deadline hits (it returns Err(Elapsed) and stops polling the inner future) [1]. That has important consequences with spawn_blocking:

  • If you do timeout(dur, handle).await (where handle: JoinHandle<T>), the inner future is effectively handle.await. When the timeout elapses, that await is dropped, which drops the JoinHandle.
  • Dropping a JoinHandle detaches the task; it does not stop it [2]. So the spawn_blocking closure will typically keep running in the blocking thread pool after the timeout.
  • Even if you try to “cancel” it explicitly, JoinHandle::abort() / AbortHandle::abort() does not abort spawn_blocking once it has started running (it may only prevent it from starting if it hasn’t begun yet) [2][3].

So: a timeout around awaiting a spawn_blocking task only times out the wait; it does not cancel the blocking work.

If you need real cancellation, you must build it into the blocking operation (e.g., pass a cancellation flag/channel and have the blocking code check it), because Tokio cannot preemptively stop arbitrary blocking code [3].

Sources: [1] Tokio timeout docs (cancellation by dropping) (docs.rs), [2] Tokio JoinHandle docs (drop detaches; abort doesn’t affect spawn_blocking) (docs.rs), [3] Tokio spawn_blocking docs (cannot be aborted once started; shutdown behavior) (docs.rs)

Citations:


ohayo sensei — Timeout doesn't abort the blocking listener task; the fix needs real cancellation.

The timeout will drop the JoinHandle, but spawn_blocking tasks cannot be preemptively aborted once they start running. The listener.accept() call will continue blocking indefinitely, keeping the listener alive. Using abort() alone won't stop the blocking operation—you need to pass a cancellation signal (e.g., a flag or channel) into the blocking closure so wait_for_oauth_code can check it and exit gracefully.

🤖 Prompt for AI Agents
In `@bin/sozo/src/commands/options/account/controller.rs` around lines 175 - 205,
The timeout currently just drops the JoinHandle while the spawn_blocking closure
(and wait_for_oauth_code) may remain blocked on listener.accept(); change this
to cooperative cancellation by adding a cancellation signal: update
wait_for_oauth_code to accept a cancellation receiver (e.g., tokio::sync::watch
or oneshot) and make it poll that signal while using non-blocking/timeout accept
or looping with short sleeps so it can exit when cancelled; in
bootstrap_credentials create the cancellation sender/receiver, pass the receiver
into the spawn_blocking closure that calls wait_for_oauth_code(listener,
cancel_rx), and when the tokio::time::timeout fires send the cancel signal (and
optionally abort the JoinHandle) so the blocking listener is closed and the task
can return promptly; ensure listener is closed/cleaned up on cancellation.


let mut api = slot::api::Client::new();
let token = api.oauth2(&code).await.context("Failed to exchange OAuth code")?;
api.set_token(token.clone());

let account_info = fetch_controller_account_info(&api)
.await
.context("Failed to load Controller account details after authorization")?;

let path = slot::credential::Credentials::new(account_info, token)
.store()
.context("Failed to store controller credentials")?;

trace!(
target: "account::controller",
path = %path.display(),
"Controller credentials stored."
);

Ok(())
}

async fn fetch_controller_account_info(
api: &slot::api::Client,
) -> Result<slot::account::AccountInfo> {
let request =
GraphqlRequest { query: CONTROLLER_ACCOUNT_INFO_QUERY, variables: serde_json::json!({}) };

let response: ControllerAccountInfoResponse = api.query(&request).await?;
let me = response.me.ok_or_else(|| anyhow!("Missing `me` account info in API response"))?;

let mut controllers = Vec::new();
for edge in me.controllers.edges.unwrap_or_default().into_iter().flatten() {
let Some(node) = edge.node else {
continue;
};

let address = Felt::from_str(&node.address)
.with_context(|| format!("Invalid controller address `{}`", node.address))?;

controllers.push(slot::account::Controller { id: node.id, address });
}

Ok(slot::account::AccountInfo {
id: me.id,
username: me.username,
controllers,
credentials: Vec::new(),
})
}

fn wait_for_oauth_code(listener: TcpListener) -> Result<String> {
let (mut stream, _) =
listener.accept().context("Failed to accept controller OAuth callback connection")?;

let mut buffer = [0_u8; 8192];
let bytes_read =
stream.read(&mut buffer).context("Failed to read controller OAuth callback request")?;
if bytes_read == 0 {
bail!("Controller OAuth callback request was empty.");
}

let request = String::from_utf8_lossy(&buffer[..bytes_read]);
let request_line = request.lines().next().unwrap_or_default();
let target = request_line
.split_whitespace()
.nth(1)
.ok_or_else(|| anyhow!("Invalid callback request line: `{request_line}`"))?;

let Some(code) = extract_oauth_code(target) else {
write_http_response(
&mut stream,
"400 Bad Request",
"Missing authorization code. You can close this tab and retry.",
)?;
bail!("Controller OAuth callback does not contain `code` query parameter.");
};

write_http_response(
&mut stream,
"200 OK",
"Controller authorization received. You can close this tab and return to sozo.",
)?;

Ok(code)
}

fn extract_oauth_code(target: &str) -> Option<String> {
let callback_url = Url::parse(&format!("http://localhost{target}")).ok()?;
if callback_url.path() != CONTROLLER_OAUTH_CALLBACK_PATH {
return None;
}

callback_url.query_pairs().find_map(|(key, value)| (key == "code").then(|| value.into_owned()))
}

fn write_http_response(stream: &mut TcpStream, status: &str, body: &str) -> Result<()> {
let response = format!(
"HTTP/1.1 {status}\r\nContent-Type: text/plain; charset=utf-8\r\nContent-Length: \
{}\r\nConnection: close\r\n\r\n{body}",
body.len()
);
stream.write_all(response.as_bytes())?;
stream.flush()?;
Ok(())
}

// Check if the new policies are equal to the ones in the existing session
//
// This function would compute the merkle root of the new policies and compare it with the root in
Expand Down Expand Up @@ -171,7 +396,7 @@ mod tests {
use scarb_metadata_ext::MetadataDojoExt;
use starknet::macros::felt;

use super::{collect_policies, PolicyMethod};
use super::{PolicyMethod, collect_policies, extract_oauth_code};

#[test]
fn collect_policies_from_project() {
Expand Down Expand Up @@ -200,4 +425,21 @@ mod tests {
});
}
}

#[test]
fn extract_oauth_code_from_callback_target() {
let code = extract_oauth_code("/callback?code=abc123&state=xyz");
assert_eq!(code.as_deref(), Some("abc123"));
}

#[test]
fn extract_oauth_code_decodes_url_encoded_value() {
let code = extract_oauth_code("/callback?code=abc%2F123");
assert_eq!(code.as_deref(), Some("abc/123"));
}

#[test]
fn extract_oauth_code_rejects_non_callback_target() {
assert_eq!(extract_oauth_code("/not-callback?code=abc123"), None);
}
}
19 changes: 17 additions & 2 deletions bin/sozo/src/commands/options/account/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ use std::collections::HashMap;
use std::str::FromStr;
use std::sync::Arc;

use anyhow::{anyhow, Result};
use anyhow::{Result, anyhow};
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

ohayo sensei — Fix rustfmt import ordering to unblock CI.
Rustfmt is currently failing on this import ordering.

Suggested fix
-use anyhow::{Result, anyhow};
+use anyhow::{anyhow, Result};
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
use anyhow::{Result, anyhow};
use anyhow::{anyhow, Result};
🤖 Prompt for AI Agents
In `@bin/sozo/src/commands/options/account/mod.rs` at line 5, The import items in
the line "use anyhow::{Result, anyhow};" are out of rustfmt's required order;
change the nested import to alphabetic order so it reads "use anyhow::{anyhow,
Result};" to satisfy rustfmt and unblock CI.

use clap::{Args, ValueEnum};
use dojo_utils::env::DOJO_ACCOUNT_ADDRESS_ENV_VAR;
use dojo_world::config::Environment;
Expand Down Expand Up @@ -52,6 +52,13 @@ pub struct AccountOptions {
#[cfg(feature = "controller")]
pub controller: bool,

#[arg(global = true)]
#[arg(long = "session")]
#[arg(help_heading = "Controller options")]
#[arg(help = "Use Cartridge Controller session account (alias of --slot.controller)")]
#[cfg(feature = "controller")]
pub session: bool,

#[command(flatten)]
#[command(next_help_heading = "Signer options")]
pub signer: SignerOptions,
Expand Down Expand Up @@ -85,7 +92,7 @@ impl AccountOptions {
P: Send + Sync,
{
#[cfg(feature = "controller")]
if self.controller {
if self.controller || self.session {
let url = starknet.url(env_metadata)?;
let cartridge_provider = CartridgeJsonRpcProvider::new(url.clone());
let account = self.controller(url, cartridge_provider.clone(), contracts).await?;
Expand Down Expand Up @@ -222,6 +229,14 @@ mod tests {
);
}

#[cfg(feature = "controller")]
#[test]
fn controller_session_alias_flag_is_parsed() {
let cmd = Command::parse_from(["sozo", "--session"]);
assert!(cmd.account.session);
assert!(!cmd.account.controller);
}

#[test]
fn account_address_from_both() {
let env_metadata = dojo_world::config::Environment {
Expand Down
Loading
Loading