Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
229 changes: 200 additions & 29 deletions arrow-flight/src/encode.rs
Original file line number Diff line number Diff line change
Expand Up @@ -528,13 +528,28 @@ fn prepare_field_for_flight(
}
DataType::Dictionary(_, value_type) => {
if !send_dictionaries {
Field::new(
// Recurse into value type to handle nested dicts being stripped
let value_field = Field::new(
field.name(),
value_type.as_ref().clone(),
field.is_nullable(),
);
prepare_field_for_flight(
&Arc::new(value_field),
dictionary_tracker,
send_dictionaries,
)
.with_metadata(field.metadata().clone())
} else {
// Recurse into value type BEFORE registering this dict's id,
// matching the depth-first order of encode_dictionaries in the
// IPC writer which processes nested dicts before the parent.
let value_field = Field::new("values", value_type.as_ref().clone(), true);
prepare_field_for_flight(
&Arc::new(value_field),
dictionary_tracker,
send_dictionaries,
);
dictionary_tracker.next_dict_id();
#[allow(deprecated)]
Field::new_dict(
Expand All @@ -547,6 +562,44 @@ fn prepare_field_for_flight(
.with_metadata(field.metadata().clone())
}
}
DataType::ListView(inner) | DataType::LargeListView(inner) => {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👍 this makes sense to add handling for these nested types.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would probably be good to change the catch all at the bottom of the match statement to explicitly list out all remaining DataTypes as a way to audit that we didn't miss any additional nested types (either now or in the future)

        .with_metadata(field.metadata().clone()),
        _ => field.as_ref().clone(),
    }

let prepared = prepare_field_for_flight(inner, dictionary_tracker, send_dictionaries);
Field::new(
field.name(),
match field.data_type() {
DataType::ListView(_) => DataType::ListView(Arc::new(prepared)),
_ => DataType::LargeListView(Arc::new(prepared)),
},
field.is_nullable(),
)
.with_metadata(field.metadata().clone())
}
DataType::FixedSizeList(inner, size) => Field::new(
field.name(),
DataType::FixedSizeList(
Arc::new(prepare_field_for_flight(
inner,
dictionary_tracker,
send_dictionaries,
)),
*size,
),
field.is_nullable(),
)
.with_metadata(field.metadata().clone()),
DataType::RunEndEncoded(run_ends, values) => Field::new(
field.name(),
DataType::RunEndEncoded(
run_ends.clone(),
Arc::new(prepare_field_for_flight(
values,
dictionary_tracker,
send_dictionaries,
)),
),
field.is_nullable(),
)
.with_metadata(field.metadata().clone()),
DataType::Map(inner, sorted) => Field::new(
field.name(),
DataType::Map(
Expand All @@ -573,33 +626,7 @@ fn prepare_schema_for_flight(
let fields: Fields = schema
.fields()
.iter()
.map(|field| match field.data_type() {
DataType::Dictionary(_, value_type) => {
if !send_dictionaries {
Field::new(
field.name(),
value_type.as_ref().clone(),
field.is_nullable(),
)
.with_metadata(field.metadata().clone())
} else {
dictionary_tracker.next_dict_id();
#[allow(deprecated)]
Field::new_dict(
field.name(),
field.data_type().clone(),
field.is_nullable(),
0,
field.dict_is_ordered().unwrap_or_default(),
)
.with_metadata(field.metadata().clone())
}
}
tpe if tpe.is_nested() => {
prepare_field_for_flight(field, dictionary_tracker, send_dictionaries)
}
_ => field.as_ref().clone(),
})
.map(|field| prepare_field_for_flight(field, dictionary_tracker, send_dictionaries))
.collect();

Schema::new(fields).with_metadata(schema.metadata().clone())
Expand Down Expand Up @@ -729,7 +756,8 @@ fn hydrate_dictionary(array: &ArrayRef, data_type: &DataType) -> Result<ArrayRef
mod tests {
use crate::decode::{DecodedPayload, FlightDataDecoder};
use arrow_array::builder::{
GenericByteDictionaryBuilder, ListBuilder, StringDictionaryBuilder, StructBuilder,
FixedSizeListBuilder, GenericByteDictionaryBuilder, GenericListViewBuilder, ListBuilder,
StringDictionaryBuilder, StructBuilder,
};
use arrow_array::*;
use arrow_array::{cast::downcast_array, types::*};
Expand Down Expand Up @@ -1540,6 +1568,149 @@ mod tests {
verify_flight_round_trip(vec![batch1, batch2]).await;
}

#[tokio::test]
async fn test_dictionary_ree_resend() {
let dict_values1 = vec![Some("a"), None, Some("b")]
.into_iter()
.collect::<DictionaryArray<Int32Type>>();
let run_ends1 = Int32Array::from(vec![1, 2, 3]);
let arr1 = RunArray::try_new(&run_ends1, &dict_values1).unwrap();

let dict_values2 = vec![Some("c"), Some("a")]
.into_iter()
.collect::<DictionaryArray<Int32Type>>();
let run_ends2 = Int32Array::from(vec![1, 2]);
let arr2 = RunArray::try_new(&run_ends2, &dict_values2).unwrap();

let schema = Arc::new(Schema::new(vec![Field::new(
"ree",
arr1.data_type().clone(),
true,
)]));

let batch1 = RecordBatch::try_new(schema.clone(), vec![Arc::new(arr1)]).unwrap();
let batch2 = RecordBatch::try_new(schema, vec![Arc::new(arr2)]).unwrap();

verify_flight_round_trip(vec![batch1, batch2]).await;
}

#[tokio::test]
async fn test_dictionary_of_struct_of_dict_resend() {
// Dict(Int8, Struct { dict: Dict(Int32, Utf8), int: Int32 })
// This exercises the Dictionary branch recursing into its value type
// before assigning its own dict_id (depth-first ordering).
let struct_fields: Vec<Field> = vec![
Field::new_dictionary("dict", DataType::Int32, DataType::Utf8, true),
Field::new("int", DataType::Int32, false),
];

let inner_values =
StringArray::from(vec![Some("alpha"), None, Some("beta"), Some("gamma")]);
let inner_keys = Int32Array::from_iter_values([0, 1, 2, 3, 0]);
let inner_dict = DictionaryArray::new(inner_keys, Arc::new(inner_values));
let int_array = Int32Array::from(vec![10, 20, 30, 40, 50]);

let struct_array = StructArray::from(vec![
(
Arc::new(struct_fields[0].clone()),
Arc::new(inner_dict) as ArrayRef,
),
(
Arc::new(struct_fields[1].clone()),
Arc::new(int_array) as ArrayRef,
),
]);

let outer_keys = Int8Array::from_iter_values([0, 0, 1, 2]);
let arr1 = DictionaryArray::new(outer_keys, Arc::new(struct_array));

let inner_values2 = StringArray::from(vec![Some("x"), Some("y")]);
let inner_keys2 = Int32Array::from_iter_values([0, 1, 0]);
let inner_dict2 = DictionaryArray::new(inner_keys2, Arc::new(inner_values2));
let int_array2 = Int32Array::from(vec![100, 200, 300]);

let struct_array2 = StructArray::from(vec![
(
Arc::new(struct_fields[0].clone()),
Arc::new(inner_dict2) as ArrayRef,
),
(
Arc::new(struct_fields[1].clone()),
Arc::new(int_array2) as ArrayRef,
),
]);

let outer_keys2 = Int8Array::from_iter_values([0, 1]);
let arr2 = DictionaryArray::new(outer_keys2, Arc::new(struct_array2));

let schema = Arc::new(Schema::new(vec![Field::new(
"dict_struct",
arr1.data_type().clone(),
false,
)]));

let batch1 = RecordBatch::try_new(schema.clone(), vec![Arc::new(arr1)]).unwrap();
let batch2 = RecordBatch::try_new(schema, vec![Arc::new(arr2)]).unwrap();

verify_flight_round_trip(vec![batch1, batch2]).await;
}

#[tokio::test]
async fn test_dictionary_list_view_resend() {
let mut builder =
GenericListViewBuilder::<i32, _>::new(StringDictionaryBuilder::<UInt16Type>::new());

builder.append_value(vec![Some("a"), None, Some("b")]);
let arr1 = builder.finish();

builder.append_value(vec![Some("c"), None, Some("d")]);
let arr2 = builder.finish();

let schema = Arc::new(Schema::new(vec![Field::new(
"dict_list_view",
DataType::ListView(Arc::new(Field::new_dictionary(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it may make sense to add a test for LargeListView as well

"item",
DataType::UInt16,
DataType::Utf8,
true,
))),
true,
)]));

let batch1 = RecordBatch::try_new(schema.clone(), vec![Arc::new(arr1)]).unwrap();
let batch2 = RecordBatch::try_new(schema, vec![Arc::new(arr2)]).unwrap();

verify_flight_round_trip(vec![batch1, batch2]).await;
}

#[tokio::test]
async fn test_dictionary_fixed_size_list_resend() {
let mut builder =
FixedSizeListBuilder::new(StringDictionaryBuilder::<UInt16Type>::new(), 2);

builder.values().append_value("a");
builder.values().append_value("b");
builder.append(true);
let arr1 = builder.finish();

builder.values().append_value("c");
builder.values().append_value("d");
builder.append(true);
let arr2 = builder.finish();

let schema = Arc::new(Schema::new(vec![Field::new_fixed_size_list(
"dict_fsl",
Field::new_dictionary("item", DataType::UInt16, DataType::Utf8, true),
2,
true,
)]));

let batch1 = RecordBatch::try_new(schema.clone(), vec![Arc::new(arr1)]).unwrap();
let batch2 = RecordBatch::try_new(schema, vec![Arc::new(arr2)]).unwrap();

verify_flight_round_trip(vec![batch1, batch2]).await;
}

async fn verify_flight_round_trip(mut batches: Vec<RecordBatch>) {
let expected_schema = batches.first().unwrap().schema();

Expand Down
7 changes: 6 additions & 1 deletion arrow-ipc/src/writer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -417,7 +417,12 @@ impl IpcDataGenerator {
// sequence is assigned depth-first, so we need to first encode children and have
// them take their assigned dict IDs before we take the dict ID for this field.
let dict_id = dict_id_seq.next().ok_or_else(|| {
ArrowError::IpcError(format!("no dict id for field {}", field.name()))
ArrowError::IpcError(format!(
"no dict id for field {:?}: field.data_type={:?}, column.data_type={:?}",
field.name(),
field.data_type(),
column.data_type()
))
})?;

match dictionary_tracker.insert_column(
Expand Down
Loading