diff --git a/src/tests/simulation/src/client.rs b/src/tests/simulation/src/client.rs index 5f53aaec708c9..4b270627cf7e2 100644 --- a/src/tests/simulation/src/client.rs +++ b/src/tests/simulation/src/client.rs @@ -16,7 +16,7 @@ use std::time::Duration; use indexmap::IndexMap; use itertools::Itertools; -use risingwave_sqlparser::ast::Statement; +use risingwave_sqlparser::ast::{RoleContextModifier, SetRoleSpec, Statement}; use risingwave_sqlparser::parser::Parser; use shell_words::split; use sqllogictest::{DBOutput, DefaultColumnType}; @@ -60,7 +60,24 @@ impl SetStmts { // store complete sql as value. self.stmts.insert(key, sql.to_owned()); } - _ => unreachable!(), + Statement::SetTimeZone { .. } => { + self.stmts.insert("time zone".to_owned(), sql.to_owned()); + } + Statement::SetRole { + context_modifier: Some(RoleContextModifier::Local), + .. + } => {} + Statement::SetRole { role_name, .. } => { + if matches!(role_name, SetRoleSpec::None) { + self.stmts.shift_remove("role"); + } else { + self.stmts.insert("role".to_owned(), sql.to_owned()); + } + } + Statement::ResetRole => { + self.stmts.shift_remove("role"); + } + _ => {} } } @@ -183,7 +200,8 @@ impl sqllogictest::AsyncDB for RisingWave { self.reconnect().await?; } - if sql.trim_start().to_lowercase().starts_with("set") { + let normalized_sql = sql.trim_start().to_lowercase(); + if normalized_sql.starts_with("set") || normalized_sql.starts_with("reset role") { self.set_stmts.push(sql); } @@ -261,4 +279,31 @@ mod tests { assert_eq!(output.status.code(), Some(1)); assert_eq!(output.stderr, b"ctl failed\n"); } + + #[test] + fn set_stmts_replays_session_role_until_reset() { + let mut set_stmts = SetStmts::default(); + + set_stmts.push("SET ROLE rw_visible_user"); + assert_eq!( + set_stmts.replay_iter().collect::>(), + vec!["SET ROLE rw_visible_user"] + ); + + set_stmts.push("RESET ROLE"); + assert!(set_stmts.replay_iter().collect::>().is_empty()); + + set_stmts.push("SET ROLE rw_visible_user"); + set_stmts.push("SET ROLE NONE"); + assert!(set_stmts.replay_iter().collect::>().is_empty()); + } + + #[test] + fn set_stmts_does_not_replay_local_role() { + let mut set_stmts = SetStmts::default(); + + set_stmts.push("SET LOCAL ROLE rw_visible_user"); + + assert!(set_stmts.replay_iter().collect::>().is_empty()); + } }