Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 48 additions & 3 deletions src/tests/simulation/src/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ use std::time::Duration;

use indexmap::IndexMap;
use itertools::Itertools;
use risingwave_sqlparser::ast::Statement;
use risingwave_sqlparser::ast::{RoleContextModifier, SetRoleSpec, Statement};
use risingwave_sqlparser::parser::Parser;
use shell_words::split;
use sqllogictest::{DBOutput, DefaultColumnType};
Expand Down Expand Up @@ -60,7 +60,24 @@ impl SetStmts {
// store complete sql as value.
self.stmts.insert(key, sql.to_owned());
}
_ => unreachable!(),
Statement::SetTimeZone { .. } => {
self.stmts.insert("time zone".to_owned(), sql.to_owned());
}
Statement::SetRole {
context_modifier: Some(RoleContextModifier::Local),
..
} => {}
Statement::SetRole { role_name, .. } => {
if matches!(role_name, SetRoleSpec::None) {
self.stmts.shift_remove("role");
} else {
self.stmts.insert("role".to_owned(), sql.to_owned());
}
}
Statement::ResetRole => {
self.stmts.shift_remove("role");
}
_ => {}
}
}

Expand Down Expand Up @@ -183,7 +200,8 @@ impl sqllogictest::AsyncDB for RisingWave {
self.reconnect().await?;
}

if sql.trim_start().to_lowercase().starts_with("set") {
let normalized_sql = sql.trim_start().to_lowercase();
if normalized_sql.starts_with("set") || normalized_sql.starts_with("reset role") {
self.set_stmts.push(sql);
}

Expand Down Expand Up @@ -261,4 +279,31 @@ mod tests {
assert_eq!(output.status.code(), Some(1));
assert_eq!(output.stderr, b"ctl failed\n");
}

#[test]
fn set_stmts_replays_session_role_until_reset() {
let mut set_stmts = SetStmts::default();

set_stmts.push("SET ROLE rw_visible_user");
assert_eq!(
set_stmts.replay_iter().collect::<Vec<_>>(),
vec!["SET ROLE rw_visible_user"]
);

set_stmts.push("RESET ROLE");
assert!(set_stmts.replay_iter().collect::<Vec<_>>().is_empty());

set_stmts.push("SET ROLE rw_visible_user");
set_stmts.push("SET ROLE NONE");
assert!(set_stmts.replay_iter().collect::<Vec<_>>().is_empty());
}

#[test]
fn set_stmts_does_not_replay_local_role() {
let mut set_stmts = SetStmts::default();

set_stmts.push("SET LOCAL ROLE rw_visible_user");

assert!(set_stmts.replay_iter().collect::<Vec<_>>().is_empty());
}
}
Loading