Skip to content
762 changes: 681 additions & 81 deletions src/client/src/database.rs

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion src/client/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ pub use common_recordbatch::{RecordBatches, SendableRecordBatchStream};
use snafu::OptionExt;

pub use self::client::Client;
pub use self::database::Database;
pub use self::database::{Database, OutputMetrics, OutputWithMetrics};
pub use self::error::{Error, Result};
use crate::error::{IllegalDatabaseResponseSnafu, ServerSnafu};

Expand Down
134 changes: 130 additions & 4 deletions src/common/grpc/src/flight.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,11 +37,17 @@ use vec1::{Vec1, vec1};
use crate::error;
use crate::error::{DecodeFlightDataSnafu, InvalidFlightDataSnafu, Result};

/// Flight metadata key used to carry flow query extensions as JSON pairs.
pub const FLOW_EXTENSIONS_METADATA_KEY: &str = "x-greptime-flow-extensions";

#[derive(Debug, Clone)]
pub enum FlightMessage {
Schema(SchemaRef),
RecordBatch(DfRecordBatch),
AffectedRows(usize),
AffectedRows {
Comment thread
killme2008 marked this conversation as resolved.
rows: usize,
metrics: Option<String>,
Comment thread
discord9 marked this conversation as resolved.
},
Metrics(String),
}

Expand Down Expand Up @@ -116,10 +122,12 @@ impl FlightEncoder {
encoded_batch.into(),
)
}
FlightMessage::AffectedRows(rows) => {
FlightMessage::AffectedRows { rows, metrics } => {
let metadata = FlightMetadata {
affected_rows: Some(AffectedRows { value: rows as _ }),
metrics: None,
metrics: metrics.map(|s| Metrics {
metrics: s.into_bytes(),
}),
}
.encode_to_vec();
vec1![FlightData {
Expand Down Expand Up @@ -223,7 +231,12 @@ impl FlightDecoder {
let metadata = FlightMetadata::decode(flight_data.app_metadata.clone())
.context(DecodeFlightDataSnafu)?;
if let Some(AffectedRows { value }) = metadata.affected_rows {
return Ok(Some(FlightMessage::AffectedRows(value as _)));
return Ok(Some(FlightMessage::AffectedRows {
rows: value as _,
metrics: metadata
.metrics
.map(|m| String::from_utf8_lossy(&m.metrics).to_string()),
}));
}
if let Some(Metrics { metrics }) = metadata.metrics {
return Ok(Some(FlightMessage::Metrics(
Expand Down Expand Up @@ -426,6 +439,47 @@ mod test {
Ok(())
}

#[test]
fn test_affected_rows_metrics_encode_decode() -> Result<()> {
let metrics = r#"{"region_watermarks":[{"region_id":42,"watermark":7}]}"#;
let mut encoder = FlightEncoder::default();
let encoded = encoder.encode(FlightMessage::AffectedRows {
rows: 3,
metrics: Some(metrics.to_string()),
});

assert_eq!(encoded.len(), 1);

let mut decoder = FlightDecoder::default();
let decoded = decoder.try_decode(encoded.first())?.unwrap();
let FlightMessage::AffectedRows {
rows,
metrics: decoded_metrics,
} = decoded
else {
unreachable!()
};
assert_eq!(rows, 3);
assert_eq!(decoded_metrics.as_deref(), Some(metrics));

let encoded = encoder.encode(FlightMessage::AffectedRows {
rows: 5,
metrics: None,
});
let decoded = decoder.try_decode(encoded.first())?.unwrap();
let FlightMessage::AffectedRows {
rows,
metrics: decoded_metrics,
} = decoded
else {
unreachable!()
};
assert_eq!(rows, 5);
assert!(decoded_metrics.is_none());

Ok(())
}

#[test]
fn test_flight_messages_to_recordbatches() {
let schema = Arc::new(Schema::new(vec![Field::new("m", DataType::Int32, true)]));
Expand Down Expand Up @@ -548,4 +602,76 @@ mod test {
assert_eq!(actual, expected.trim());
Ok(())
}

#[test]
fn test_affected_rows_roundtrip_through_flight_codec() {
// Verify the full FlightEncoder → FlightDecoder pipeline handles
// the new FlightMessage::AffectedRows variant with optional inline
// metrics without breaking the wire protocol.
let mut encoder = FlightEncoder::default();
let mut decoder = FlightDecoder::default();

// Without metrics — same wire format as old `AffectedRows(7)`.
let encoded = encoder.encode(FlightMessage::AffectedRows {
rows: 7,
metrics: None,
});
let decoded = decoder.try_decode(encoded.first()).unwrap().unwrap();
assert!(matches!(
decoded,
FlightMessage::AffectedRows {
rows: 7,
metrics: None,
}
));

// With metrics — new capability, row count preserved.
let json = r#"{"region_watermarks":[{"region_id":1,"watermark":99}]}"#;
let encoded = encoder.encode(FlightMessage::AffectedRows {
rows: 42,
metrics: Some(json.to_string()),
});
let decoded = decoder.try_decode(encoded.first()).unwrap().unwrap();
assert!(matches!(
decoded,
FlightMessage::AffectedRows {
rows: 42,
metrics: Some(_),
}
));
}

/// Simulates the wire output of the **old** `FlightMessage::AffectedRows(usize)`
/// variant and verifies that the **new** `FlightDecoder` handles it.
#[test]
fn test_old_affected_rows_format_decoded_by_new_code() {
use arrow_flight::FlightData;
use prost::bytes::Bytes as ProstBytes;

// The old encoder produced FlightData whose app_metadata is
// FlightMetadata { affected_rows, metrics: None }. The new
// `AffectedRows { rows, metrics: Option<String> }` variant with
// `metrics: None` produces the exact same wire bytes.
let old_wire_bytes = FlightData {
flight_descriptor: None,
data_header: build_none_flight_msg().into(),
app_metadata: FlightMetadata {
affected_rows: Some(AffectedRows { value: 99 }),
metrics: None, // old format: no metrics field
}
.encode_to_vec()
.into(),
data_body: ProstBytes::default(),
};

let mut decoder = FlightDecoder::default();
let decoded = decoder.try_decode(&old_wire_bytes).unwrap().unwrap();
assert!(matches!(
decoded,
FlightMessage::AffectedRows {
rows: 99,
metrics: None,
}
));
}
}
Loading
Loading