diff --git a/src/client/src/database.rs b/src/client/src/database.rs index e12c2ec0fc63..e1985318dfdd 100644 --- a/src/client/src/database.rs +++ b/src/client/src/database.rs @@ -14,7 +14,9 @@ use std::pin::Pin; use std::str::FromStr; -use std::sync::Arc; +use std::sync::atomic::{AtomicBool, Ordering}; +use std::sync::{Arc, RwLock}; +use std::task::{Context, Poll}; use api::v1::auth_header::AuthScheme; use api::v1::ddl_request::Expr as DdlExpr; @@ -25,6 +27,7 @@ use api::v1::{ AlterTableExpr, AuthHeader, Basic, CreateTableExpr, DdlRequest, GreptimeRequest, InsertRequests, QueryRequest, RequestHeader, RowInsertRequests, }; +use arc_swap::ArcSwapOption; use arrow_flight::{FlightData, Ticket}; use async_stream::stream; use base64::Engine; @@ -33,17 +36,18 @@ use common_catalog::build_db_string; use common_catalog::consts::{DEFAULT_CATALOG_NAME, DEFAULT_SCHEMA_NAME}; use common_error::ext::BoxedError; use common_grpc::flight::do_put::DoPutResponse; -use common_grpc::flight::{FlightDecoder, FlightMessage}; +use common_grpc::flight::{FLOW_EXTENSIONS_METADATA_KEY, FlightDecoder, FlightMessage}; use common_query::Output; +use common_recordbatch::adapter::RecordBatchMetrics; use common_recordbatch::error::ExternalSnafu; -use common_recordbatch::{RecordBatch, RecordBatchStreamWrapper}; +use common_recordbatch::{OrderOption, RecordBatch, RecordBatchStream, RecordBatchStreamWrapper}; use common_telemetry::tracing::Span; use common_telemetry::tracing_context::W3cTrace; use common_telemetry::{error, warn}; use futures::future; use futures_util::{Stream, StreamExt, TryStreamExt}; use prost::Message; -use snafu::{OptionExt, ResultExt, ensure}; +use snafu::{OptionExt, ResultExt}; use tonic::metadata::{AsciiMetadataKey, AsciiMetadataValue, MetadataMap, MetadataValue}; use tonic::transport::Channel; @@ -57,6 +61,315 @@ type FlightDataStream = Pin + Send>>; type DoPutResponseStream = Pin>>>; +/// Terminal metrics associated with a query output. +/// +/// For streaming outputs, metrics are only final after the stream is fully +/// drained and [`Self::is_ready`] returns `true`. Region watermark helpers keep +/// the RFC distinction between proved regions, unproved participating regions, +/// and non-participating regions. +#[derive(Debug, Clone, Default)] +pub struct OutputMetrics { + inner: Arc, +} + +#[derive(Debug, Default)] +struct OutputMetricsInner { + metrics: RwLock>, + ready: AtomicBool, +} + +impl OutputMetrics { + fn new() -> Self { + Self::default() + } + + /// Replaces the current terminal metrics snapshot. + pub fn update(&self, metrics: Option) { + *self.inner.metrics.write().unwrap() = metrics; + } + + /// Marks the terminal metrics as final for this output. + pub fn mark_ready(&self) { + let _ = self + .inner + .ready + .compare_exchange(false, true, Ordering::AcqRel, Ordering::Acquire); + } + + /// Returns whether terminal metrics are final. + /// + /// Streaming outputs become ready only after the stream reaches EOF. + pub fn is_ready(&self) -> bool { + self.inner.ready.load(Ordering::Acquire) + } + + /// Returns the latest terminal metrics snapshot, if any. + pub fn get(&self) -> Option { + self.inner.metrics.read().unwrap().clone() + } + + /// Returns proved per-region watermarks. + /// + /// Entries whose watermark is `None` are intentionally omitted because they + /// represent participating regions whose terminal sequence bound was not + /// provable. + pub fn region_watermark_map(&self) -> Option> { + Some( + self.get()? + .region_watermarks + .into_iter() + .filter_map(|entry| entry.watermark.map(|seq| (entry.region_id, seq))) + .collect::>(), + ) + } + + /// Returns all regions that participated in terminal metric collection, + /// including entries whose watermark is `None`. + pub fn participating_regions(&self) -> Option> { + Some( + self.get()? + .region_watermarks + .into_iter() + .map(|entry| entry.region_id) + .collect::>(), + ) + } +} + +/// Query output together with a handle for its terminal metrics. +/// +/// The contained [`OutputMetrics`] lets callers read stream terminal metrics +/// after consuming `output`. For non-stream outputs, metrics are ready +/// immediately. +#[derive(Debug)] +pub struct OutputWithMetrics { + pub output: Output, + pub metrics: OutputMetrics, +} + +impl OutputWithMetrics { + /// Wraps an output with a terminal metrics handle. + /// + /// Stream outputs update the handle as the stream is consumed. Non-stream + /// outputs are marked ready immediately. + pub fn from_output(output: Output) -> Self { + let terminal_metrics = OutputMetrics::new(); + let output = attach_terminal_metrics(output, &terminal_metrics); + Self { + output, + metrics: terminal_metrics, + } + } + + /// Returns proved per-region watermarks from the terminal metrics. + pub fn region_watermark_map(&self) -> Option> { + self.metrics.region_watermark_map() + } + + /// Returns all regions participating in terminal metric collection. + pub fn participating_regions(&self) -> Option> { + self.metrics.participating_regions() + } + + /// Drops the terminal metrics handle and returns the original output. + pub fn into_output(self) -> Output { + self.output + } +} + +fn parse_terminal_metrics(metrics_json: &str) -> Result { + serde_json::from_str(metrics_json).map_err(|e| { + IllegalFlightMessagesSnafu { + reason: format!("Invalid terminal metrics message: {e}"), + } + .build() + }) +} + +struct StreamWithMetrics { + stream: common_recordbatch::SendableRecordBatchStream, + metrics: OutputMetrics, +} + +impl StreamWithMetrics { + fn new(stream: common_recordbatch::SendableRecordBatchStream, metrics: OutputMetrics) -> Self { + Self { stream, metrics } + } + + fn sync_terminal_metrics(&self) { + self.metrics.update(self.stream.metrics()); + } +} + +impl RecordBatchStream for StreamWithMetrics { + fn name(&self) -> &str { + self.stream.name() + } + + fn schema(&self) -> datatypes::schema::SchemaRef { + self.stream.schema() + } + + fn output_ordering(&self) -> Option<&[OrderOption]> { + self.stream.output_ordering() + } + + fn metrics(&self) -> Option { + self.sync_terminal_metrics(); + self.metrics.get() + } +} + +impl Stream for StreamWithMetrics { + type Item = common_recordbatch::error::Result; + + fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let polled = Pin::new(&mut self.stream).poll_next(cx); + if let Poll::Ready(None) = &polled { + self.sync_terminal_metrics(); + self.metrics.mark_ready(); + } + polled + } + + fn size_hint(&self) -> (usize, Option) { + self.stream.size_hint() + } +} + +fn attach_terminal_metrics(output: Output, terminal_metrics: &OutputMetrics) -> Output { + let Output { data, meta } = output; + let data = match data { + common_query::OutputData::Stream(stream) => { + terminal_metrics.update(stream.metrics()); + common_query::OutputData::Stream(Box::pin(StreamWithMetrics::new( + stream, + terminal_metrics.clone(), + ))) + } + other => { + terminal_metrics.mark_ready(); + other + } + }; + Output::new(data, meta) +} + +async fn output_from_flight_message_stream( + mut flight_message_stream: S, +) -> Result +where + S: Stream> + Send + Unpin + 'static, +{ + let Some(first_flight_message) = flight_message_stream.next().await else { + return IllegalFlightMessagesSnafu { + reason: "Expect the response not to be empty", + } + .fail(); + }; + + let first_flight_message = first_flight_message?; + + match first_flight_message { + FlightMessage::AffectedRows { rows, metrics } => { + let terminal_metrics = OutputMetrics::new(); + if let Some(metrics) = metrics { + terminal_metrics.update(Some(parse_terminal_metrics(&metrics)?)); + } + let next_message = flight_message_stream.next().await.transpose()?; + match next_message { + None => terminal_metrics.mark_ready(), + Some(FlightMessage::Metrics(s)) if terminal_metrics.get().is_none() => { + terminal_metrics.update(Some(parse_terminal_metrics(&s)?)); + terminal_metrics.mark_ready(); + } + Some(FlightMessage::Metrics(_)) => { + return IllegalFlightMessagesSnafu { + reason: "'AffectedRows' Flight metadata already carries Metrics and cannot be followed by another Metrics message".to_string(), + } + .fail(); + } + Some(other) => { + return IllegalFlightMessagesSnafu { + reason: format!( + "'AffectedRows' Flight message can only be followed by a Metrics message, got {other:?}" + ), + } + .fail(); + } + } + Ok(OutputWithMetrics { + output: Output::new_with_affected_rows(rows), + metrics: terminal_metrics, + }) + } + FlightMessage::RecordBatch(_) | FlightMessage::Metrics(_) => IllegalFlightMessagesSnafu { + reason: "The first flight message cannot be a RecordBatch or Metrics message", + } + .fail(), + FlightMessage::Schema(schema) => { + let metrics = Arc::new(ArcSwapOption::from(None)); + let metrics_ref = metrics.clone(); + let schema = Arc::new( + datatypes::schema::Schema::try_from(schema).context(error::ConvertSchemaSnafu)?, + ); + let schema_cloned = schema.clone(); + let stream = Box::pin(stream!({ + while let Some(flight_message_item) = flight_message_stream.next().await { + let flight_message = match flight_message_item { + Ok(message) => message, + Err(e) => { + yield Err(BoxedError::new(e)).context(ExternalSnafu); + break; + } + }; + + match flight_message { + FlightMessage::RecordBatch(arrow_batch) => { + yield Ok(RecordBatch::from_df_record_batch( + schema_cloned.clone(), + arrow_batch, + )) + } + FlightMessage::Metrics(s) => { + match parse_terminal_metrics(&s) { + Ok(m) => { + metrics_ref.swap(Some(Arc::new(m))); + } + Err(e) => { + yield Err(BoxedError::new(e)).context(ExternalSnafu); + } + }; + } + FlightMessage::AffectedRows { .. } | FlightMessage::Schema(_) => { + yield IllegalFlightMessagesSnafu { + reason: format!( + "A Schema message must be succeeded exclusively by a set of RecordBatch messages, flight_message: {:?}", + flight_message + ) + } + .fail() + .map_err(BoxedError::new) + .context(ExternalSnafu); + break; + } + } + } + })); + let record_batch_stream = RecordBatchStreamWrapper { + schema, + stream, + output_ordering: None, + metrics, + span: Span::current(), + }; + Ok(OutputWithMetrics::from_output(Output::new_with_stream( + Box::pin(record_batch_stream), + ))) + } + } +} + #[derive(Clone, Debug, Default)] pub struct Database { // The "catalog" and "schema" to be used in processing the requests at the server side. @@ -238,6 +551,22 @@ impl Database { Ok(()) } + fn put_flow_extensions( + metadata: &mut MetadataMap, + flow_extensions: &[(&str, &str)], + ) -> Result<()> { + if flow_extensions.is_empty() { + return Ok(()); + } + + let value = serde_json::to_string(&flow_extensions.to_vec()) + .expect("flow extension pairs should serialize"); + let key = AsciiMetadataKey::from_static(FLOW_EXTENSIONS_METADATA_KEY); + let value = AsciiMetadataValue::from_str(&value).context(InvalidTonicMetadataValueSnafu)?; + metadata.insert(key, value); + Ok(()) + } + /// Make a request to the database. pub async fn handle(&self, request: Request) -> Result { let mut client = make_database_client(&self.client)?; @@ -333,15 +662,58 @@ impl Database { let request = Request::Query(QueryRequest { query: Some(Query::Sql(sql.as_ref().to_string())), }); - self.do_get(request, hints).await + self.do_get(request, hints, &[]) + .await + .map(OutputWithMetrics::into_output) + } + + /// Executes a SQL query and returns the output with terminal metrics. + /// + /// For stream outputs, callers must consume the stream before reading final + /// terminal metrics from [`OutputWithMetrics::metrics`]. + pub async fn sql_with_terminal_metrics( + &self, + sql: S, + hints: &[(&str, &str)], + ) -> Result + where + S: AsRef, + { + self.query_with_terminal_metrics_and_flow_extensions( + QueryRequest { + query: Some(Query::Sql(sql.as_ref().to_string())), + }, + hints, + &[], + ) + .await } /// Executes a logical plan directly without SQL parsing. pub async fn logical_plan(&self, logical_plan: Vec) -> Result { - let request = Request::Query(QueryRequest { - query: Some(Query::LogicalPlan(logical_plan)), - }); - self.do_get(request, &[]).await + self.query_with_terminal_metrics_and_flow_extensions( + QueryRequest { + query: Some(Query::LogicalPlan(logical_plan)), + }, + &[], + &[], + ) + .await + .map(OutputWithMetrics::into_output) + } + + /// Executes a query and carries flow extensions through Flight metadata. + /// + /// This is the lower-level terminal-metrics API for Flow callers that need + /// to pass JSON-bearing flow extensions without going through hint metadata. + pub async fn query_with_terminal_metrics_and_flow_extensions( + &self, + request: QueryRequest, + hints: &[(&str, &str)], + flow_extensions: &[(&str, &str)], + ) -> Result { + self.do_get(Request::Query(request), hints, flow_extensions) + .await } /// Creates a new table using the provided table expression. @@ -349,7 +721,9 @@ impl Database { let request = Request::Ddl(DdlRequest { expr: Some(DdlExpr::CreateTable(expr)), }); - self.do_get(request, &[]).await + self.do_get(request, &[], &[]) + .await + .map(OutputWithMetrics::into_output) } /// Alters an existing table using the provided alter expression. @@ -357,17 +731,26 @@ impl Database { let request = Request::Ddl(DdlRequest { expr: Some(DdlExpr::AlterTable(expr)), }); - self.do_get(request, &[]).await + self.do_get(request, &[], &[]) + .await + .map(OutputWithMetrics::into_output) } - async fn do_get(&self, request: Request, hints: &[(&str, &str)]) -> Result { + async fn do_get( + &self, + request: Request, + hints: &[(&str, &str)], + flow_extensions: &[(&str, &str)], + ) -> Result { let request = self.to_rpc_request(request); let request = Ticket { ticket: request.encode_to_vec().into(), }; let mut request = tonic::Request::new(request); - Self::put_hints(request.metadata_mut(), hints)?; + let metadata = request.metadata_mut(); + Self::put_hints(metadata, hints)?; + Self::put_flow_extensions(metadata, flow_extensions)?; let mut client = self.client.make_flight_client(false, false)?; @@ -389,7 +772,7 @@ impl Database { let flight_data_stream = response.into_inner(); let mut decoder = FlightDecoder::default(); - let mut flight_message_stream = flight_data_stream.map(move |flight_data| { + let flight_message_stream = flight_data_stream.map(move |flight_data| { flight_data .map_err(Error::from) .and_then(|data| decoder.try_decode(&data).context(ConvertFlightDataSnafu))? @@ -398,70 +781,7 @@ impl Database { }) }); - let Some(first_flight_message) = flight_message_stream.next().await else { - return IllegalFlightMessagesSnafu { - reason: "Expect the response not to be empty", - } - .fail(); - }; - - let first_flight_message = first_flight_message?; - - match first_flight_message { - FlightMessage::AffectedRows(rows) => { - ensure!( - flight_message_stream.next().await.is_none(), - IllegalFlightMessagesSnafu { - reason: "Expect 'AffectedRows' Flight messages to be the one and the only!" - } - ); - Ok(Output::new_with_affected_rows(rows)) - } - FlightMessage::RecordBatch(_) | FlightMessage::Metrics(_) => { - IllegalFlightMessagesSnafu { - reason: "The first flight message cannot be a RecordBatch or Metrics message", - } - .fail() - } - FlightMessage::Schema(schema) => { - let schema = Arc::new( - datatypes::schema::Schema::try_from(schema) - .context(error::ConvertSchemaSnafu)?, - ); - let schema_cloned = schema.clone(); - let stream = Box::pin(stream!({ - while let Some(flight_message) = flight_message_stream.next().await { - let flight_message = flight_message - .map_err(BoxedError::new) - .context(ExternalSnafu)?; - match flight_message { - FlightMessage::RecordBatch(arrow_batch) => { - yield Ok(RecordBatch::from_df_record_batch( - schema_cloned.clone(), - arrow_batch, - )) - } - FlightMessage::Metrics(_) => {} - FlightMessage::AffectedRows(_) | FlightMessage::Schema(_) => { - yield IllegalFlightMessagesSnafu {reason: format!("A Schema message must be succeeded exclusively by a set of RecordBatch messages, flight_message: {:?}", flight_message)} - .fail() - .map_err(BoxedError::new) - .context(ExternalSnafu); - break; - } - } - } - })); - let record_batch_stream = RecordBatchStreamWrapper { - schema, - stream, - output_ordering: None, - metrics: Default::default(), - span: Span::current(), - }; - Ok(Output::new_with_stream(Box::pin(record_batch_stream))) - } - } + output_from_flight_message_stream(flight_message_stream).await } /// Ingest a stream of [RecordBatch]es that belong to a table, using Arrow Flight's "`DoPut`" @@ -512,16 +832,104 @@ struct FlightContext { #[cfg(test)] mod tests { - use std::assert_matches; + use std::sync::Arc; + use std::task::{Context, Poll}; use api::v1::auth_header::AuthScheme; use api::v1::{AuthHeader, Basic}; use common_error::status_code::StatusCode; + use common_query::OutputData; + use common_recordbatch::{OrderOption, RecordBatch, RecordBatchStream}; + use datatypes::prelude::{ConcreteDataType, VectorRef}; + use datatypes::schema::{ColumnSchema, Schema}; + use datatypes::vectors::Int32Vector; + use futures_util::StreamExt; use tonic::{Code, Status}; use super::*; use crate::error::TonicSnafu; + struct MockMetricsStream { + schema: datatypes::schema::SchemaRef, + batch: Option, + metrics: RecordBatchMetrics, + terminal_metrics_only: bool, + } + + impl Stream for MockMetricsStream { + type Item = common_recordbatch::error::Result; + + fn poll_next(mut self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { + Poll::Ready(self.batch.take().map(Ok)) + } + } + + impl RecordBatchStream for MockMetricsStream { + fn name(&self) -> &str { + "MockMetricsStream" + } + + fn schema(&self) -> datatypes::schema::SchemaRef { + self.schema.clone() + } + + fn output_ordering(&self) -> Option<&[OrderOption]> { + None + } + + fn metrics(&self) -> Option { + if self.terminal_metrics_only && self.batch.is_some() { + return None; + } + Some(self.metrics.clone()) + } + } + + fn terminal_metrics_json() -> String { + terminal_metrics_json_with_seq(42) + } + + fn terminal_metrics_json_with_seq(seq: u64) -> String { + serde_json::to_string(&RecordBatchMetrics { + region_watermarks: vec![common_recordbatch::adapter::RegionWatermarkEntry { + region_id: 7, + watermark: Some(seq), + }], + ..Default::default() + }) + .unwrap() + } + + #[test] + fn test_put_flow_extensions_preserves_comma_bearing_values() { + let mut metadata = MetadataMap::new(); + Database::put_flow_extensions( + &mut metadata, + &[ + ("flow.return_region_seq", "true"), + ("flow.incremental_after_seqs", r#"{"1":10,"2":20}"#), + ], + ) + .unwrap(); + + let value = metadata + .get(FLOW_EXTENSIONS_METADATA_KEY) + .unwrap() + .to_str() + .unwrap(); + let decoded: Vec<(String, String)> = serde_json::from_str(value).unwrap(); + assert_eq!( + decoded, + vec![ + ("flow.return_region_seq".to_string(), "true".to_string()), + ( + "flow.incremental_after_seqs".to_string(), + r#"{"1":10,"2":20}"#.to_string() + ), + ] + ); + } + #[test] fn test_flight_ctx() { let mut ctx = FlightContext::default(); @@ -536,12 +944,12 @@ mod tests { auth_scheme: Some(basic), }); - assert_matches!( + assert!(matches!( ctx.auth_header, Some(AuthHeader { auth_scheme: Some(AuthScheme::Basic(_)), }) - ) + )); } #[test] @@ -558,4 +966,198 @@ mod tests { assert_eq!(expected.to_string(), actual.to_string()); } + + #[tokio::test] + async fn test_query_with_terminal_metrics_tracks_terminal_only_metrics() { + let schema = Arc::new(Schema::new(vec![ColumnSchema::new( + "v", + ConcreteDataType::int32_datatype(), + false, + )])); + let batch = RecordBatch::new( + schema.clone(), + vec![Arc::new(Int32Vector::from_slice([1, 2])) as VectorRef], + ) + .unwrap(); + let output = Output::new_with_stream(Box::pin(MockMetricsStream { + schema, + batch: Some(batch), + metrics: RecordBatchMetrics { + region_watermarks: vec![common_recordbatch::adapter::RegionWatermarkEntry { + region_id: 7, + watermark: Some(42), + }], + ..Default::default() + }, + terminal_metrics_only: true, + })); + + let result = OutputWithMetrics::from_output(output); + let terminal_metrics = result.metrics.clone(); + assert!(!terminal_metrics.is_ready()); + assert!(terminal_metrics.get().is_none()); + + let OutputData::Stream(mut stream) = result.output.data else { + panic!("expected stream output"); + }; + while stream.next().await.is_some() {} + + assert!(terminal_metrics.is_ready()); + assert_eq!( + terminal_metrics.participating_regions(), + Some(std::collections::BTreeSet::from([7_u64])) + ); + assert_eq!( + terminal_metrics.region_watermark_map(), + Some(std::collections::HashMap::from([(7_u64, 42_u64)])) + ); + } + + #[test] + fn test_parse_terminal_metrics_rejects_invalid_json() { + assert!(parse_terminal_metrics("{not-json}").is_err()); + } + + #[tokio::test] + async fn test_affected_rows_inline_metrics_are_parsed() { + let output = output_from_flight_message_stream(futures_util::stream::iter(vec![Ok( + FlightMessage::AffectedRows { + rows: 3, + metrics: Some(terminal_metrics_json()), + }, + )] + as Vec>)) + .await + .unwrap(); + + assert!(matches!(output.output.data, OutputData::AffectedRows(3))); + assert!(output.metrics.is_ready()); + assert_eq!( + output.metrics.region_watermark_map(), + Some(std::collections::HashMap::from([(7, 42)])) + ); + } + + #[tokio::test] + async fn test_affected_rows_inline_metrics_rejects_trailing_metrics() { + let metrics_json = terminal_metrics_json(); + let err = output_from_flight_message_stream(futures_util::stream::iter(vec![ + Ok(FlightMessage::AffectedRows { + rows: 3, + metrics: Some(metrics_json.clone()), + }), + Ok(FlightMessage::Metrics(metrics_json)), + ] + as Vec>)) + .await + .unwrap_err(); + + assert!( + err.to_string().contains("already carries Metrics"), + "unexpected error: {err:?}" + ); + } + + #[tokio::test] + async fn test_invalid_terminal_metrics_after_record_batch_yields_batch_then_error() { + let schema = Arc::new(Schema::new(vec![ColumnSchema::new( + "v", + ConcreteDataType::int32_datatype(), + false, + )])); + let batch = RecordBatch::new( + schema.clone(), + vec![Arc::new(Int32Vector::from_slice([1])) as VectorRef], + ) + .unwrap(); + let output = output_from_flight_message_stream(futures_util::stream::iter(vec![ + Ok(FlightMessage::Schema(schema.arrow_schema().clone())), + Ok(FlightMessage::RecordBatch(batch.into_df_record_batch())), + Ok(FlightMessage::Metrics("{not-json}".to_string())), + ] + as Vec>)) + .await + .unwrap(); + let terminal_metrics = output.metrics.clone(); + let OutputData::Stream(mut record_batch_stream) = output.output.data else { + panic!("expected stream output"); + }; + + let batch = record_batch_stream.next().await.unwrap().unwrap(); + assert_eq!(batch.num_rows(), 1); + + let err = record_batch_stream.next().await.unwrap().unwrap_err(); + assert_eq!("External error", err.to_string()); + assert!( + format!("{err:?}").contains("Invalid terminal metrics message"), + "unexpected error: {err:?}" + ); + assert!(record_batch_stream.next().await.is_none()); + assert!(terminal_metrics.is_ready()); + assert!(terminal_metrics.get().is_none()); + } + + #[tokio::test] + async fn test_record_batch_stream_continues_after_partial_metrics() { + let schema = Arc::new(Schema::new(vec![ColumnSchema::new( + "v", + ConcreteDataType::int32_datatype(), + false, + )])); + let first_batch = RecordBatch::new( + schema.clone(), + vec![Arc::new(Int32Vector::from_slice([1])) as VectorRef], + ) + .unwrap(); + let second_batch = RecordBatch::new( + schema.clone(), + vec![Arc::new(Int32Vector::from_slice([2])) as VectorRef], + ) + .unwrap(); + let output = output_from_flight_message_stream(futures_util::stream::iter(vec![ + Ok(FlightMessage::Schema(schema.arrow_schema().clone())), + Ok(FlightMessage::RecordBatch( + first_batch.into_df_record_batch(), + )), + Ok(FlightMessage::Metrics(terminal_metrics_json_with_seq(1))), + Ok(FlightMessage::RecordBatch( + second_batch.into_df_record_batch(), + )), + Ok(FlightMessage::Metrics(terminal_metrics_json_with_seq(2))), + ] + as Vec>)) + .await + .unwrap(); + let terminal_metrics = output.metrics.clone(); + let OutputData::Stream(mut record_batch_stream) = output.output.data else { + panic!("expected stream output"); + }; + + let first_batch = record_batch_stream.next().await.unwrap().unwrap(); + assert_eq!(first_batch.num_rows(), 1); + let second_batch = record_batch_stream.next().await.unwrap().unwrap(); + assert_eq!(second_batch.num_rows(), 1); + assert!(record_batch_stream.next().await.is_none()); + + assert!(terminal_metrics.is_ready()); + assert_eq!( + terminal_metrics.region_watermark_map(), + Some(std::collections::HashMap::from([(7, 2)])) + ); + } + + #[test] + fn test_output_metrics_distinguishes_empty_region_watermarks_from_absence() { + let metrics = OutputMetrics::default(); + metrics.update(Some(RecordBatchMetrics::default())); + + assert_eq!( + metrics.participating_regions(), + Some(std::collections::BTreeSet::new()) + ); + assert_eq!( + metrics.region_watermark_map(), + Some(std::collections::HashMap::new()) + ); + } } diff --git a/src/client/src/lib.rs b/src/client/src/lib.rs index 0c9334b7d4a2..147dffc1456f 100644 --- a/src/client/src/lib.rs +++ b/src/client/src/lib.rs @@ -32,7 +32,7 @@ pub use common_recordbatch::{RecordBatches, SendableRecordBatchStream}; use snafu::OptionExt; pub use self::client::Client; -pub use self::database::Database; +pub use self::database::{Database, OutputMetrics, OutputWithMetrics}; pub use self::error::{Error, Result}; use crate::error::{IllegalDatabaseResponseSnafu, ServerSnafu}; diff --git a/src/common/grpc/src/flight.rs b/src/common/grpc/src/flight.rs index 5fc115a60e91..c7860f5706dd 100644 --- a/src/common/grpc/src/flight.rs +++ b/src/common/grpc/src/flight.rs @@ -37,11 +37,17 @@ use vec1::{Vec1, vec1}; use crate::error; use crate::error::{DecodeFlightDataSnafu, InvalidFlightDataSnafu, Result}; +/// Flight metadata key used to carry flow query extensions as JSON pairs. +pub const FLOW_EXTENSIONS_METADATA_KEY: &str = "x-greptime-flow-extensions"; + #[derive(Debug, Clone)] pub enum FlightMessage { Schema(SchemaRef), RecordBatch(DfRecordBatch), - AffectedRows(usize), + AffectedRows { + rows: usize, + metrics: Option, + }, Metrics(String), } @@ -116,10 +122,12 @@ impl FlightEncoder { encoded_batch.into(), ) } - FlightMessage::AffectedRows(rows) => { + FlightMessage::AffectedRows { rows, metrics } => { let metadata = FlightMetadata { affected_rows: Some(AffectedRows { value: rows as _ }), - metrics: None, + metrics: metrics.map(|s| Metrics { + metrics: s.into_bytes(), + }), } .encode_to_vec(); vec1![FlightData { @@ -223,7 +231,12 @@ impl FlightDecoder { let metadata = FlightMetadata::decode(flight_data.app_metadata.clone()) .context(DecodeFlightDataSnafu)?; if let Some(AffectedRows { value }) = metadata.affected_rows { - return Ok(Some(FlightMessage::AffectedRows(value as _))); + return Ok(Some(FlightMessage::AffectedRows { + rows: value as _, + metrics: metadata + .metrics + .map(|m| String::from_utf8_lossy(&m.metrics).to_string()), + })); } if let Some(Metrics { metrics }) = metadata.metrics { return Ok(Some(FlightMessage::Metrics( @@ -426,6 +439,47 @@ mod test { Ok(()) } + #[test] + fn test_affected_rows_metrics_encode_decode() -> Result<()> { + let metrics = r#"{"region_watermarks":[{"region_id":42,"watermark":7}]}"#; + let mut encoder = FlightEncoder::default(); + let encoded = encoder.encode(FlightMessage::AffectedRows { + rows: 3, + metrics: Some(metrics.to_string()), + }); + + assert_eq!(encoded.len(), 1); + + let mut decoder = FlightDecoder::default(); + let decoded = decoder.try_decode(encoded.first())?.unwrap(); + let FlightMessage::AffectedRows { + rows, + metrics: decoded_metrics, + } = decoded + else { + unreachable!() + }; + assert_eq!(rows, 3); + assert_eq!(decoded_metrics.as_deref(), Some(metrics)); + + let encoded = encoder.encode(FlightMessage::AffectedRows { + rows: 5, + metrics: None, + }); + let decoded = decoder.try_decode(encoded.first())?.unwrap(); + let FlightMessage::AffectedRows { + rows, + metrics: decoded_metrics, + } = decoded + else { + unreachable!() + }; + assert_eq!(rows, 5); + assert!(decoded_metrics.is_none()); + + Ok(()) + } + #[test] fn test_flight_messages_to_recordbatches() { let schema = Arc::new(Schema::new(vec![Field::new("m", DataType::Int32, true)])); @@ -548,4 +602,76 @@ mod test { assert_eq!(actual, expected.trim()); Ok(()) } + + #[test] + fn test_affected_rows_roundtrip_through_flight_codec() { + // Verify the full FlightEncoder → FlightDecoder pipeline handles + // the new FlightMessage::AffectedRows variant with optional inline + // metrics without breaking the wire protocol. + let mut encoder = FlightEncoder::default(); + let mut decoder = FlightDecoder::default(); + + // Without metrics — same wire format as old `AffectedRows(7)`. + let encoded = encoder.encode(FlightMessage::AffectedRows { + rows: 7, + metrics: None, + }); + let decoded = decoder.try_decode(encoded.first()).unwrap().unwrap(); + assert!(matches!( + decoded, + FlightMessage::AffectedRows { + rows: 7, + metrics: None, + } + )); + + // With metrics — new capability, row count preserved. + let json = r#"{"region_watermarks":[{"region_id":1,"watermark":99}]}"#; + let encoded = encoder.encode(FlightMessage::AffectedRows { + rows: 42, + metrics: Some(json.to_string()), + }); + let decoded = decoder.try_decode(encoded.first()).unwrap().unwrap(); + assert!(matches!( + decoded, + FlightMessage::AffectedRows { + rows: 42, + metrics: Some(_), + } + )); + } + + /// Simulates the wire output of the **old** `FlightMessage::AffectedRows(usize)` + /// variant and verifies that the **new** `FlightDecoder` handles it. + #[test] + fn test_old_affected_rows_format_decoded_by_new_code() { + use arrow_flight::FlightData; + use prost::bytes::Bytes as ProstBytes; + + // The old encoder produced FlightData whose app_metadata is + // FlightMetadata { affected_rows, metrics: None }. The new + // `AffectedRows { rows, metrics: Option }` variant with + // `metrics: None` produces the exact same wire bytes. + let old_wire_bytes = FlightData { + flight_descriptor: None, + data_header: build_none_flight_msg().into(), + app_metadata: FlightMetadata { + affected_rows: Some(AffectedRows { value: 99 }), + metrics: None, // old format: no metrics field + } + .encode_to_vec() + .into(), + data_body: ProstBytes::default(), + }; + + let mut decoder = FlightDecoder::default(); + let decoded = decoder.try_decode(&old_wire_bytes).unwrap().unwrap(); + assert!(matches!( + decoded, + FlightMessage::AffectedRows { + rows: 99, + metrics: None, + } + )); + } } diff --git a/src/datanode/src/region_server.rs b/src/datanode/src/region_server.rs index aa3ffbfe3a1c..f16c83a84bf1 100644 --- a/src/datanode/src/region_server.rs +++ b/src/datanode/src/region_server.rs @@ -17,8 +17,10 @@ mod catalog; use std::collections::HashMap; use std::fmt::Debug; use std::ops::Deref; +use std::pin::Pin; use std::sync::atomic::{AtomicBool, Ordering}; use std::sync::{Arc, RwLock}; +use std::task::{Context, Poll}; use std::time::Duration; use api::region::RegionResponse; @@ -36,7 +38,8 @@ use common_error::status_code::StatusCode; use common_meta::datanode::TopicStatsReporter; use common_query::OutputData; use common_query::request::QueryRequest; -use common_recordbatch::SendableRecordBatchStream; +use common_recordbatch::adapter::RecordBatchMetrics; +use common_recordbatch::{OrderOption, RecordBatch, RecordBatchStream, SendableRecordBatchStream}; use common_runtime::Runtime; use common_telemetry::tracing::{self, info_span}; use common_telemetry::tracing_context::{FutureExt, TracingContext}; @@ -45,6 +48,7 @@ use dashmap::DashMap; use datafusion::datasource::TableProvider; use datafusion_common::tree_node::TreeNode; use either::Either; +use futures_util::Stream; use futures_util::future::try_join_all; use metric_engine::engine::MetricEngine; use mito2::engine::{MITO_ENGINE_NAME, MitoEngine}; @@ -53,6 +57,7 @@ use query::QueryEngineRef; pub use query::dummy_catalog::{ DummyCatalogList, DummyTableProviderFactory, TableProviderFactoryRef, }; +use query::options::should_collect_region_watermark_from_extensions; use serde_json; use servers::error::{ self as servers_error, ExecuteGrpcRequestSnafu, Result as ServerResult, SuspendedSnafu, @@ -278,16 +283,31 @@ impl RegionServer { .await .context(DecodeLogicalPlanSnafu)?; - self.inner + let stream = self + .inner .handle_read( QueryRequest { header: request.header, region_id, plan, }, - query_ctx, + query_ctx.clone(), ) - .await + .await?; + + let region_latest_seq = + if should_collect_region_watermark_from_extensions(&query_ctx.extensions()) { + query_ctx.get_snapshot(region_id.as_u64()) + } else { + None + }; + + if let Some(seq) = region_latest_seq { + Ok(Box::pin(RegionWatermarkStream::new(stream, region_id, seq)) + as SendableRecordBatchStream) + } else { + Ok(stream) + } } #[tracing::instrument(skip_all)] @@ -749,6 +769,80 @@ impl RegionServer { } } +/// Wraps a region read stream so terminal metrics can carry the scan-open watermark. +struct RegionWatermarkStream { + stream: SendableRecordBatchStream, + region_id: u64, + snapshot_seq: u64, + finished: bool, +} + +impl RegionWatermarkStream { + fn new(stream: SendableRecordBatchStream, region_id: RegionId, snapshot_seq: u64) -> Self { + Self { + stream, + region_id: region_id.as_u64(), + snapshot_seq, + finished: false, + } + } + + fn merged_metrics(&self, mut metrics: RecordBatchMetrics) -> RecordBatchMetrics { + if metrics + .region_watermarks + .iter() + .any(|entry| entry.region_id == self.region_id) + { + return metrics; + } + + metrics + .region_watermarks + .push(common_recordbatch::adapter::RegionWatermarkEntry { + region_id: self.region_id, + watermark: Some(self.snapshot_seq), + }); + metrics + } +} + +impl RecordBatchStream for RegionWatermarkStream { + fn name(&self) -> &str { + self.stream.name() + } + + fn schema(&self) -> datatypes::schema::SchemaRef { + self.stream.schema() + } + + fn output_ordering(&self) -> Option<&[OrderOption]> { + self.stream.output_ordering() + } + + fn metrics(&self) -> Option { + let base = self.stream.metrics(); + if !self.finished { + return base; + } + + Some(self.merged_metrics(base.unwrap_or_default())) + } +} + +impl Stream for RegionWatermarkStream { + type Item = common_recordbatch::error::Result; + + fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + match Pin::new(&mut self.stream).poll_next(cx) { + Poll::Ready(None) => { + self.finished = true; + Poll::Ready(None) + } + other => other, + } + } +} + #[async_trait] impl RegionServerHandler for RegionServer { async fn handle(&self, request: region_request::Body) -> ServerResult { @@ -1669,10 +1763,16 @@ impl RegionAttribute { mod tests { use std::assert_matches; + use std::sync::Arc; use api::v1::SemanticType; use common_error::ext::ErrorExt; - use datatypes::prelude::ConcreteDataType; + use common_recordbatch::RecordBatches; + use common_recordbatch::adapter::{RecordBatchMetrics, RegionWatermarkEntry}; + use datatypes::prelude::{ConcreteDataType, VectorRef}; + use datatypes::schema::{ColumnSchema, Schema}; + use datatypes::vectors::Int32Vector; + use futures_util::StreamExt; use mito2::test_util::CreateRequestBuilder; use store_api::metadata::{ColumnMetadata, RegionMetadata, RegionMetadataBuilder}; use store_api::region_engine::RegionEngine; @@ -1685,6 +1785,69 @@ mod tests { use crate::error::Result; use crate::tests::{MockRegionEngine, mock_region_server}; + #[tokio::test] + async fn test_region_watermark_stream_only_sets_terminal_metrics() { + let schema = Arc::new(Schema::new(vec![ColumnSchema::new( + "v", + ConcreteDataType::int32_datatype(), + false, + )])); + let values: VectorRef = Arc::new(Int32Vector::from_slice([1, 2])); + let batch = RecordBatch::new(schema.clone(), vec![values]).unwrap(); + let stream = RecordBatches::try_new(schema, vec![batch]) + .unwrap() + .as_stream(); + + let region_id = RegionId::new(42, 7); + let wrapped = RegionWatermarkStream::new(stream, region_id, 99); + let mut pinned = Box::pin(wrapped); + + assert!(pinned.as_ref().get_ref().metrics().is_none()); + while pinned.next().await.is_some() {} + + let metrics = pinned.as_ref().get_ref().metrics().unwrap(); + assert_eq!( + metrics.region_watermarks, + vec![RegionWatermarkEntry { + region_id: region_id.as_u64(), + watermark: Some(99), + }] + ); + } + + #[test] + fn test_region_watermark_stream_preserves_unproved_watermark() { + let schema = Arc::new(Schema::new(vec![ColumnSchema::new( + "v", + ConcreteDataType::int32_datatype(), + false, + )])); + let values: VectorRef = Arc::new(Int32Vector::from_slice([1])); + let batch = RecordBatch::new(schema.clone(), vec![values]).unwrap(); + let stream = RecordBatches::try_new(schema, vec![batch]) + .unwrap() + .as_stream(); + + let region_id = RegionId::new(42, 7); + let wrapped = RegionWatermarkStream::new(stream, region_id, 99); + let metrics = RecordBatchMetrics { + region_watermarks: vec![RegionWatermarkEntry { + region_id: region_id.as_u64(), + watermark: None, + }], + ..Default::default() + }; + + let merged = wrapped.merged_metrics(metrics); + assert_eq!( + merged.region_watermarks, + vec![RegionWatermarkEntry { + region_id: region_id.as_u64(), + watermark: None, + }] + ); + } + #[tokio::test] async fn test_region_registering() { common_telemetry::init_default_ut_logging(); diff --git a/src/flow/src/batching_mode/frontend_client.rs b/src/flow/src/batching_mode/frontend_client.rs index 9875564c78ae..c29c52846bbc 100644 --- a/src/flow/src/batching_mode/frontend_client.rs +++ b/src/flow/src/batching_mode/frontend_client.rs @@ -196,6 +196,9 @@ impl DatabaseWithPeer { } impl FrontendClient { + // TODO: support more fine-grained load balancing strategies for frontend + // selection, such as AZ (availability zone) awareness, to prefer frontends + // in the same zone as the flownode and reduce cross-AZ latency. /// scan for available frontend from metadata pub(crate) async fn scan_for_frontend(&self) -> Result, Error> { let Self::Distributed { meta_client, .. } = self else { @@ -314,12 +317,7 @@ impl FrontendClient { database_client .handler .lock() - .map_err(|e| { - UnexpectedSnafu { - reason: format!("Failed to lock database client: {e}"), - } - .build() - })? + .unwrap() .as_ref() .context(UnexpectedSnafu { reason: "Standalone's frontend instance is not set", @@ -392,12 +390,7 @@ impl FrontendClient { database_client .handler .lock() - .map_err(|e| { - UnexpectedSnafu { - reason: format!("Failed to lock database client: {e}"), - } - .build() - })? + .unwrap() .as_ref() .context(UnexpectedSnafu { reason: "Standalone's frontend instance is not set", diff --git a/src/frontend/src/error.rs b/src/frontend/src/error.rs index d148e6aa1bf6..6f78d23e14ee 100644 --- a/src/frontend/src/error.rs +++ b/src/frontend/src/error.rs @@ -399,7 +399,7 @@ impl ErrorExt for Error { Error::PrometheusLabelValuesQueryPlan { source, .. } => source.status_code(), - Error::CollectRecordbatch { .. } => StatusCode::EngineExecuteQuery, + Error::CollectRecordbatch { source, .. } => source.status_code(), Error::SqlExecIntercepted { source, .. } => source.status_code(), Error::StartServer { source, .. } => source.status_code(), diff --git a/src/query/src/datafusion.rs b/src/query/src/datafusion.rs index 6fc78c59e5c6..5725cf493854 100644 --- a/src/query/src/datafusion.rs +++ b/src/query/src/datafusion.rs @@ -59,7 +59,10 @@ use crate::error::{ TableNotFoundSnafu, TableReadOnlySnafu, UnsupportedExprSnafu, }; use crate::executor::QueryExecutor; -use crate::metrics::{OnDone, QUERY_STAGE_ELAPSED}; +use crate::metrics::{ + OnDone, QUERY_STAGE_ELAPSED, maybe_attach_region_watermark_metrics, + should_collect_region_watermark_from_query_ctx, +}; use crate::physical_wrapper::PhysicalPlanWrapperRef; use crate::planner::{DfLogicalPlanner, LogicalPlanner}; use crate::query_engine::{DescribeResult, QueryEngineContext, QueryEngineState}; @@ -100,8 +103,10 @@ impl DatafusionQueryEngine { optimized_physical_plan }; + let stream = self.execute_stream(&ctx, &physical_plan)?; + Ok(Output::new( - OutputData::Stream(self.execute_stream(&ctx, &physical_plan)?), + OutputData::Stream(stream), OutputMeta::new_with_plan(physical_plan), )) } @@ -128,10 +133,10 @@ impl DatafusionQueryEngine { let table_name = dml.table_name.resolve(default_catalog, default_schema); let table = self.find_table(&table_name, &query_ctx).await?; - let output = self + let Output { data, meta } = self .exec_query_plan((*dml.input).clone(), query_ctx.clone()) .await?; - let mut stream = match output.data { + let mut stream = match data { OutputData::RecordBatches(batches) => batches.as_stream(), OutputData::Stream(stream) => stream, _ => unreachable!(), @@ -167,7 +172,7 @@ impl DatafusionQueryEngine { } Ok(Output::new( OutputData::AffectedRows(affected_rows), - OutputMeta::new_with_cost(insert_cost), + OutputMeta::new(meta.plan, insert_cost), )) } @@ -543,7 +548,10 @@ impl QueryExecutor for DatafusionQueryEngine { ctx: &QueryEngineContext, plan: &Arc, ) -> Result { - let explain_verbose = ctx.query_ctx().explain_verbose(); + let query_ctx = ctx.query_ctx(); + let explain_verbose = query_ctx.explain_verbose(); + let should_collect_region_watermark = + should_collect_region_watermark_from_query_ctx(&query_ctx)?; let output_partitions = plan.properties().output_partitioning().partition_count(); if explain_verbose { common_telemetry::info!("Executing query plan, output_partitions: {output_partitions}"); @@ -579,7 +587,11 @@ impl QueryExecutor for DatafusionQueryEngine { ); } }); - Ok(Box::pin(stream)) + Ok(maybe_attach_region_watermark_metrics( + Box::pin(stream), + plan.clone(), + should_collect_region_watermark, + )) } _ => { // merge into a single partition @@ -598,7 +610,7 @@ impl QueryExecutor for DatafusionQueryEngine { .map_err(BoxedError::new) .context(QueryExecutionSnafu)?; stream.set_metrics2(plan.clone()); - stream.set_explain_verbose(ctx.query_ctx().explain_verbose()); + stream.set_explain_verbose(explain_verbose); let stream = OnDone::new(Box::pin(stream), move || { let exec_cost = exec_timer.stop_and_record(); if explain_verbose { @@ -608,7 +620,11 @@ impl QueryExecutor for DatafusionQueryEngine { ); } }); - Ok(Box::pin(stream)) + Ok(maybe_attach_region_watermark_metrics( + Box::pin(stream), + plan.clone(), + should_collect_region_watermark, + )) } } } diff --git a/src/query/src/metrics.rs b/src/query/src/metrics.rs index 9a376d748cd8..7541b191faf4 100644 --- a/src/query/src/metrics.rs +++ b/src/query/src/metrics.rs @@ -26,8 +26,11 @@ use futures::Stream; use futures_util::ready; use lazy_static::lazy_static; use prometheus::*; +use session::context::QueryContextRef; use crate::dist_plan::MergeScanExec; +use crate::error::Result; +use crate::options::FlowQueryExtensions; /// Intermediate merge state for one participating region while collecting /// terminal correctness watermarks across merge-scan sub-stages. @@ -201,6 +204,27 @@ impl Stream for RegionWatermarkMetricsStream { } } +/// Returns whether terminal region watermark metrics should be collected for the query context. +pub fn should_collect_region_watermark_from_query_ctx(query_ctx: &QueryContextRef) -> Result { + Ok( + FlowQueryExtensions::parse_flow_extensions(&query_ctx.extensions())? + .is_some_and(|extensions| extensions.should_collect_region_watermark()), + ) +} + +/// Attaches terminal region watermark metrics to `stream` when collection is requested. +pub fn maybe_attach_region_watermark_metrics( + stream: SendableRecordBatchStream, + plan: Arc, + should_collect_region_watermark: bool, +) -> SendableRecordBatchStream { + if should_collect_region_watermark { + Box::pin(RegionWatermarkMetricsStream::new(stream, plan)) + } else { + stream + } +} + pub fn terminal_recordbatch_metrics_from_plan( plan: Arc, ) -> Option { @@ -215,6 +239,18 @@ pub fn terminal_recordbatch_metrics_from_plan( } } +/// Collects terminal record-batch metrics from `plan` only when requested. +pub fn terminal_recordbatch_metrics_from_plan_if_requested( + plan: Option>, + should_collect_region_watermark: bool, +) -> Option { + if should_collect_region_watermark { + plan.and_then(terminal_recordbatch_metrics_from_plan) + } else { + None + } +} + fn collect_region_watermarks(plan: Arc) -> Vec { let mut merged = BTreeMap::::new(); let mut stack = vec![plan]; @@ -230,57 +266,85 @@ fn collect_region_watermarks(plan: Arc) -> Vec, + entries: impl IntoIterator, +) { + for entry in entries { + merged + .entry(entry.region_id) + .and_modify(|existing| match entry.watermark { + None => match existing { + MergeState::Participated | MergeState::Proved(_) => { + *existing = MergeState::Unproved; + } + MergeState::Unproved | MergeState::Conflict { .. } => {} + }, + Some(seq) => match existing { + MergeState::Participated => { + *existing = MergeState::Proved(seq); + } + MergeState::Unproved => {} + MergeState::Proved(existing_seq) if *existing_seq == seq => {} + MergeState::Proved(existing_seq) => { + let old_seq = *existing_seq; + *existing = MergeState::Conflict { + watermarks: vec![old_seq, seq], + }; + } + MergeState::Conflict { watermarks } => { + if !watermarks.contains(&seq) { + watermarks.push(seq); + } + } + }, + }) + .or_insert(match entry.watermark { + Some(seq) => MergeState::Proved(seq), + None => MergeState::Unproved, + }); + } +} + fn merge_merge_scan_region_watermarks( merged: &mut BTreeMap, regions: impl IntoIterator, sub_stage_metrics: impl IntoIterator, ) { + // Regions listed by MergeScanExec participated even when no sub-stage can + // prove a watermark. Keep them as explicit `None` entries so callers can + // distinguish unproved participation from non-participation. for region_id in regions { merged.entry(region_id).or_insert(MergeState::Participated); } for metrics in sub_stage_metrics { - for entry in metrics.region_watermarks { - merged - .entry(entry.region_id) - .and_modify(|existing| match entry.watermark { - None => match existing { - MergeState::Participated | MergeState::Proved(_) => { - *existing = MergeState::Unproved; - } - MergeState::Unproved | MergeState::Conflict { .. } => {} - }, - Some(seq) => match existing { - MergeState::Participated => { - *existing = MergeState::Proved(seq); - } - MergeState::Unproved => {} - MergeState::Proved(existing_seq) if *existing_seq == seq => {} - MergeState::Proved(existing_seq) => { - let old_seq = *existing_seq; - *existing = MergeState::Conflict { - watermarks: vec![old_seq, seq], - }; - } - MergeState::Conflict { watermarks } => { - if !watermarks.contains(&seq) { - watermarks.push(seq); - } - } - }, - }) - .or_insert(match entry.watermark { - Some(seq) => MergeState::Proved(seq), - None => MergeState::Unproved, - }); - } + merge_region_watermark_entries(merged, metrics.region_watermarks); } } diff --git a/src/query/src/options.rs b/src/query/src/options.rs index 46b8f1e413cf..688d6315b0a0 100644 --- a/src/query/src/options.rs +++ b/src/query/src/options.rs @@ -177,10 +177,35 @@ impl FlowQueryExtensions { } pub fn should_collect_region_watermark(&self) -> bool { - self.return_region_seq || self.incremental_after_seqs.is_some() + should_collect_region_watermark( + self.return_region_seq, + self.incremental_after_seqs.is_some(), + ) } } +/// Returns whether raw Flow query extensions request terminal region watermark collection. +/// +/// This is only an intent/presence check for transport/scan plumbing; callers that need +/// validated Flow options must still use [`FlowQueryExtensions::parse_flow_extensions`]. +pub fn should_collect_region_watermark_from_extensions( + extensions: &HashMap, +) -> bool { + let return_region_seq = extensions + .get(FLOW_RETURN_REGION_SEQ) + .is_some_and(|value| value.eq_ignore_ascii_case("true")); + let has_incremental_after_seqs = extensions.contains_key(FLOW_INCREMENTAL_AFTER_SEQS); + + should_collect_region_watermark(return_region_seq, has_incremental_after_seqs) +} + +fn should_collect_region_watermark( + return_region_seq: bool, + has_incremental_after_seqs: bool, +) -> bool { + return_region_seq || has_incremental_after_seqs +} + fn parse_incremental_after_seqs(value: &str) -> Result> { let raw = serde_json::from_str::>(value).map_err(|e| { invalid_query_context_extension(format!( @@ -420,6 +445,24 @@ mod flow_extension_tests { assert!(parsed.should_collect_region_watermark()); } + #[test] + fn test_should_collect_region_watermark_from_extensions() { + let exts = HashMap::from([(FLOW_RETURN_REGION_SEQ.to_string(), "true".to_string())]); + assert!(should_collect_region_watermark_from_extensions(&exts)); + + let exts = HashMap::from([( + FLOW_INCREMENTAL_AFTER_SEQS.to_string(), + r#"{"1":10}"#.to_string(), + )]); + assert!(should_collect_region_watermark_from_extensions(&exts)); + + let exts = HashMap::from([(FLOW_RETURN_REGION_SEQ.to_string(), "false".to_string())]); + assert!(!should_collect_region_watermark_from_extensions(&exts)); + assert!(!should_collect_region_watermark_from_extensions( + &HashMap::new() + )); + } + #[test] fn test_parse_flow_extensions_return_region_seq_only_returns_some() { let exts = HashMap::from([(FLOW_RETURN_REGION_SEQ.to_string(), "true".to_string())]); diff --git a/src/servers/src/grpc/flight.rs b/src/servers/src/grpc/flight.rs index 364ce8ce268a..a2ff787d1b14 100644 --- a/src/servers/src/grpc/flight.rs +++ b/src/servers/src/grpc/flight.rs @@ -26,8 +26,11 @@ use arrow_flight::{ }; use async_trait::async_trait; use bytes::{self, Bytes}; +use common_error::ext::ErrorExt; use common_grpc::flight::do_put::{DoPutMetadata, DoPutResponse}; -use common_grpc::flight::{FlightDecoder, FlightEncoder, FlightMessage}; +use common_grpc::flight::{ + FLOW_EXTENSIONS_METADATA_KEY, FlightDecoder, FlightEncoder, FlightMessage, +}; use common_memory_manager::MemoryGuard; use common_query::{Output, OutputData}; use common_recordbatch::DfRecordBatch; @@ -38,7 +41,9 @@ use datatypes::arrow::datatypes::SchemaRef; use futures::{Stream, future, ready}; use futures_util::{StreamExt, TryStreamExt}; use prost::Message; -use session::context::{QueryContext, QueryContextRef}; +use query::metrics::terminal_recordbatch_metrics_from_plan_if_requested; +use query::options::FlowQueryExtensions; +use session::context::{Channel, QueryContextRef}; use snafu::{IntoError, ResultExt, ensure}; use table::table_name::TableName; use tokio::sync::mpsc; @@ -47,7 +52,9 @@ use tonic::{Request, Response, Status, Streaming}; use crate::error::{InvalidParameterSnafu, Result, ToJsonSnafu}; pub use crate::grpc::flight::stream::FlightRecordBatchStream; -use crate::grpc::greptime_handler::{GreptimeRequestHandler, get_request_type}; +use crate::grpc::greptime_handler::{ + GreptimeRequestHandler, create_query_context, get_request_type, +}; use crate::grpc::{FlightCompression, TonicResult, context_auth}; use crate::request_memory_limiter::ServerMemoryLimiter; use crate::request_memory_metrics::RequestMemoryMetrics; @@ -186,11 +193,22 @@ impl FlightCraft for GreptimeRequestHandler { &self, request: Request, ) -> TonicResult>> { - let hints = hint_headers::extract_hints(request.metadata()); + let mut hints = hint_headers::extract_hints(request.metadata()); + hints.extend(extract_flow_extensions(request.metadata())?); let ticket = request.into_inner().ticket; let request = GreptimeRequest::decode(ticket.as_ref()).context(error::InvalidFlightTicketSnafu)?; + let query_ctx = + create_query_context(Channel::Grpc, request.header.as_ref(), hints.clone())?; + // Validate flow hint syntax at the transport boundary before dispatching the request. + // This does not authorize or execute anything; `handle_request()` below still performs + // the normal frontend handling and auth checks before query execution. + let flow_extensions = FlowQueryExtensions::parse_flow_extensions(&query_ctx.extensions()) + .map_err(|e| Status::invalid_argument(e.output_msg()))?; + let should_emit_terminal_metrics = flow_extensions + .as_ref() + .is_some_and(|extensions| extensions.should_collect_region_watermark()); // The Grpc protocol pass query by Flight. It needs to be wrapped under a span, in order to record stream let span = info_span!( @@ -205,7 +223,8 @@ impl FlightCraft for GreptimeRequestHandler { output, TracingContext::from_current_span(), flight_compression, - QueryContext::arc(), + query_ctx, + should_emit_terminal_metrics, ); Ok(Response::new(stream)) } @@ -512,11 +531,32 @@ impl Stream for PutRecordBatchRequestStream { } } +fn extract_flow_extensions( + metadata: &tonic::metadata::MetadataMap, +) -> TonicResult> { + let Some(value) = metadata.get(FLOW_EXTENSIONS_METADATA_KEY) else { + return Ok(vec![]); + }; + + let value = value.to_str().map_err(|e| { + Status::invalid_argument(format!( + "Invalid {FLOW_EXTENSIONS_METADATA_KEY} metadata value: {e}" + )) + })?; + + serde_json::from_str::>(value).map_err(|e| { + Status::invalid_argument(format!( + "Invalid {FLOW_EXTENSIONS_METADATA_KEY} metadata JSON: {e}" + )) + }) +} + fn to_flight_data_stream( output: Output, tracing_context: TracingContext, flight_compression: FlightCompression, query_ctx: QueryContextRef, + should_emit_terminal_metrics: bool, ) -> TonicStream { match output.data { OutputData::Stream(stream) => { @@ -538,13 +578,60 @@ fn to_flight_data_stream( Box::pin(stream) as _ } OutputData::AffectedRows(rows) => { - let stream = tokio_stream::iter( - FlightEncoder::default() - .encode(FlightMessage::AffectedRows(rows)) - .into_iter() - .map(Ok), - ); + let terminal_metrics = terminal_recordbatch_metrics_from_plan_if_requested( + output.meta.plan, + should_emit_terminal_metrics, + ) + .and_then(|metrics| serde_json::to_string(&metrics).ok()); + let affected_rows = FlightEncoder::default().encode(FlightMessage::AffectedRows { + rows, + metrics: terminal_metrics, + }); + let stream = tokio_stream::iter(affected_rows.into_iter().map(Ok)); Box::pin(stream) as _ } } } + +#[cfg(test)] +mod tests { + use tonic::metadata::{AsciiMetadataValue, MetadataMap}; + + use super::*; + + #[test] + fn test_extract_flow_extensions_preserves_comma_bearing_values() { + let mut metadata = MetadataMap::new(); + metadata.insert( + FLOW_EXTENSIONS_METADATA_KEY, + AsciiMetadataValue::try_from( + r#"[["flow.return_region_seq","true"],["flow.incremental_after_seqs","{\"1\":10,\"2\":20}"]]"#, + ) + .unwrap(), + ); + + let extensions = extract_flow_extensions(&metadata).unwrap(); + assert_eq!( + extensions, + vec![ + ("flow.return_region_seq".to_string(), "true".to_string()), + ( + "flow.incremental_after_seqs".to_string(), + r#"{"1":10,"2":20}"#.to_string() + ), + ] + ); + } + + #[test] + fn test_extract_flow_extensions_rejects_invalid_json() { + let mut metadata = MetadataMap::new(); + metadata.insert( + FLOW_EXTENSIONS_METADATA_KEY, + AsciiMetadataValue::try_from("not-json").unwrap(), + ); + + let err = extract_flow_extensions(&metadata).unwrap_err(); + assert_eq!(err.code(), tonic::Code::InvalidArgument); + } +} diff --git a/tests-integration/src/grpc/flight.rs b/tests-integration/src/grpc/flight.rs index 9ed6b8176f1f..d638926fff3a 100644 --- a/tests-integration/src/grpc/flight.rs +++ b/tests-integration/src/grpc/flight.rs @@ -27,6 +27,7 @@ mod test { use common_grpc::flight::{FlightEncoder, FlightMessage}; use common_query::OutputData; use common_recordbatch::RecordBatch; + use common_recordbatch::adapter::RegionWatermarkEntry; use datatypes::prelude::{ConcreteDataType, ScalarVector, VectorRef}; use datatypes::schema::{ColumnSchema, Schema}; use datatypes::vectors::{Int32Vector, StringVector, TimestampMillisecondVector}; @@ -129,6 +130,104 @@ mod test { | 1970-01-01T00:00:00.009 | -9 | s9 | +-------------------------+----+----+"; query_and_expect(db.fe_instance().as_ref(), sql, expected).await; + + let output = client.sql(sql).await.unwrap(); + let OutputData::Stream(mut stream) = output.data else { + panic!("expected stream output"); + }; + while let Some(batch) = stream.next().await { + batch.unwrap(); + } + let metrics = stream.metrics().expect("expected terminal metrics"); + assert!(metrics.region_watermarks.is_empty()); + + let result = client + .sql_with_terminal_metrics(sql, &[("flow.return_region_seq", "true")]) + .await + .unwrap(); + let terminal_metrics = result.metrics.clone(); + let OutputData::Stream(mut stream) = result.output.data else { + panic!("expected stream output"); + }; + while let Some(batch) = stream.next().await { + batch.unwrap(); + } + assert!(terminal_metrics.is_ready()); + let regions = db.list_all_regions().await; + assert_eq!(regions.len(), 1); + let (region_id, region) = regions.into_iter().next().unwrap(); + let expected_watermark = (region_id.as_u64(), region.find_committed_sequence()); + assert_eq!( + terminal_metrics.region_watermark_map(), + Some(std::collections::HashMap::from([expected_watermark])) + ); + + let output = client + .sql_with_hint(sql, &[("flow.return_region_seq", "true")]) + .await + .unwrap(); + let OutputData::Stream(mut stream) = output.data else { + panic!("expected stream output"); + }; + + let mut row_count = 0; + while let Some(batch) = stream.next().await { + let batch = batch.unwrap(); + row_count += batch.num_rows(); + } + assert_eq!(row_count, 9); + + let metrics = stream.metrics().expect("expected terminal metrics"); + let region_watermarks = metrics.region_watermarks; + assert_eq!( + region_watermarks, + vec![RegionWatermarkEntry { + region_id: expected_watermark.0, + watermark: Some(expected_watermark.1), + }] + ); + + let previous_watermark = expected_watermark; + + create_table_named(&client, "bar").await; + let result = client + .sql_with_terminal_metrics("insert into bar select ts, a, `B` from foo", &[]) + .await + .unwrap(); + let OutputData::AffectedRows(affected_rows) = result.output.data else { + panic!("expected affected rows output"); + }; + assert_eq!(affected_rows, 9); + assert!(result.metrics.is_ready()); + assert!(result.region_watermark_map().is_none()); + + let err = client + .sql_with_terminal_metrics( + "insert into bar select ts, a, `B` from foo", + &[("flow.return_region_seq", "not-a-bool")], + ) + .await + .unwrap_err(); + let err_msg = format!("{err:?}"); + assert!(err_msg.contains("Invalid value for flow.return_region_seq")); + + client.sql("truncate table bar").await.unwrap(); + + let result = client + .sql_with_terminal_metrics( + "insert into bar select ts, a, `B` from foo", + &[("flow.return_region_seq", "true")], + ) + .await + .unwrap(); + let OutputData::AffectedRows(affected_rows) = result.output.data else { + panic!("expected affected rows output"); + }; + assert_eq!(affected_rows, 9); + assert_eq!( + result.region_watermark_map(), + Some(std::collections::HashMap::from([previous_watermark])) + ); } async fn test_put_record_batches(client: &Database, record_batches: Vec) { @@ -224,6 +323,10 @@ mod test { } async fn create_table(client: &Database) { + create_table_named(client, "foo").await; + } + + async fn create_table_named(client: &Database, table_name: &str) { // create table foo ( // ts timestamp time index, // a int primary key, @@ -232,7 +335,7 @@ mod test { let output = client .create(CreateTableExpr { schema_name: "public".to_string(), - table_name: "foo".to_string(), + table_name: table_name.to_string(), column_defs: vec![ ColumnDef { name: "ts".to_string(),