diff --git a/Cargo.lock b/Cargo.lock index 872095752b9e..6081085ba073 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2139,7 +2139,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "117725a109d387c937a1533ce01b450cbde6b88abceea8473c4d7a85853cda3c" dependencies = [ "lazy_static", - "windows-sys 0.59.0", + "windows-sys 0.48.0", ] [[package]] @@ -2632,10 +2632,13 @@ dependencies = [ "datafusion", "datafusion-common", "datafusion-expr", + "datafusion-proto", "datatypes", "futures-util", "once_cell", + "prost 0.14.1", "serde", + "serde_json", "snafu 0.8.6", "sqlparser", "sqlparser_derive 0.1.1", @@ -4129,6 +4132,43 @@ dependencies = [ "tokio", ] +[[package]] +name = "datafusion-proto" +version = "52.1.0" +source = "git+https://github.com/GreptimeTeam/datafusion.git?rev=02b82535e0160c4545667f36a03e1ff9d1d2e51f#02b82535e0160c4545667f36a03e1ff9d1d2e51f" +dependencies = [ + "arrow 57.3.0", + "chrono", + "datafusion-catalog", + "datafusion-catalog-listing", + "datafusion-common", + "datafusion-datasource", + "datafusion-datasource-arrow", + "datafusion-datasource-csv", + "datafusion-datasource-json", + "datafusion-datasource-parquet", + "datafusion-execution", + "datafusion-expr", + "datafusion-functions-table", + "datafusion-physical-expr", + "datafusion-physical-expr-common", + "datafusion-physical-plan", + "datafusion-proto-common", + "object_store", + "prost 0.14.1", + "rand 0.9.1", +] + +[[package]] +name = "datafusion-proto-common" +version = "52.1.0" +source = "git+https://github.com/GreptimeTeam/datafusion.git?rev=02b82535e0160c4545667f36a03e1ff9d1d2e51f#02b82535e0160c4545667f36a03e1ff9d1d2e51f" +dependencies = [ + "arrow 57.3.0", + "datafusion-common", + "prost 0.14.1", +] + [[package]] name = "datafusion-pruning" version = "52.1.0" @@ -5681,7 +5721,7 @@ dependencies = [ [[package]] name = "greptime-proto" version = "0.1.0" -source = "git+https://github.com/GreptimeTeam/greptime-proto.git?rev=092ba1d01e2da676dca66cca7eebb55009da8ef8#092ba1d01e2da676dca66cca7eebb55009da8ef8" +source = "git+https://github.com/GreptimeTeam/greptime-proto.git?rev=9423b6ae25e8e64b1c57ef2594a6a7698efb3c5a#9423b6ae25e8e64b1c57ef2594a6a7698efb3c5a" dependencies = [ "prost 0.14.1", "prost-types 0.14.1", @@ -6201,7 +6241,7 @@ dependencies = [ "libc", "percent-encoding", "pin-project-lite", - "socket2 0.6.0", + "socket2 0.5.10", "tokio", "tower-service", "tracing", @@ -7290,7 +7330,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "07033963ba89ebaf1584d767badaa2e8fcec21aedea6b8c0346d487d49c28667" dependencies = [ "cfg-if", - "windows-targets 0.52.6", + "windows-targets 0.48.5", ] [[package]] @@ -10321,7 +10361,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ac6c3320f9abac597dcbc668774ef006702672474aad53c6d596b62e487b40b1" dependencies = [ "heck 0.5.0", - "itertools 0.14.0", + "itertools 0.10.5", "log", "multimap", "once_cell", @@ -10369,7 +10409,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8a56d757972c98b346a9b766e3f02746cde6dd1cd1d1d563472929fdd74bec4d" dependencies = [ "anyhow", - "itertools 0.14.0", + "itertools 0.10.5", "proc-macro2", "quote", "syn 2.0.117", @@ -10382,7 +10422,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9120690fafc389a67ba3803df527d0ec9cbbc9cc45e4cc20b332996dfb672425" dependencies = [ "anyhow", - "itertools 0.14.0", + "itertools 0.10.5", "proc-macro2", "quote", "syn 2.0.117", @@ -12137,6 +12177,7 @@ dependencies = [ "derive_more", "snafu 0.8.6", "sql", + "uuid", ] [[package]] @@ -15039,7 +15080,7 @@ version = "0.1.9" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "cf221c93e13a30d793f7645a0e7762c55d169dbb0a49671918a2319d289b10bb" dependencies = [ - "windows-sys 0.59.0", + "windows-sys 0.48.0", ] [[package]] diff --git a/Cargo.toml b/Cargo.toml index 227608bf64a8..5b5033455d0f 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -139,6 +139,7 @@ datafusion-orc = "0.7" datafusion-pg-catalog = "0.15.1" datafusion-physical-expr = "=52.1" datafusion-physical-plan = "=52.1" +datafusion-proto = "=52.1" datafusion-sql = "=52.1" datafusion-substrait = "=52.1" deadpool = "0.12" @@ -154,7 +155,7 @@ etcd-client = { version = "0.17", features = [ fst = "0.4.7" futures = "0.3" futures-util = "0.3" -greptime-proto = { git = "https://github.com/GreptimeTeam/greptime-proto.git", rev = "092ba1d01e2da676dca66cca7eebb55009da8ef8" } +greptime-proto = { git = "https://github.com/GreptimeTeam/greptime-proto.git", rev = "9423b6ae25e8e64b1c57ef2594a6a7698efb3c5a" } hex = "0.4" http = "1" humantime = "2.1" @@ -251,7 +252,7 @@ tracing-appender = "0.2" tracing-opentelemetry = "0.31.0" tracing-subscriber = { version = "0.3", features = ["env-filter", "json", "fmt"] } typetag = "0.2" -uuid = { version = "1.17", features = ["serde", "v4", "fast-rng"] } +uuid = { version = "1.17", features = ["serde", "v4", "v7", "fast-rng"] } vrl = "0.25" zstd = "0.13" # DO_NOT_REMOVE_THIS: END_OF_EXTERNAL_DEPENDENCIES @@ -341,6 +342,7 @@ datafusion-optimizer = { git = "https://github.com/GreptimeTeam/datafusion.git", datafusion-physical-expr = { git = "https://github.com/GreptimeTeam/datafusion.git", rev = "02b82535e0160c4545667f36a03e1ff9d1d2e51f" } datafusion-physical-expr-common = { git = "https://github.com/GreptimeTeam/datafusion.git", rev = "02b82535e0160c4545667f36a03e1ff9d1d2e51f" } datafusion-physical-plan = { git = "https://github.com/GreptimeTeam/datafusion.git", rev = "02b82535e0160c4545667f36a03e1ff9d1d2e51f" } +datafusion-proto = { git = "https://github.com/GreptimeTeam/datafusion.git", rev = "02b82535e0160c4545667f36a03e1ff9d1d2e51f" } datafusion-datasource = { git = "https://github.com/GreptimeTeam/datafusion.git", rev = "02b82535e0160c4545667f36a03e1ff9d1d2e51f" } datafusion-sql = { git = "https://github.com/GreptimeTeam/datafusion.git", rev = "02b82535e0160c4545667f36a03e1ff9d1d2e51f" } datafusion-substrait = { git = "https://github.com/GreptimeTeam/datafusion.git", rev = "02b82535e0160c4545667f36a03e1ff9d1d2e51f" } diff --git a/src/client/src/region.rs b/src/client/src/region.rs index 8eefb16e0d13..72321b265983 100644 --- a/src/client/src/region.rs +++ b/src/client/src/region.rs @@ -16,7 +16,10 @@ use std::sync::Arc; use api::region::RegionResponse; use api::v1::ResponseHeader; -use api::v1::region::RegionRequest; +use api::v1::region::{ + RegionRequest, RegionRequestHeader, RemoteDynFilterRequest, RemoteDynFilterUnregister, + RemoteDynFilterUpdate, region_request, remote_dyn_filter_request, +}; use arc_swap::ArcSwapOption; use arrow_flight::Ticket; use async_stream::stream; @@ -284,6 +287,48 @@ impl RegionRequester { pub async fn handle(&self, request: RegionRequest) -> Result { self.handle_inner(request).await } + + pub async fn handle_remote_dyn_filter_update( + &self, + query_id: impl Into, + update: RemoteDynFilterUpdate, + ) -> Result { + self.handle_inner(build_remote_dyn_filter_request( + query_id.into(), + remote_dyn_filter_request::Action::Update(update), + )) + .await + } + + pub async fn handle_remote_dyn_filter_unregister( + &self, + query_id: impl Into, + unregister: RemoteDynFilterUnregister, + ) -> Result { + self.handle_inner(build_remote_dyn_filter_request( + query_id.into(), + remote_dyn_filter_request::Action::Unregister(unregister), + )) + .await + } +} + +fn build_remote_dyn_filter_request( + query_id: String, + action: remote_dyn_filter_request::Action, +) -> RegionRequest { + RegionRequest { + header: Some(RegionRequestHeader { + tracing_context: TracingContext::from_current_span().to_w3c(), + ..Default::default() + }), + body: Some(region_request::Body::RemoteDynFilter( + RemoteDynFilterRequest { + query_id, + action: Some(action), + }, + )), + } } pub fn check_response_header(header: &Option) -> Result<()> { @@ -312,6 +357,7 @@ pub fn check_response_header(header: &Option) -> Result<()> { #[cfg(test)] mod test { use api::v1::Status as PbStatus; + use api::v1::region::{RemoteDynFilterUpdate, region_request, remote_dyn_filter_request}; use super::*; use crate::Error::{IllegalDatabaseResponse, Server}; @@ -361,4 +407,30 @@ mod test { assert_eq!(code, StatusCode::Internal); assert_eq!(msg, "blabla"); } + + #[test] + fn test_build_remote_dyn_filter_request_sets_header_and_body() { + let request = build_remote_dyn_filter_request( + "query-1".to_string(), + remote_dyn_filter_request::Action::Update(RemoteDynFilterUpdate { + filter_id: "filter-1".to_string(), + payload: vec![1, 2, 3], + generation: 7, + is_complete: false, + }), + ); + + request.header.expect("remote dyn filter header must exist"); + + let body = request.body.expect("remote dyn filter body must exist"); + let region_request::Body::RemoteDynFilter(remote_request) = body else { + panic!("expected remote dyn filter request body"); + }; + + assert_eq!(remote_request.query_id, "query-1"); + assert!(matches!( + remote_request.action, + Some(remote_dyn_filter_request::Action::Update(_)) + )); + } } diff --git a/src/common/query/Cargo.toml b/src/common/query/Cargo.toml index 48328ea6121f..13f3c174b65c 100644 --- a/src/common/query/Cargo.toml +++ b/src/common/query/Cargo.toml @@ -22,8 +22,10 @@ common-time.workspace = true datafusion.workspace = true datafusion-common.workspace = true datafusion-expr.workspace = true +datafusion-proto.workspace = true datatypes.workspace = true once_cell.workspace = true +prost.workspace = true serde.workspace = true snafu.workspace = true sqlparser.workspace = true @@ -33,4 +35,5 @@ store-api.workspace = true [dev-dependencies] common-base.workspace = true futures-util.workspace = true +serde_json.workspace = true tokio.workspace = true diff --git a/src/common/query/src/request.rs b/src/common/query/src/request.rs index 260a43e79df0..c33e209557d8 100644 --- a/src/common/query/src/request.rs +++ b/src/common/query/src/request.rs @@ -12,10 +12,162 @@ // See the License for the specific language governing permissions and // limitations under the License. +use std::sync::Arc; + use api::v1::region::RegionRequestHeader; +use datafusion::arrow::datatypes::Schema; +use datafusion::execution::TaskContext; +use datafusion::physical_expr::expressions::Column; +use datafusion::physical_plan::PhysicalExpr; +use datafusion::physical_plan::joins::HashTableLookupExpr; +use datafusion_common::tree_node::{TreeNode, TreeNodeRecursion}; +use datafusion_common::{DataFusionError, Result as DataFusionResult}; use datafusion_expr::LogicalPlan; +use datafusion_proto::physical_plan::DefaultPhysicalExtensionCodec; +use datafusion_proto::physical_plan::from_proto::parse_physical_expr; +use datafusion_proto::physical_plan::to_proto::serialize_physical_expr; +use datafusion_proto::protobuf::PhysicalExprNode; +use prost::Message; +use serde::{Deserialize, Serialize}; use store_api::storage::RegionId; +pub const DYN_FILTER_PROTOCOL_VERSION: u32 = 1; + +#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)] +#[non_exhaustive] +#[serde(tag = "kind", content = "payload", rename_all = "snake_case")] +pub enum DynFilterPayload { + Datafusion(Vec), +} + +impl DynFilterPayload { + pub fn from_datafusion_expr( + expr: &Arc, + max_payload_bytes: usize, + ) -> DataFusionResult { + validate_supported_payload_expr(expr)?; + + let codec = DefaultPhysicalExtensionCodec {}; + let proto = serialize_physical_expr(expr, &codec)?; + let mut bytes = Vec::new(); + proto.encode(&mut bytes).map_err(|e| { + DataFusionError::Internal(format!("Failed to encode PhysicalExprNode: {e}")) + })?; + + validate_payload_size(bytes.len(), max_payload_bytes)?; + + Ok(Self::Datafusion(bytes)) + } + + pub fn decode_datafusion_expr( + &self, + task_ctx: &TaskContext, + input_schema: &Schema, + max_payload_bytes: usize, + ) -> DataFusionResult> { + let Self::Datafusion(bytes) = self; + validate_payload_size(bytes.len(), max_payload_bytes)?; + let codec = DefaultPhysicalExtensionCodec {}; + let proto = PhysicalExprNode::decode(bytes.as_slice()).map_err(|e| { + DataFusionError::Internal(format!("Failed to decode PhysicalExprNode: {e}")) + })?; + + let expr = parse_physical_expr(&proto, task_ctx, input_schema, &codec)?; + validate_supported_payload_expr(&expr)?; + validate_decoded_payload_expr(&expr, input_schema)?; + Ok(expr) + } +} + +fn validate_payload_size( + payload_size_bytes: usize, + max_payload_bytes: usize, +) -> DataFusionResult<()> { + if payload_size_bytes > max_payload_bytes { + return Err(DataFusionError::Plan(format!( + "DynFilterPayload::Datafusion is {} bytes, which exceeds the configured limit of {} bytes", + payload_size_bytes, max_payload_bytes + ))); + } + + Ok(()) +} + +fn validate_supported_payload_expr(expr: &Arc) -> DataFusionResult<()> { + expr.apply(|node| { + if node.as_any().is::() { + return Err(DataFusionError::Plan( + "HashTableLookupExpr cannot be encoded into DynFilterPayload::Datafusion" + .to_string(), + )); + } + + Ok(TreeNodeRecursion::Continue) + })?; + + Ok(()) +} + +fn validate_decoded_payload_expr( + expr: &Arc, + input_schema: &Schema, +) -> DataFusionResult<()> { + expr.apply(|node| { + if let Some(column) = node.as_any().downcast_ref::() { + let Some(field) = input_schema.fields().get(column.index()) else { + return Err(DataFusionError::Plan(format!( + "Decoded Column '{}' references out-of-bounds index {} for input schema of size {}", + column.name(), + column.index(), + input_schema.fields().len() + ))); + }; + + if field.name() != column.name() { + return Err(DataFusionError::Plan(format!( + "Decoded Column name/index mismatch: payload has '{}' at index {}, but schema field is '{}'", + column.name(), + column.index(), + field.name() + ))); + } + } + + Ok(TreeNodeRecursion::Continue) + })?; + + Ok(()) +} + +#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)] +pub struct DynFilterUpdate { + pub protocol_version: u32, + pub query_id: String, + pub filter_id: String, + pub epoch: u64, + pub is_complete: bool, + pub payload: DynFilterPayload, +} + +impl DynFilterUpdate { + pub fn new( + query_id: String, + filter_id: String, + epoch: u64, + is_complete: bool, + payload: DynFilterPayload, + ) -> Self { + Self { + protocol_version: DYN_FILTER_PROTOCOL_VERSION, + query_id, + filter_id, + epoch, + is_complete, + payload, + } + } +} + /// The query request to be handled by the RegionServer (Datanode). #[derive(Clone, Debug)] pub struct QueryRequest { @@ -28,3 +180,102 @@ pub struct QueryRequest { /// The form of the query: a logical plan. pub plan: LogicalPlan, } + +#[cfg(test)] +mod tests { + use std::sync::Arc; + + use datafusion::arrow::datatypes::{DataType, Field, Schema}; + use datafusion::physical_expr::expressions::Column; + + use super::*; + + #[test] + fn dyn_filter_update_sets_protocol_version() { + let update = DynFilterUpdate::new( + "query-1".to_string(), + "filter-1".to_string(), + 3, + false, + DynFilterPayload::Datafusion(vec![1, 2, 3]), + ); + + assert_eq!(update.protocol_version, DYN_FILTER_PROTOCOL_VERSION); + assert!(!update.is_complete); + assert!( + matches!(update.payload, DynFilterPayload::Datafusion(ref bytes) if bytes == &vec![1, 2, 3]) + ); + } + + #[test] + fn dyn_filter_update_json_round_trip_preserves_payload_shape() { + let update = DynFilterUpdate::new( + "query-2".to_string(), + "filter-9".to_string(), + 9, + true, + DynFilterPayload::Datafusion(vec![9, 8, 7]), + ); + + let json = serde_json::to_string(&update).unwrap(); + let decoded: DynFilterUpdate = serde_json::from_str(&json).unwrap(); + + assert_eq!(decoded, update); + assert!(decoded.is_complete); + assert!( + matches!(decoded.payload, DynFilterPayload::Datafusion(ref bytes) if bytes == &vec![9, 8, 7]) + ); + } + + #[test] + fn dyn_filter_payload_round_trips_physical_column_expr() { + let schema = Schema::new(vec![Field::new("host", DataType::Utf8, false)]); + let expr: Arc = + Arc::new(Column::new_with_schema("host", &schema).unwrap()); + + let payload = DynFilterPayload::from_datafusion_expr(&expr, 1024).unwrap(); + let decoded = payload + .decode_datafusion_expr(&TaskContext::default(), &schema, 1024) + .unwrap(); + + let original = expr.as_any().downcast_ref::().unwrap(); + let decoded = decoded.as_any().downcast_ref::().unwrap(); + + assert_eq!(decoded.name(), original.name()); + assert_eq!(decoded.index(), original.index()); + } + + #[test] + fn dyn_filter_payload_decode_rejects_invalid_bytes() { + let schema = Schema::new(vec![Field::new("host", DataType::Utf8, false)]); + let payload = DynFilterPayload::Datafusion(vec![1, 2, 3]); + + let err = payload + .decode_datafusion_expr(&TaskContext::default(), &schema, 1024) + .unwrap_err(); + + assert!(matches!(err, DataFusionError::Internal(_))); + } + + #[test] + fn dyn_filter_payload_decode_rejects_column_name_index_mismatch() { + let schema = Schema::new(vec![Field::new("host", DataType::Utf8, false)]); + let mismatched_expr: Arc = Arc::new(Column::new("service", 0)); + + let payload = DynFilterPayload::from_datafusion_expr(&mismatched_expr, 1024).unwrap(); + let err = payload + .decode_datafusion_expr(&TaskContext::default(), &schema, 1024) + .unwrap_err(); + + assert!(matches!(err, DataFusionError::Plan(_))); + } + + #[test] + fn dyn_filter_payload_rejects_oversized_payload() { + let expr: Arc = Arc::new(Column::new("host", 0)); + + let err = DynFilterPayload::from_datafusion_expr(&expr, 1).unwrap_err(); + + assert!(matches!(err, DataFusionError::Plan(_))); + } +} diff --git a/src/datanode/src/region_server.rs b/src/datanode/src/region_server.rs index ec10691beab5..9bf028fef7ad 100644 --- a/src/datanode/src/region_server.rs +++ b/src/datanode/src/region_server.rs @@ -25,7 +25,8 @@ use api::region::RegionResponse; use api::v1::meta::TopicStat; use api::v1::region::sync_request::ManifestInfo; use api::v1::region::{ - ListMetadataRequest, RegionResponse as RegionResponseV1, SyncRequest, region_request, + ListMetadataRequest, RegionResponse as RegionResponseV1, RemoteDynFilterRequest, SyncRequest, + region_request, }; use api::v1::{ResponseHeader, Status}; use arrow_flight::{FlightData, Ticket}; @@ -84,8 +85,9 @@ use crate::error::{ ConcurrentQueryLimiterTimeoutSnafu, DataFusionSnafu, DecodeLogicalPlanSnafu, ExecuteLogicalPlanSnafu, FindLogicalRegionsSnafu, GetRegionMetadataSnafu, HandleBatchDdlRequestSnafu, HandleBatchOpenRequestSnafu, HandleRegionRequestSnafu, - NewPlanDecoderSnafu, RegionEngineNotFoundSnafu, RegionNotFoundSnafu, RegionNotReadySnafu, - Result, SerializeJsonSnafu, StopRegionEngineSnafu, UnexpectedSnafu, UnsupportedOutputSnafu, + NewPlanDecoderSnafu, NotYetImplementedSnafu, RegionEngineNotFoundSnafu, RegionNotFoundSnafu, + RegionNotReadySnafu, Result, SerializeJsonSnafu, StopRegionEngineSnafu, UnexpectedSnafu, + UnsupportedOutputSnafu, }; use crate::event_listener::RegionServerEventListenerRef; use crate::region_server::catalog::{NameAwareCatalogList, NameAwareDataSourceInjectorBuilder}; @@ -696,6 +698,70 @@ impl RegionServer { Ok(response) } + async fn handle_remote_dyn_filter_request( + &self, + request: &RemoteDynFilterRequest, + ) -> Result { + if request.query_id.is_empty() { + return error::MissingRequiredFieldSnafu { name: "query_id" }.fail(); + } + + match request + .action + .as_ref() + .context(error::MissingRequiredFieldSnafu { name: "action" })? + { + api::v1::region::remote_dyn_filter_request::Action::Update(update) => { + self.handle_remote_dyn_filter_update(&request.query_id, update) + .await + } + api::v1::region::remote_dyn_filter_request::Action::Unregister(unregister) => { + self.handle_remote_dyn_filter_unregister(&request.query_id, unregister) + .await + } + } + } + + async fn handle_remote_dyn_filter_update( + &self, + query_id: &str, + request: &api::v1::region::RemoteDynFilterUpdate, + ) -> Result { + if request.filter_id.is_empty() { + return error::MissingRequiredFieldSnafu { name: "filter_id" }.fail(); + } + + if request.payload.is_empty() { + return error::MissingRequiredFieldSnafu { name: "payload" }.fail(); + } + + NotYetImplementedSnafu { + what: format!( + "remote dyn filter update unary RPC placeholder for query_id {query_id}, filter_id {}", + request.filter_id + ), + } + .fail() + } + + async fn handle_remote_dyn_filter_unregister( + &self, + query_id: &str, + request: &api::v1::region::RemoteDynFilterUnregister, + ) -> Result { + if request.filter_id.is_empty() { + return error::MissingRequiredFieldSnafu { name: "filter_id" }.fail(); + } + + NotYetImplementedSnafu { + what: format!( + "remote dyn filter unregister unary RPC placeholder for query_id {query_id}, filter_id {}", + request.filter_id + ), + } + .fail() + } + /// Sync region manifest and registers new opened logical regions. pub async fn sync_region( &self, @@ -767,6 +833,10 @@ impl RegionServerHandler for RegionServer { self.handle_list_metadata_request(list_metadata_request) .await } + region_request::Body::RemoteDynFilter(remote_dyn_filter_request) => { + self.handle_remote_dyn_filter_request(remote_dyn_filter_request) + .await + } _ => self.handle_requests_in_serial(request).await, } .map_err(BoxedError::new) @@ -1670,6 +1740,10 @@ mod tests { use std::assert_matches; use api::v1::SemanticType; + use api::v1::region::{ + RemoteDynFilterRequest, RemoteDynFilterUnregister, RemoteDynFilterUpdate, + remote_dyn_filter_request, + }; use common_error::ext::ErrorExt; use datatypes::prelude::ConcreteDataType; use mito2::test_util::CreateRequestBuilder; @@ -2304,4 +2378,135 @@ mod tests { .await .unwrap_err(); } + + #[tokio::test] + async fn test_handle_remote_dyn_filter_request_requires_query_id() { + let mock_region_server = mock_region_server(); + + let err = mock_region_server + .handle_remote_dyn_filter_request(&RemoteDynFilterRequest { + query_id: String::new(), + action: Some(remote_dyn_filter_request::Action::Unregister( + RemoteDynFilterUnregister { + filter_id: "filter-1".to_string(), + }, + )), + }) + .await + .unwrap_err(); + + assert_matches!( + err, + crate::error::Error::MissingRequiredField { ref name, .. } if name == "query_id" + ); + } + + #[tokio::test] + async fn test_handle_remote_dyn_filter_request_requires_action() { + let mock_region_server = mock_region_server(); + + let err = mock_region_server + .handle_remote_dyn_filter_request(&RemoteDynFilterRequest { + query_id: "query-1".to_string(), + action: None, + }) + .await + .unwrap_err(); + + assert_matches!( + err, + crate::error::Error::MissingRequiredField { ref name, .. } if name == "action" + ); + } + + #[tokio::test] + async fn test_handle_remote_dyn_filter_update_requires_filter_id() { + let mock_region_server = mock_region_server(); + + let err = mock_region_server + .handle_remote_dyn_filter_request(&RemoteDynFilterRequest { + query_id: "query-1".to_string(), + action: Some(remote_dyn_filter_request::Action::Update( + RemoteDynFilterUpdate { + filter_id: String::new(), + payload: vec![1], + generation: 1, + is_complete: false, + }, + )), + }) + .await + .unwrap_err(); + + assert_matches!( + err, + crate::error::Error::MissingRequiredField { ref name, .. } if name == "filter_id" + ); + } + + #[tokio::test] + async fn test_handle_remote_dyn_filter_update_requires_payload() { + let mock_region_server = mock_region_server(); + + let err = mock_region_server + .handle_remote_dyn_filter_request(&RemoteDynFilterRequest { + query_id: "query-1".to_string(), + action: Some(remote_dyn_filter_request::Action::Update( + RemoteDynFilterUpdate { + filter_id: "filter-1".to_string(), + payload: Vec::new(), + generation: 1, + is_complete: false, + }, + )), + }) + .await + .unwrap_err(); + + assert_matches!( + err, + crate::error::Error::MissingRequiredField { ref name, .. } if name == "payload" + ); + } + + #[tokio::test] + async fn test_handle_remote_dyn_filter_update_placeholder() { + let mock_region_server = mock_region_server(); + + let err = mock_region_server + .handle_remote_dyn_filter_request(&RemoteDynFilterRequest { + query_id: "query-1".to_string(), + action: Some(remote_dyn_filter_request::Action::Update( + RemoteDynFilterUpdate { + filter_id: "filter-1".to_string(), + payload: vec![1], + generation: 1, + is_complete: false, + }, + )), + }) + .await + .unwrap_err(); + + assert_matches!(err, crate::error::Error::NotYetImplemented { .. }); + } + + #[tokio::test] + async fn test_handle_remote_dyn_filter_unregister_placeholder() { + let mock_region_server = mock_region_server(); + + let err = mock_region_server + .handle_remote_dyn_filter_request(&RemoteDynFilterRequest { + query_id: "query-1".to_string(), + action: Some(remote_dyn_filter_request::Action::Unregister( + RemoteDynFilterUnregister { + filter_id: "filter-1".to_string(), + }, + )), + }) + .await + .unwrap_err(); + + assert_matches!(err, crate::error::Error::NotYetImplemented { .. }); + } } diff --git a/src/servers/src/grpc/context_auth.rs b/src/servers/src/grpc/context_auth.rs index 39c4fc5c88df..0cf71cc6a14b 100644 --- a/src/servers/src/grpc/context_auth.rs +++ b/src/servers/src/grpc/context_auth.rs @@ -20,7 +20,10 @@ use auth::{Identity, Password, UserInfoRef, UserProviderRef}; use common_catalog::consts::{DEFAULT_CATALOG_NAME, DEFAULT_SCHEMA_NAME}; use common_catalog::parse_catalog_and_schema_from_db_string; use common_error::ext::ErrorExt; -use session::context::{Channel, QueryContextBuilder, QueryContextRef}; +use session::context::{ + Channel, QueryContextBuilder, QueryContextRef, REMOTE_QUERY_ID_EXTENSION_KEY, + generate_remote_query_id, +}; use snafu::{OptionExt, ResultExt}; use tonic::Status; use tonic::metadata::MetadataMap; @@ -50,6 +53,10 @@ pub fn create_query_context_from_grpc_metadata( .current_catalog(catalog) .current_schema(schema) .channel(Channel::Grpc) + .set_extension( + REMOTE_QUERY_ID_EXTENSION_KEY.to_string(), + generate_remote_query_id(), + ) .build(), )) } diff --git a/src/servers/src/grpc/greptime_handler.rs b/src/servers/src/grpc/greptime_handler.rs index c1f146db6da2..3763725bdcec 100644 --- a/src/servers/src/grpc/greptime_handler.rs +++ b/src/servers/src/grpc/greptime_handler.rs @@ -33,8 +33,11 @@ use common_telemetry::tracing_context::{FutureExt, TracingContext}; use common_telemetry::{debug, error, tracing, warn}; use common_time::timezone::parse_timezone; use futures_util::StreamExt; -use session::context::{Channel, QueryContextBuilder, QueryContextRef}; -use session::hints::READ_PREFERENCE_HINT; +use session::context::{ + Channel, QueryContextBuilder, QueryContextRef, REMOTE_QUERY_ID_EXTENSION_KEY, + generate_remote_query_id, +}; +use session::hints::{READ_PREFERENCE_HINT, is_reserved_extension_key}; use snafu::{OptionExt, ResultExt}; use tokio::sync::mpsc; use tokio::sync::mpsc::error::TrySendError; @@ -214,7 +217,11 @@ pub(crate) fn create_query_context( .current_catalog(catalog) .current_schema(schema) .timezone(timezone) - .channel(channel); + .channel(channel) + .set_extension( + REMOTE_QUERY_ID_EXTENSION_KEY.to_string(), + generate_remote_query_id(), + ); if let Some(x) = extensions .iter() @@ -231,6 +238,13 @@ pub(crate) fn create_query_context( } for (key, value) in extensions { + if is_reserved_extension_key(&key) { + debug!( + key = key.as_str(), + "Ignoring reserved external query context extension key" + ); + continue; + } ctx_builder = ctx_builder.set_extension(key, value); } Ok(ctx_builder.build().into()) @@ -308,9 +322,35 @@ mod tests { query_context.read_preference(), ReadPreference::Leader )); + let mut extensions = query_context.extensions().into_iter().collect::>(); + extensions.sort_unstable_by(|a, b| a.0.cmp(&b.0)); + assert_eq!( + extensions[0], + ("auto_create_table".to_string(), "true".to_string()) + ); + assert_eq!(extensions[1].0, REMOTE_QUERY_ID_EXTENSION_KEY.to_string()); + assert_eq!( + query_context.remote_query_id(), + Some(extensions[1].1.as_str()) + ); + } + + #[test] + fn test_create_query_context_ignores_remote_query_id_extension() { + let query_context = create_query_context( + Channel::Grpc, + None, + vec![( + REMOTE_QUERY_ID_EXTENSION_KEY.to_string(), + "spoofed-query-id".to_string(), + )], + ) + .unwrap(); + + assert_ne!(query_context.remote_query_id(), Some("spoofed-query-id")); assert_eq!( - query_context.extensions().into_iter().collect::>(), - vec![("auto_create_table".to_string(), "true".to_string())] + query_context.extension(REMOTE_QUERY_ID_EXTENSION_KEY), + query_context.remote_query_id() ); } } diff --git a/src/servers/src/http/authorize.rs b/src/servers/src/http/authorize.rs index d2bfd9eba2f4..07ac146899bc 100644 --- a/src/servers/src/http/authorize.rs +++ b/src/servers/src/http/authorize.rs @@ -28,7 +28,9 @@ use common_telemetry::warn; use common_time::Timezone; use common_time::timezone::parse_timezone; use headers::Header; -use session::context::QueryContextBuilder; +use session::context::{ + QueryContextBuilder, REMOTE_QUERY_ID_EXTENSION_KEY, generate_remote_query_id, +}; use snafu::{OptionExt, ResultExt, ensure}; use crate::error::{ @@ -64,7 +66,11 @@ pub async fn inner_auth( let query_ctx_builder = QueryContextBuilder::default() .current_catalog(catalog.clone()) .current_schema(schema.clone()) - .timezone(timezone); + .timezone(timezone) + .set_extension( + REMOTE_QUERY_ID_EXTENSION_KEY.to_string(), + generate_remote_query_id(), + ); let query_ctx = query_ctx_builder.build(); let need_auth = need_auth(&req); @@ -388,6 +394,19 @@ mod tests { assert!(auth_scheme.is_err()); } + #[test] + fn test_inner_auth_assigns_remote_query_id() { + let req = + mock_http_request(None, Some("http://127.0.0.1/v1/sql?db=greptime-public")).unwrap(); + let req = futures::executor::block_on(inner_auth::<()>(None, req)).unwrap(); + let query_ctx = req + .extensions() + .get::() + .unwrap(); + + assert!(query_ctx.remote_query_id().is_some()); + } + #[test] fn test_auth_header() { // base64encode("username:password") == "dXNlcm5hbWU6cGFzc3dvcmQ=" diff --git a/src/servers/src/http/hints.rs b/src/servers/src/http/hints.rs index 7f98461cf634..406f22673881 100644 --- a/src/servers/src/http/hints.rs +++ b/src/servers/src/http/hints.rs @@ -16,16 +16,65 @@ use axum::body::Body; use axum::http::Request; use axum::middleware::Next; use axum::response::Response; +use common_telemetry::debug; use session::context::QueryContext; +use session::hints::is_reserved_extension_key; use crate::hint_headers; pub async fn extract_hints(mut request: Request, next: Next) -> Response { let hints = hint_headers::extract_hints(request.headers()); if let Some(query_ctx) = request.extensions_mut().get_mut::() { - for (key, value) in hints { - query_ctx.set_extension(key, value); - } + apply_hints(query_ctx, hints); } next.run(request).await } + +fn apply_hints(query_ctx: &mut QueryContext, hints: Vec<(String, String)>) { + for (key, value) in hints { + if is_reserved_extension_key(&key) { + debug!( + key = key.as_str(), + "Ignoring reserved external query context extension key" + ); + continue; + } + query_ctx.set_extension(key, value); + } +} + +#[cfg(test)] +mod tests { + use session::context::{QueryContextBuilder, generate_remote_query_id}; + use session::hints::REMOTE_QUERY_ID_EXTENSION_KEY; + + use super::apply_hints; + + #[test] + fn test_apply_hints_ignores_reserved_extension_keys() { + let original_query_id = generate_remote_query_id(); + let mut query_ctx = QueryContextBuilder::default() + .set_extension( + REMOTE_QUERY_ID_EXTENSION_KEY.to_string(), + original_query_id.clone(), + ) + .build(); + + apply_hints( + &mut query_ctx, + vec![ + ( + REMOTE_QUERY_ID_EXTENSION_KEY.to_string(), + "spoofed".to_string(), + ), + ("ttl".to_string(), "7d".to_string()), + ], + ); + + assert_eq!( + query_ctx.remote_query_id(), + Some(original_query_id.as_str()) + ); + assert_eq!(query_ctx.extension("ttl"), Some("7d")); + } +} diff --git a/src/session/Cargo.toml b/src/session/Cargo.toml index 5b8b60f5ab57..c7be8c2f7e85 100644 --- a/src/session/Cargo.toml +++ b/src/session/Cargo.toml @@ -27,3 +27,4 @@ derive_builder.workspace = true derive_more.workspace = true snafu.workspace = true sql.workspace = true +uuid.workspace = true diff --git a/src/session/src/context.rs b/src/session/src/context.rs index 5f16ea8b5a6e..864db9f1f12a 100644 --- a/src/session/src/context.rs +++ b/src/session/src/context.rs @@ -32,7 +32,9 @@ use datafusion_common::config::ConfigOptions; use derive_builder::Builder; use sql::dialect::{Dialect, GenericDialect, GreptimeDbDialect, MySqlDialect, PostgreSqlDialect}; +pub use crate::hints::REMOTE_QUERY_ID_EXTENSION_KEY; use crate::protocol_ctx::ProtocolCtx; +use crate::query_id::QueryId; use crate::session_config::{PGByteaOutputValue, PGDateOrder, PGDateTimeStyle, PGIntervalStyle}; use crate::{MutableInner, ReadPreference}; @@ -41,6 +43,14 @@ pub type ConnInfoRef = Arc; const CURSOR_COUNT_WARNING_LIMIT: usize = 10; +pub fn generate_remote_query_id() -> String { + generate_remote_query_id_value().to_string() +} + +pub fn generate_remote_query_id_value() -> QueryId { + QueryId::new() +} + #[derive(Debug, Builder, Clone)] #[builder(pattern = "owned")] #[builder(build_fn(skip))] @@ -152,7 +162,12 @@ impl From<&RegionRequestHeader> for QueryContext { if let Some(ctx) = &value.query_context { ctx.clone().into() } else { - QueryContextBuilder::default().build() + QueryContextBuilder::default() + .set_extension( + REMOTE_QUERY_ID_EXTENSION_KEY.to_string(), + generate_remote_query_id(), + ) + .build() } } } @@ -219,7 +234,14 @@ impl From<&QueryContext> for api::v1::QueryContext { impl QueryContext { pub fn arc() -> QueryContextRef { - Arc::new(QueryContextBuilder::default().build()) + Arc::new( + QueryContextBuilder::default() + .set_extension( + REMOTE_QUERY_ID_EXTENSION_KEY.to_string(), + generate_remote_query_id(), + ) + .build(), + ) } /// Create a new datafusion's ConfigOptions instance based on the current QueryContext. @@ -233,6 +255,10 @@ impl QueryContext { QueryContextBuilder::default() .current_catalog(catalog.to_string()) .current_schema(schema.to_string()) + .set_extension( + REMOTE_QUERY_ID_EXTENSION_KEY.to_string(), + generate_remote_query_id(), + ) .build() } @@ -241,6 +267,10 @@ impl QueryContext { .current_catalog(catalog.to_string()) .current_schema(schema.to_string()) .channel(channel) + .set_extension( + REMOTE_QUERY_ID_EXTENSION_KEY.to_string(), + generate_remote_query_id(), + ) .build() } @@ -259,6 +289,10 @@ impl QueryContext { QueryContextBuilder::default() .current_catalog(catalog) .current_schema(schema.clone()) + .set_extension( + REMOTE_QUERY_ID_EXTENSION_KEY.to_string(), + generate_remote_query_id(), + ) .build() } @@ -320,6 +354,15 @@ impl QueryContext { self.extensions.get(key.as_ref()).map(|v| v.as_str()) } + pub fn remote_query_id(&self) -> Option<&str> { + self.extension(REMOTE_QUERY_ID_EXTENSION_KEY) + } + + pub fn remote_query_id_value(&self) -> Option { + self.remote_query_id() + .and_then(|query_id| query_id.parse().ok()) + } + pub fn extensions(&self) -> HashMap { self.extensions.clone() } @@ -483,6 +526,10 @@ impl QueryContext { impl QueryContextBuilder { pub fn build(self) -> QueryContext { let channel = self.channel.unwrap_or_default(); + let mut extensions = self.extensions.unwrap_or_default(); + extensions + .entry(REMOTE_QUERY_ID_EXTENSION_KEY.to_string()) + .or_insert_with(generate_remote_query_id); QueryContext { current_catalog: self .current_catalog @@ -494,7 +541,7 @@ impl QueryContextBuilder { sql_dialect: self .sql_dialect .unwrap_or_else(|| Arc::new(GreptimeDbDialect {})), - extensions: self.extensions.unwrap_or_default(), + extensions, configuration_parameter: self .configuration_parameter .unwrap_or_else(|| Arc::new(ConfigurationVariables::default())), @@ -707,6 +754,9 @@ mod test { assert_eq!("mysql[127.0.0.1:9000]", session.conn_info().to_string()); assert_eq!(100, session.process_id()); + + let query_ctx = session.new_query_context(); + assert!(query_ctx.remote_query_id().is_some()); } #[test] @@ -739,8 +789,40 @@ mod test { assert_eq!(roundtrip_api.current_catalog, api_ctx.current_catalog); assert_eq!(roundtrip_api.current_schema, api_ctx.current_schema); assert_eq!(roundtrip_api.timezone, api_ctx.timezone); - assert_eq!(roundtrip_api.extensions, api_ctx.extensions); + assert_eq!( + roundtrip_api.extensions.get("flow.return_region_seq"), + Some(&"true".to_string()) + ); + assert!( + roundtrip_api + .extensions + .contains_key(REMOTE_QUERY_ID_EXTENSION_KEY) + ); assert_eq!(roundtrip_api.channel, api_ctx.channel); assert_eq!(roundtrip_api.snapshot_seqs, api_ctx.snapshot_seqs); } + + #[test] + fn test_query_context_remote_query_id_round_trip() { + let query_id = "0195f4fd-c503-7c54-8b8f-7dfb8f6f9c4a"; + let ctx = QueryContextBuilder::default() + .current_catalog(DEFAULT_CATALOG_NAME.to_string()) + .current_schema("public".to_string()) + .set_extension( + REMOTE_QUERY_ID_EXTENSION_KEY.to_string(), + query_id.to_string(), + ) + .build(); + + assert_eq!(ctx.remote_query_id(), Some(query_id)); + assert_eq!(ctx.remote_query_id_value().unwrap().to_string(), query_id); + + let proto: api::v1::QueryContext = (&ctx).into(); + let restored = QueryContext::from(proto); + assert_eq!(restored.remote_query_id(), Some(query_id)); + assert_eq!( + restored.remote_query_id_value().unwrap().to_string(), + query_id + ); + } } diff --git a/src/session/src/hints.rs b/src/session/src/hints.rs index e2a5b5fff874..db51bc627971 100644 --- a/src/session/src/hints.rs +++ b/src/session/src/hints.rs @@ -16,8 +16,10 @@ pub const HINTS_KEY: &str = "x-greptime-hints"; /// Deprecated, use `HINTS_KEY` instead. Notes if "x-greptime-hints" is set, keys with this prefix will be ignored. pub const HINTS_KEY_PREFIX: &str = "x-greptime-hint-"; +pub const REMOTE_QUERY_ID_EXTENSION_KEY: &str = "remote_query_id"; pub const READ_PREFERENCE_HINT: &str = "read_preference"; +pub const RESERVED_EXTENSION_KEYS: [&str; 1] = [REMOTE_QUERY_ID_EXTENSION_KEY]; /// Deprecated, use `HINTS_KEY` instead. pub const HINT_KEYS: [&str; 7] = [ @@ -29,3 +31,18 @@ pub const HINT_KEYS: [&str; 7] = [ "x-greptime-hint-skip_wal", "x-greptime-hint-read_preference", ]; + +pub fn is_reserved_extension_key(key: &str) -> bool { + RESERVED_EXTENSION_KEYS.contains(&key) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_is_reserved_extension_key() { + assert!(is_reserved_extension_key(REMOTE_QUERY_ID_EXTENSION_KEY)); + assert!(!is_reserved_extension_key(READ_PREFERENCE_HINT)); + } +} diff --git a/src/session/src/lib.rs b/src/session/src/lib.rs index 8d2a3e214180..cba78a060ee8 100644 --- a/src/session/src/lib.rs +++ b/src/session/src/lib.rs @@ -15,6 +15,7 @@ pub mod context; pub mod hints; pub mod protocol_ctx; +pub mod query_id; pub mod session_config; pub mod table_name; @@ -30,7 +31,10 @@ use common_recordbatch::cursor::RecordBatchStreamCursor; pub use common_session::ReadPreference; use common_time::Timezone; use common_time::timezone::get_timezone; -use context::{ConfigurationVariables, QueryContextBuilder}; +use context::{ + ConfigurationVariables, QueryContextBuilder, REMOTE_QUERY_ID_EXTENSION_KEY, + generate_remote_query_id, +}; use derive_more::Debug; use crate::context::{Channel, ConnInfo, QueryContextRef}; @@ -106,6 +110,10 @@ impl Session { .channel(self.conn_info.channel) .process_id(self.process_id) .conn_info(self.conn_info.clone()) + .set_extension( + REMOTE_QUERY_ID_EXTENSION_KEY.to_string(), + generate_remote_query_id(), + ) .build() .into() } diff --git a/src/session/src/query_id.rs b/src/session/src/query_id.rs new file mode 100644 index 000000000000..220e52577ea6 --- /dev/null +++ b/src/session/src/query_id.rs @@ -0,0 +1,76 @@ +// Copyright 2023 Greptime Team +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use std::fmt::{Display, Formatter}; +use std::str::FromStr; + +use uuid::Uuid; + +#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)] +pub struct QueryId(Uuid); + +impl QueryId { + pub fn new() -> Self { + Self(Uuid::now_v7()) + } + + pub fn as_uuid(&self) -> &Uuid { + &self.0 + } +} + +impl Default for QueryId { + fn default() -> Self { + Self::new() + } +} + +impl Display for QueryId { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + self.0.fmt(f) + } +} + +impl FromStr for QueryId { + type Err = uuid::Error; + + fn from_str(s: &str) -> Result { + Ok(Self(Uuid::parse_str(s)?)) + } +} + +impl From for QueryId { + fn from(value: Uuid) -> Self { + Self(value) + } +} + +impl From for Uuid { + fn from(value: QueryId) -> Self { + value.0 + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn query_id_round_trips_through_string() { + let query_id = QueryId::new(); + let encoded = query_id.to_string(); + + assert_eq!(encoded.parse::().unwrap(), query_id); + } +} diff --git a/src/store-api/src/region_request.rs b/src/store-api/src/region_request.rs index 99d3a87dd320..33a66d0d0dcb 100644 --- a/src/store-api/src/region_request.rs +++ b/src/store-api/src/region_request.rs @@ -184,6 +184,10 @@ impl RegionRequest { reason: "ListMetadata request should be handled separately by RegionServer", } .fail(), + region_request::Body::RemoteDynFilter(_) => UnexpectedSnafu { + reason: "RemoteDynFilter request should be handled separately by RegionServer", + } + .fail(), region_request::Body::ApplyStagingManifest(apply) => { make_region_apply_staging_manifest(apply) }