Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
272 changes: 242 additions & 30 deletions arrow-flight/src/encode.rs
Original file line number Diff line number Diff line change
Expand Up @@ -528,13 +528,28 @@ fn prepare_field_for_flight(
}
DataType::Dictionary(_, value_type) => {
if !send_dictionaries {
Field::new(
// Recurse into value type to handle nested dicts being stripped
let value_field = Field::new(
field.name(),
value_type.as_ref().clone(),
field.is_nullable(),
);
prepare_field_for_flight(
&Arc::new(value_field),
dictionary_tracker,
send_dictionaries,
)
.with_metadata(field.metadata().clone())
} else {
// Recurse into value type BEFORE registering this dict's id,
// matching the depth-first order of encode_dictionaries in the
// IPC writer which processes nested dicts before the parent.
let value_field = Field::new("values", value_type.as_ref().clone(), true);
prepare_field_for_flight(
&Arc::new(value_field),
dictionary_tracker,
send_dictionaries,
);
dictionary_tracker.next_dict_id();
#[allow(deprecated)]
Field::new_dict(
Expand All @@ -547,6 +562,44 @@ fn prepare_field_for_flight(
.with_metadata(field.metadata().clone())
}
}
DataType::ListView(inner) | DataType::LargeListView(inner) => {
let prepared = prepare_field_for_flight(inner, dictionary_tracker, send_dictionaries);
Field::new(
field.name(),
match field.data_type() {
DataType::ListView(_) => DataType::ListView(Arc::new(prepared)),
_ => DataType::LargeListView(Arc::new(prepared)),
},
field.is_nullable(),
)
.with_metadata(field.metadata().clone())
}
DataType::FixedSizeList(inner, size) => Field::new(
field.name(),
DataType::FixedSizeList(
Arc::new(prepare_field_for_flight(
inner,
dictionary_tracker,
send_dictionaries,
)),
*size,
),
field.is_nullable(),
)
.with_metadata(field.metadata().clone()),
DataType::RunEndEncoded(run_ends, values) => Field::new(
field.name(),
DataType::RunEndEncoded(
run_ends.clone(),
Arc::new(prepare_field_for_flight(
values,
dictionary_tracker,
send_dictionaries,
)),
),
field.is_nullable(),
)
.with_metadata(field.metadata().clone()),
DataType::Map(inner, sorted) => Field::new(
field.name(),
DataType::Map(
Expand All @@ -556,7 +609,37 @@ fn prepare_field_for_flight(
field.is_nullable(),
)
.with_metadata(field.metadata().clone()),
_ => field.as_ref().clone(),
DataType::Null
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is good -- it gives confidence we aren't missing another nested type 👍

Thank you

| DataType::Boolean
| DataType::Int8
| DataType::Int16
| DataType::Int32
| DataType::Int64
| DataType::UInt8
| DataType::UInt16
| DataType::UInt32
| DataType::UInt64
| DataType::Float16
| DataType::Float32
| DataType::Float64
| DataType::Timestamp(_, _)
| DataType::Date32
| DataType::Date64
| DataType::Time32(_)
| DataType::Time64(_)
| DataType::Duration(_)
| DataType::Interval(_)
| DataType::Binary
| DataType::FixedSizeBinary(_)
| DataType::LargeBinary
| DataType::BinaryView
| DataType::Utf8
| DataType::LargeUtf8
| DataType::Utf8View
| DataType::Decimal32(_, _)
| DataType::Decimal64(_, _)
| DataType::Decimal128(_, _)
| DataType::Decimal256(_, _) => field.as_ref().clone(),
}
}

Expand All @@ -573,33 +656,7 @@ fn prepare_schema_for_flight(
let fields: Fields = schema
.fields()
.iter()
.map(|field| match field.data_type() {
DataType::Dictionary(_, value_type) => {
if !send_dictionaries {
Field::new(
field.name(),
value_type.as_ref().clone(),
field.is_nullable(),
)
.with_metadata(field.metadata().clone())
} else {
dictionary_tracker.next_dict_id();
#[allow(deprecated)]
Field::new_dict(
field.name(),
field.data_type().clone(),
field.is_nullable(),
0,
field.dict_is_ordered().unwrap_or_default(),
)
.with_metadata(field.metadata().clone())
}
}
tpe if tpe.is_nested() => {
prepare_field_for_flight(field, dictionary_tracker, send_dictionaries)
}
_ => field.as_ref().clone(),
})
.map(|field| prepare_field_for_flight(field, dictionary_tracker, send_dictionaries))
.collect();

Schema::new(fields).with_metadata(schema.metadata().clone())
Expand Down Expand Up @@ -729,7 +786,8 @@ fn hydrate_dictionary(array: &ArrayRef, data_type: &DataType) -> Result<ArrayRef
mod tests {
use crate::decode::{DecodedPayload, FlightDataDecoder};
use arrow_array::builder::{
GenericByteDictionaryBuilder, ListBuilder, StringDictionaryBuilder, StructBuilder,
FixedSizeListBuilder, GenericByteDictionaryBuilder, GenericListViewBuilder, ListBuilder,
StringDictionaryBuilder, StructBuilder,
};
use arrow_array::*;
use arrow_array::{cast::downcast_array, types::*};
Expand Down Expand Up @@ -1540,6 +1598,160 @@ mod tests {
verify_flight_round_trip(vec![batch1, batch2]).await;
}

#[tokio::test]
async fn test_dictionary_ree_resend() {
let dict_values1 = vec![Some("a"), None, Some("b")]
.into_iter()
.collect::<DictionaryArray<Int32Type>>();
let run_ends1 = Int32Array::from(vec![1, 2, 3]);
let arr1 = RunArray::try_new(&run_ends1, &dict_values1).unwrap();

let dict_values2 = vec![Some("c"), Some("a")]
.into_iter()
.collect::<DictionaryArray<Int32Type>>();
let run_ends2 = Int32Array::from(vec![1, 2]);
let arr2 = RunArray::try_new(&run_ends2, &dict_values2).unwrap();

let schema = Arc::new(Schema::new(vec![Field::new(
"ree",
arr1.data_type().clone(),
true,
)]));

let batch1 = RecordBatch::try_new(schema.clone(), vec![Arc::new(arr1)]).unwrap();
let batch2 = RecordBatch::try_new(schema, vec![Arc::new(arr2)]).unwrap();

verify_flight_round_trip(vec![batch1, batch2]).await;
}

#[tokio::test]
async fn test_dictionary_of_struct_of_dict_resend() {
// Dict(Int8, Struct { dict: Dict(Int32, Utf8), int: Int32 })
// This exercises the Dictionary branch recursing into its value type
// before assigning its own dict_id (depth-first ordering).
let struct_fields: Vec<Field> = vec![
Field::new_dictionary("dict", DataType::Int32, DataType::Utf8, true),
Field::new("int", DataType::Int32, false),
];

let inner_values =
StringArray::from(vec![Some("alpha"), None, Some("beta"), Some("gamma")]);
let inner_keys = Int32Array::from_iter_values([0, 1, 2, 3, 0]);
let inner_dict = DictionaryArray::new(inner_keys, Arc::new(inner_values));
let int_array = Int32Array::from(vec![10, 20, 30, 40, 50]);

let struct_array = StructArray::from(vec![
(
Arc::new(struct_fields[0].clone()),
Arc::new(inner_dict) as ArrayRef,
),
(
Arc::new(struct_fields[1].clone()),
Arc::new(int_array) as ArrayRef,
),
]);

let outer_keys = Int8Array::from_iter_values([0, 0, 1, 2]);
let arr1 = DictionaryArray::new(outer_keys, Arc::new(struct_array));

let inner_values2 = StringArray::from(vec![Some("x"), Some("y")]);
let inner_keys2 = Int32Array::from_iter_values([0, 1, 0]);
let inner_dict2 = DictionaryArray::new(inner_keys2, Arc::new(inner_values2));
let int_array2 = Int32Array::from(vec![100, 200, 300]);

let struct_array2 = StructArray::from(vec![
(
Arc::new(struct_fields[0].clone()),
Arc::new(inner_dict2) as ArrayRef,
),
(
Arc::new(struct_fields[1].clone()),
Arc::new(int_array2) as ArrayRef,
),
]);

let outer_keys2 = Int8Array::from_iter_values([0, 1]);
let arr2 = DictionaryArray::new(outer_keys2, Arc::new(struct_array2));

let schema = Arc::new(Schema::new(vec![Field::new(
"dict_struct",
arr1.data_type().clone(),
false,
)]));

let batch1 = RecordBatch::try_new(schema.clone(), vec![Arc::new(arr1)]).unwrap();
let batch2 = RecordBatch::try_new(schema, vec![Arc::new(arr2)]).unwrap();

verify_flight_round_trip(vec![batch1, batch2]).await;
}

async fn verify_dictionary_list_view_resend<O: OffsetSizeTrait>() {
let mut builder =
GenericListViewBuilder::<O, _>::new(StringDictionaryBuilder::<UInt16Type>::new());

builder.append_value(vec![Some("a"), None, Some("b")]);
let arr1 = builder.finish();

builder.append_value(vec![Some("c"), None, Some("d")]);
let arr2 = builder.finish();

let inner = Arc::new(Field::new_dictionary(
"item",
DataType::UInt16,
DataType::Utf8,
true,
));
let dt = if O::IS_LARGE {
DataType::LargeListView(inner)
} else {
DataType::ListView(inner)
};
let schema = Arc::new(Schema::new(vec![Field::new("dict_list_view", dt, true)]));

let batch1 = RecordBatch::try_new(schema.clone(), vec![Arc::new(arr1)]).unwrap();
let batch2 = RecordBatch::try_new(schema, vec![Arc::new(arr2)]).unwrap();

verify_flight_round_trip(vec![batch1, batch2]).await;
}

#[tokio::test]
async fn test_dictionary_list_view_resend() {
verify_dictionary_list_view_resend::<i32>().await;
}

#[tokio::test]
async fn test_dictionary_large_list_view_resend() {
verify_dictionary_list_view_resend::<i64>().await;
}

#[tokio::test]
async fn test_dictionary_fixed_size_list_resend() {
let mut builder =
FixedSizeListBuilder::new(StringDictionaryBuilder::<UInt16Type>::new(), 2);

builder.values().append_value("a");
builder.values().append_value("b");
builder.append(true);
let arr1 = builder.finish();

builder.values().append_value("c");
builder.values().append_value("d");
builder.append(true);
let arr2 = builder.finish();

let schema = Arc::new(Schema::new(vec![Field::new_fixed_size_list(
"dict_fsl",
Field::new_dictionary("item", DataType::UInt16, DataType::Utf8, true),
2,
true,
)]));

let batch1 = RecordBatch::try_new(schema.clone(), vec![Arc::new(arr1)]).unwrap();
let batch2 = RecordBatch::try_new(schema, vec![Arc::new(arr2)]).unwrap();

verify_flight_round_trip(vec![batch1, batch2]).await;
}

async fn verify_flight_round_trip(mut batches: Vec<RecordBatch>) {
let expected_schema = batches.first().unwrap().schema();

Expand Down
7 changes: 6 additions & 1 deletion arrow-ipc/src/writer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -417,7 +417,12 @@ impl IpcDataGenerator {
// sequence is assigned depth-first, so we need to first encode children and have
// them take their assigned dict IDs before we take the dict ID for this field.
let dict_id = dict_id_seq.next().ok_or_else(|| {
ArrowError::IpcError(format!("no dict id for field {}", field.name()))
ArrowError::IpcError(format!(
"no dict id for field {:?}: field.data_type={:?}, column.data_type={:?}",
field.name(),
field.data_type(),
column.data_type()
))
})?;

match dictionary_tracker.insert_column(
Expand Down
Loading