Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
363 changes: 363 additions & 0 deletions docs/rfcs/2026-04-08-aggregate-stats-physical-pass.md

Large diffs are not rendered by default.

160 changes: 88 additions & 72 deletions src/common/function/src/aggrs/aggr_wrapper.rs
Original file line number Diff line number Diff line change
Expand Up @@ -330,6 +330,39 @@ impl StateWrapper {
acc_args.return_field = self.deduce_aggr_return_type(&acc_args)?;
Ok(acc_args)
}

/// Builds a state scalar from explicit state-field values.
///
/// The caller must provide one scalar per state field in the wrapper's state layout.
/// This method is responsible only for validating the current wrapper state type and
/// assembling the final struct scalar from those explicit field values.
pub fn value_from_custom_state_fields(
&self,
arg_types: &[DataType],
state_values: Vec<ScalarValue>,
) -> datafusion_common::Result<ScalarValue> {
let DataType::Struct(fields) = self.return_type(arg_types)? else {
return Err(datafusion_common::DataFusionError::Internal(format!(
"Expected struct state type for {}, got non-struct return type",
self.name()
)));
};
if fields.len() != state_values.len() {
return Err(datafusion_common::DataFusionError::Internal(format!(
"Expected {} state fields for {}, got {}",
fields.len(),
self.name(),
state_values.len()
)));
}

let arrays = state_values
.into_iter()
.map(|value| value.to_array())
.collect::<datafusion_common::Result<Vec<_>>>()?;
let struct_array = build_state_struct_array(&fields, arrays)?;
Ok(ScalarValue::Struct(Arc::new(struct_array)))
}
}

impl AggregateUDFImpl for StateWrapper {
Expand Down Expand Up @@ -472,13 +505,59 @@ impl AggregateUDFImpl for StateWrapper {
};

let array = ret.to_array().ok()?;

let struct_array = StructArray::new(fields.clone(), vec![array], None);
let ret = ScalarValue::Struct(Arc::new(struct_array));
Some(ret)
}
}

fn build_state_struct_array(
fields: &Fields,
arrays: Vec<ArrayRef>,
) -> datafusion_common::Result<StructArray> {
let array_type = arrays
.iter()
.map(|array| array.data_type().clone())
.collect::<Vec<_>>();
let expected_type = fields
.iter()
.map(|field| field.data_type().clone())
.collect::<Vec<_>>();
if array_type != expected_type {
// Keep this fallback intentionally lenient.
//
// Historically the wrapper path has tolerated state-schema drift as long as the
// physical state columns remain positionally compatible. This shows up most clearly
// in order-sensitive aggregates such as first_value/last_value, where DataFusion-side
// state metadata and the arrays we need to wrap may not line up exactly. The merge
// path consumes state columns by position, not by field metadata, so preserving a
// struct wrapper here is more compatible than failing eagerly on field/type mismatch.
debug!(
"State mismatch, expected: {}, got: {} for expected fields: {:?} and given array types: {:?}",
fields.len(),
arrays.len(),
fields,
array_type,
);
let guess_schema = arrays
.iter()
.enumerate()
.map(|(index, array)| {
Field::new(
format!("col_{index}[mismatch_state]").as_str(),
array.data_type().clone(),
true,
)
})
.collect::<Fields>();
return StructArray::try_new(guess_schema, arrays, None)
.map_err(|err| datafusion_common::DataFusionError::ArrowError(Box::new(err), None));
}

StructArray::try_new(fields.clone(), arrays, None)
.map_err(|err| datafusion_common::DataFusionError::ArrowError(Box::new(err), None))
}

/// The wrapper's input is the same as the original aggregate function's input,
/// and the output is the state function's output.
#[derive(Debug)]
Expand Down Expand Up @@ -510,42 +589,9 @@ impl StateGroupsAccum {
}

fn wrap_state_arrays(&self, arrays: Vec<ArrayRef>) -> datafusion_common::Result<ArrayRef> {
let array_type = arrays
.iter()
.map(|array| array.data_type().clone())
.collect::<Vec<_>>();
let expected_type = self
.state_fields
.iter()
.map(|field| field.data_type().clone())
.collect::<Vec<_>>();
if array_type != expected_type {
debug!(
"State mismatch, expected: {}, got: {} for expected fields: {:?} and given array types: {:?}",
self.state_fields.len(),
arrays.len(),
self.state_fields,
array_type,
);
let guess_schema = arrays
.iter()
.enumerate()
.map(|(index, array)| {
Field::new(
format!("col_{index}[mismatch_state]").as_str(),
array.data_type().clone(),
true,
)
})
.collect::<Fields>();
let array = StructArray::try_new(guess_schema, arrays, None)?;
return Ok(Arc::new(array));
}

Ok(Arc::new(StructArray::try_new(
self.state_fields.clone(),
Ok(Arc::new(build_state_struct_array(
&self.state_fields,
arrays,
None,
)?))
}
}
Expand Down Expand Up @@ -621,44 +667,11 @@ impl Accumulator for StateAccum {
fn evaluate(&mut self) -> datafusion_common::Result<ScalarValue> {
let state = self.inner.state()?;

let array = state
let arrays = state
.iter()
.map(|s| s.to_array())
.collect::<Result<Vec<_>, _>>()?;
let array_type = array
.iter()
.map(|a| a.data_type().clone())
.collect::<Vec<_>>();
let expected_type: Vec<_> = self
.state_fields
.iter()
.map(|f| f.data_type().clone())
.collect();
if array_type != expected_type {
debug!(
"State mismatch, expected: {}, got: {} for expected fields: {:?} and given array types: {:?}",
self.state_fields.len(),
array.len(),
self.state_fields,
array_type,
);
let guess_schema = array
.iter()
.enumerate()
.map(|(index, array)| {
Field::new(
format!("col_{index}[mismatch_state]").as_str(),
array.data_type().clone(),
true,
)
})
.collect::<Fields>();
let arr = StructArray::try_new(guess_schema, array, None)?;

return Ok(ScalarValue::Struct(Arc::new(arr)));
}

let struct_array = StructArray::try_new(self.state_fields.clone(), array, None)?;
let struct_array = build_state_struct_array(&self.state_fields, arrays)?;
Ok(ScalarValue::Struct(Arc::new(struct_array)))
}

Expand Down Expand Up @@ -860,7 +873,10 @@ impl Accumulator for MergeAccum {
"State fields mismatch, expected: {:?}, got: {:?}",
self.state_fields, fields
);
// state fields mismatch might be acceptable by datafusion, continue
// Intentionally continue here for compatibility with the wrapper's historical
// behavior: downstream merge logic uses the struct columns positionally, and some
// DataFusion/order-sensitive aggregate paths can produce equivalent state payloads
// whose field metadata does not exactly match our locally expected schema.
}

// now fields should be the same, so we can merge the batch
Expand Down
45 changes: 45 additions & 0 deletions src/common/function/src/aggrs/aggr_wrapper/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ use datafusion::datasource::DefaultTableSource;
use datafusion::execution::{RecordBatchStream, SendableRecordBatchStream, TaskContext};
use datafusion::functions_aggregate::average::avg_udaf;
use datafusion::functions_aggregate::count::count_udaf;
use datafusion::functions_aggregate::min_max::max_udaf;
use datafusion::functions_aggregate::sum::sum_udaf;
use datafusion::optimizer::AnalyzerRule;
use datafusion::optimizer::analyzer::type_coercion::TypeCoercion;
Expand Down Expand Up @@ -291,6 +292,50 @@ fn create_avg_state_groups_accumulator() -> Box<dyn GroupsAccumulator> {
state_wrapper.create_groups_accumulator(acc_args).unwrap()
}

fn test_state_scalar_for_type(data_type: &DataType) -> ScalarValue {
match data_type {
DataType::Float64 => ScalarValue::Float64(Some(1.5)),
DataType::UInt64 => ScalarValue::UInt64(Some(2)),
DataType::Int64 => ScalarValue::Int64(Some(3)),
_ => panic!("unsupported test data type: {data_type:?}"),
}
}

#[test]
fn test_value_from_custom_state_fields_single_field() {
let wrapper = StateWrapper::new((*max_udaf()).clone()).unwrap();
let value = wrapper
.value_from_custom_state_fields(&[DataType::Int64], vec![ScalarValue::Int64(Some(7))])
.unwrap();

let ScalarValue::Struct(array) = value else {
panic!("expected struct state")
};
assert_eq!(1, array.columns().len());
assert_eq!(DataType::Int64, array.column(0).data_type().clone());
}

#[test]
fn test_value_from_custom_state_fields_multi_field() {
let wrapper = StateWrapper::new((*avg_udaf()).clone()).unwrap();
let DataType::Struct(fields) = wrapper.return_type(&[DataType::Float64]).unwrap() else {
panic!("expected struct state type")
};

let values = fields
.iter()
.map(|field| test_state_scalar_for_type(field.data_type()))
.collect::<Vec<_>>();
let value = wrapper
.value_from_custom_state_fields(&[DataType::Float64], values)
.unwrap();

let ScalarValue::Struct(array) = value else {
panic!("expected struct state")
};
assert_eq!(fields.len(), array.columns().len());
}

#[tokio::test]
async fn test_sum_udaf() {
let ctx = SessionContext::new();
Expand Down
1 change: 1 addition & 0 deletions src/mito2/src/read.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ pub(crate) mod prune;
pub(crate) mod pruner;
pub mod range;
pub(crate) mod range_cache;
pub(crate) mod scan_input_stats;
pub mod scan_region;
pub mod scan_util;
pub(crate) mod seq_scan;
Expand Down
Loading
Loading