diff --git a/parquet/benches/arrow_writer.rs b/parquet/benches/arrow_writer.rs index b92f0788b2fc..b5b38b948bf1 100644 --- a/parquet/benches/arrow_writer.rs +++ b/parquet/benches/arrow_writer.rs @@ -33,6 +33,7 @@ use arrow::datatypes::*; use arrow::util::bench_util::{create_f16_array, create_f32_array, create_f64_array}; use arrow::{record_batch::RecordBatch, util::data_gen::*}; use arrow_array::RecordBatchOptions; +use arrow_array::builder::StringDictionaryBuilder; use parquet::arrow::ArrowSchemaConverter; use parquet::errors::Result; use parquet::file::properties::{WriterProperties, WriterVersion}; @@ -139,6 +140,43 @@ fn create_string_dictionary_bench_batch( )?) } +/// Creates a DictionaryArray with low cardinality (~15 unique values across +/// `size` rows). This simulates realistic categorical columns like status +/// codes, country codes, or ship modes where dictionary encoding excels. +fn create_string_dictionary_low_cardinality_bench_batch(size: usize) -> Result { + let categories = [ + "DELIVERED", + "SHIPPED", + "PENDING", + "CANCELLED", + "RETURNED", + "PROCESSING", + "ON_HOLD", + "REFUNDED", + "BACKORDERED", + "IN_TRANSIT", + "CONFIRMED", + "DISPATCHED", + "FAILED", + "COMPLETED", + "UNKNOWN", + ]; + let mut builder = StringDictionaryBuilder::::new(); + for i in 0..size { + builder.append_value(categories[i % categories.len()]); + } + let dict = builder.finish(); + let schema = Schema::new(vec![Field::new( + "_1", + DataType::Dictionary(Box::new(DataType::Int32), Box::new(DataType::Utf8)), + false, + )]); + Ok(RecordBatch::try_new( + Arc::new(schema), + vec![Arc::new(dict)], + )?) +} + fn create_string_bench_batch_non_null( size: usize, null_density: f32, @@ -399,6 +437,9 @@ fn create_batches() -> Vec<(&'static str, RecordBatch)> { let batch = create_string_dictionary_bench_batch(BATCH_SIZE, 0.25, 0.75).unwrap(); batches.push(("string_dictionary", batch)); + let batch = create_string_dictionary_low_cardinality_bench_batch(BATCH_SIZE).unwrap(); + batches.push(("string_dictionary_low_cardinality", batch)); + let batch = create_string_bench_batch_non_null(BATCH_SIZE, 0.25, 0.75).unwrap(); batches.push(("string_non_null", batch)); diff --git a/parquet/src/arrow/arrow_writer/byte_array.rs b/parquet/src/arrow/arrow_writer/byte_array.rs index 228d229b3088..74378e1fc7ba 100644 --- a/parquet/src/arrow/arrow_writer/byte_array.rs +++ b/parquet/src/arrow/arrow_writer/byte_array.rs @@ -32,6 +32,7 @@ use arrow_array::{ Array, ArrayAccessor, BinaryArray, BinaryViewArray, DictionaryArray, FixedSizeBinaryArray, LargeBinaryArray, LargeStringArray, StringArray, StringViewArray, }; +use arrow_buffer::ArrowNativeType; use arrow_schema::DataType; macro_rules! downcast_dict_impl { @@ -97,6 +98,65 @@ macro_rules! downcast_op { }; } +/// Dispatches to `encode_with_remap` for Dictionary types, providing the +/// `row_to_key` closure that maps row indices to dictionary key indices. +macro_rules! downcast_dict_remap { + ($key_type:expr, $val:ident, $array:ident, $op:expr, $indices:expr, $encoder:expr) => {{ + macro_rules! inner { + ($kt:ident) => {{ + let dict_array = $array + .as_any() + .downcast_ref::>() + .unwrap(); + let typed = dict_array.downcast_dict::<$val>().unwrap(); + let keys = dict_array.keys(); + let dict_len = dict_array.values().len(); + let row_to_key = |idx: usize| -> usize { keys.value(idx).as_usize() }; + $op(typed, $indices, $encoder, dict_len, &row_to_key) + }}; + } + match $key_type.as_ref() { + DataType::UInt8 => inner!(UInt8Type), + DataType::UInt16 => inner!(UInt16Type), + DataType::UInt32 => inner!(UInt32Type), + DataType::UInt64 => inner!(UInt64Type), + DataType::Int8 => inner!(Int8Type), + DataType::Int16 => inner!(Int16Type), + DataType::Int32 => inner!(Int32Type), + DataType::Int64 => inner!(Int64Type), + _ => unreachable!(), + } + }}; +} + +/// Macro that dispatches to `encode_with_remap` for Dictionary data types. +/// For non-Dictionary types, this macro should never be called. +macro_rules! downcast_op_remap { + ($data_type:expr, $array:ident, $op:expr, $indices:expr, $encoder:expr) => { + match $data_type { + DataType::Dictionary(key, value) => match value.as_ref() { + DataType::Utf8 => { + downcast_dict_remap!(key, StringArray, $array, $op, $indices, $encoder) + } + DataType::LargeUtf8 => { + downcast_dict_remap!(key, LargeStringArray, $array, $op, $indices, $encoder) + } + DataType::Binary => { + downcast_dict_remap!(key, BinaryArray, $array, $op, $indices, $encoder) + } + DataType::LargeBinary => { + downcast_dict_remap!(key, LargeBinaryArray, $array, $op, $indices, $encoder) + } + DataType::FixedSizeBinary(_) => { + downcast_dict_remap!(key, FixedSizeBinaryArray, $array, $op, $indices, $encoder) + } + d => unreachable!("cannot downcast {} dictionary value to byte array", d), + }, + d => unreachable!("downcast_op_remap called with non-dictionary type {}", d), + } + }; +} + /// A fallback encoder, i.e. non-dictionary, for [`ByteArray`] struct FallbackEncoder { encoder: FallbackEncoderImpl, @@ -356,6 +416,50 @@ impl DictEncoder { } } + /// Fast path for DictionaryArray input with a lazy remap table. + /// + /// Instead of interning each row's value individually (O(N) hash operations), + /// this method builds a lazy remap table of size O(D) where D is the number + /// of unique dictionary values actually referenced, then maps each row's key + /// through the remap table using a simple array index lookup. + /// + /// The `row_to_key` closure extracts the dictionary key (as usize) for a given + /// row index. This avoids allocating a separate `Vec` for the keys. + /// + /// The remap table uses `Vec>` with lazy population: values are + /// interned on first encounter and cached for subsequent rows. This ensures + /// only referenced dictionary values are interned, producing byte-identical + /// output to the per-row path. + fn encode_with_remap( + &mut self, + values: T, + indices: &[usize], + dict_len: usize, + row_to_key: F, + ) where + T: ArrayAccessor + Copy, + T::Item: AsRef<[u8]>, + F: Fn(usize) -> usize, + { + let mut remap: Vec> = vec![None; dict_len]; + + self.indices.extend(indices.iter().map(|&idx| { + let key = row_to_key(idx); + let interned = match remap[key] { + Some(cached) => cached, + None => { + let value = values.value(idx); + let fresh = self.interner.intern(value.as_ref()); + remap[key] = Some(fresh); + fresh + } + }; + let value = values.value(idx); + self.variable_length_bytes += value.as_ref().len() as i64; + interned + })); + } + fn bit_width(&self) -> u8 { let length = self.interner.storage().values.len(); num_required_bits(length.saturating_sub(1) as u64) @@ -468,6 +572,28 @@ impl ColumnValueEncoder for ByteArrayEncoder { } fn write_gather(&mut self, values: &Self::Values, indices: &[usize]) -> Result<()> { + // Fast path: when input is a DictionaryArray and dictionary encoding is + // enabled, use a remap-based approach that replaces O(N) hash operations + // with O(D) hash operations (D = unique dictionary values) plus O(N) + // simple array index lookups. Only used when D < N/2 (low cardinality), + // as the remap table overhead is not worthwhile for high-cardinality + // dictionaries. + if let DataType::Dictionary(key_type, _value_type) = values.data_type() { + if self.dict_encoder.is_some() && self.geo_stats_accumulator.is_none() { + let dict_len = get_dict_len(values, key_type); + if dict_len <= indices.len() / 2 { + downcast_op_remap!( + values.data_type(), + values, + encode_with_remap, + indices, + self + ); + return Ok(()); + } + } + } + downcast_op!(values.data_type(), values, encode, indices, self); Ok(()) } @@ -584,6 +710,84 @@ where } } +/// Get the dictionary length from a DictionaryArray, dispatching on key type. +fn get_dict_len(values: &dyn Array, key_type: &DataType) -> usize { + macro_rules! get_len { + ($kt:ident) => { + values + .as_any() + .downcast_ref::>() + .unwrap() + .values() + .len() + }; + } + match key_type { + DataType::Int8 => get_len!(Int8Type), + DataType::Int16 => get_len!(Int16Type), + DataType::Int32 => get_len!(Int32Type), + DataType::Int64 => get_len!(Int64Type), + DataType::UInt8 => get_len!(UInt8Type), + DataType::UInt16 => get_len!(UInt16Type), + DataType::UInt32 => get_len!(UInt32Type), + DataType::UInt64 => get_len!(UInt64Type), + _ => unreachable!(), + } +} + +/// Encodes dictionary array values using a remap-based fast path. +/// +/// This is equivalent to [`encode`] but optimizes the dictionary encoding step: +/// instead of O(N) hash operations, it uses O(D) hash operations where D is +/// the number of unique dictionary values, plus O(N) simple array lookups. +/// +/// Called via `downcast_op!` which dispatches to the appropriate +/// `TypedDictionaryArray` for Dictionary types. The `TypedDictionaryArray` +/// implements `ArrayAccessor`, which transparently resolves dictionary keys +/// to values. +/// +/// The `row_to_key` closure extracts the dictionary key (as usize) for a given +/// row index. This is provided by the `downcast_dict_remap_op!` macro. +fn encode_with_remap( + values: T, + indices: &[usize], + encoder: &mut ByteArrayEncoder, + dict_len: usize, + row_to_key: &dyn Fn(usize) -> usize, +) where + T: ArrayAccessor + Copy, + T::Item: Copy + Ord + AsRef<[u8]>, +{ + // Statistics: use existing per-row computation for correctness + if encoder.statistics_enabled != EnabledStatistics::None { + // geo_stats_accumulator is guaranteed None (checked in write_gather) + if let Some((min, max)) = compute_min_max(values, indices.iter().cloned()) { + if encoder.min_value.as_ref().is_none_or(|m| m > &min) { + encoder.min_value = Some(min); + } + if encoder.max_value.as_ref().is_none_or(|m| m < &max) { + encoder.max_value = Some(max); + } + } + } + + // Bloom filter: O(D) insertion using seen-tracking per dictionary key + if let Some(bloom_filter) = &mut encoder.bloom_filter { + let mut seen = vec![false; dict_len]; + for &idx in indices { + let key = row_to_key(idx); + if !seen[key] { + seen[key] = true; + bloom_filter.insert(values.value(idx).as_ref()); + } + } + } + + // Dictionary encoding: remap-based fast path + let dict_encoder = encoder.dict_encoder.as_mut().unwrap(); + dict_encoder.encode_with_remap(values, indices, dict_len, row_to_key); +} + /// Computes the min and max for the provided array and indices /// /// This is a free function so it can be used with `downcast_op!` @@ -624,3 +828,358 @@ fn update_geo_stats_accumulator( } } } + +#[cfg(test)] +mod tests { + use std::sync::Arc; + + use arrow_array::builder::StringDictionaryBuilder; + use arrow_array::cast::AsArray; + use arrow_array::types::Int32Type; + use arrow_array::{Array, ArrayAccessor, DictionaryArray, RecordBatch, StringArray}; + use arrow_schema::{DataType, Field, Schema}; + use bytes::Bytes; + + use crate::arrow::ArrowWriter; + use crate::arrow::arrow_reader::ParquetRecordBatchReaderBuilder; + use crate::file::properties::WriterProperties; + + /// Write a single RecordBatch to Parquet bytes using the given properties. + fn write_batch_to_bytes(batch: &RecordBatch, props: Option) -> Bytes { + let mut buf = Vec::new(); + let mut writer = ArrowWriter::try_new(&mut buf, batch.schema(), props).unwrap(); + writer.write(batch).unwrap(); + writer.close().unwrap(); + buf.into() + } + + /// Read all rows from Parquet bytes as RecordBatches. + fn read_batches_from_bytes(data: &Bytes) -> Vec { + let reader = ParquetRecordBatchReaderBuilder::try_new(data.clone()) + .unwrap() + .build() + .unwrap(); + reader.collect::, _>>().unwrap() + } + + /// Extract string values from a column, handling both StringArray and + /// DictionaryArray transparently. + fn column_to_strings(col: &dyn Array) -> Vec> { + match col.data_type() { + DataType::Utf8 => { + let sa = col.as_string::(); + (0..sa.len()) + .map(|i| { + if sa.is_null(i) { + None + } else { + Some(sa.value(i).to_string()) + } + }) + .collect() + } + DataType::Dictionary(_, _) => { + let da = col + .as_any() + .downcast_ref::>() + .unwrap(); + let typed = da.downcast_dict::().unwrap(); + (0..col.len()) + .map(|i| { + if col.is_null(i) { + None + } else { + Some(typed.value(i).to_string()) + } + }) + .collect() + } + other => panic!("Unexpected data type: {other}"), + } + } + + // T1: Data equivalence (DictionaryArray vs StringArray) + // + // The Parquet files differ in Arrow schema metadata (Utf8 vs Dictionary), + // but the data pages, dictionary pages, and column statistics must match. + #[test] + fn test_dict_passthrough_data_equivalence() { + use crate::file::reader::FileReader; + use crate::file::serialized_reader::SerializedFileReader; + + let strings = vec!["alpha", "beta", "alpha", "gamma", "beta"]; + + // Plain StringArray + let plain = StringArray::from(strings.clone()); + let plain_schema = Arc::new(Schema::new(vec![Field::new("col", DataType::Utf8, false)])); + let plain_batch = RecordBatch::try_new(plain_schema, vec![Arc::new(plain)]).unwrap(); + + // DictionaryArray with the same data + let dict: DictionaryArray = strings.into_iter().collect(); + let dict_schema = Arc::new(Schema::new(vec![Field::new( + "col", + DataType::Dictionary(Box::new(DataType::Int32), Box::new(DataType::Utf8)), + false, + )])); + let dict_batch = RecordBatch::try_new(dict_schema, vec![Arc::new(dict)]).unwrap(); + + let plain_bytes = write_batch_to_bytes(&plain_batch, None); + let dict_bytes = write_batch_to_bytes(&dict_batch, None); + + // Compare column chunk metadata + let plain_reader = SerializedFileReader::new(plain_bytes.clone()).unwrap(); + let dict_reader = SerializedFileReader::new(dict_bytes.clone()).unwrap(); + + let plain_meta = plain_reader.metadata().row_group(0).column(0); + let dict_meta = dict_reader.metadata().row_group(0).column(0); + + assert_eq!(plain_meta.statistics(), dict_meta.statistics()); + assert_eq!(plain_meta.num_values(), dict_meta.num_values()); + assert_eq!(plain_meta.compressed_size(), dict_meta.compressed_size()); + assert_eq!( + plain_meta.uncompressed_size(), + dict_meta.uncompressed_size() + ); + + // Verify both read back the same logical values + let pb = read_batches_from_bytes(&plain_bytes); + let db = read_batches_from_bytes(&dict_bytes); + let plain_vals = column_to_strings(pb[0].column(0).as_ref()); + let dict_vals = column_to_strings(db[0].column(0).as_ref()); + assert_eq!(plain_vals, dict_vals); + } + + // T2: Roundtrip DictionaryArray -> read back -> verify values + #[test] + fn test_dict_passthrough_roundtrip() { + let strings = ["hello", "world", "hello", "foo", "world", "bar"]; + let dict: DictionaryArray = strings.iter().copied().collect(); + let schema = Arc::new(Schema::new(vec![Field::new( + "col", + DataType::Dictionary(Box::new(DataType::Int32), Box::new(DataType::Utf8)), + false, + )])); + let batch = RecordBatch::try_new(schema, vec![Arc::new(dict)]).unwrap(); + + let bytes = write_batch_to_bytes(&batch, None); + let batches = read_batches_from_bytes(&bytes); + + assert_eq!(batches.len(), 1); + let vals = column_to_strings(batches[0].column(0).as_ref()); + let expected: Vec> = strings.iter().map(|s| Some(s.to_string())).collect(); + assert_eq!(vals, expected); + } + + // T3: Roundtrip DictionaryArray -> verify values match + #[test] + fn test_dict_passthrough_roundtrip_to_plain() { + let strings = ["cat", "dog", "cat", "bird"]; + let dict: DictionaryArray = strings.iter().copied().collect(); + let schema = Arc::new(Schema::new(vec![Field::new( + "col", + DataType::Dictionary(Box::new(DataType::Int32), Box::new(DataType::Utf8)), + false, + )])); + let batch = RecordBatch::try_new(schema, vec![Arc::new(dict)]).unwrap(); + + let bytes = write_batch_to_bytes(&batch, None); + let batches = read_batches_from_bytes(&bytes); + + let vals = column_to_strings(batches[0].column(0).as_ref()); + assert_eq!( + vals, + vec![ + Some("cat".into()), + Some("dog".into()), + Some("cat".into()), + Some("bird".into()), + ] + ); + } + + // T4: DictionaryArray with null keys + #[test] + fn test_dict_passthrough_null_keys() { + let mut builder = StringDictionaryBuilder::::new(); + builder.append_value("alpha"); + builder.append_null(); + builder.append_value("beta"); + builder.append_null(); + builder.append_value("alpha"); + let dict = builder.finish(); + + let schema = Arc::new(Schema::new(vec![Field::new( + "col", + DataType::Dictionary(Box::new(DataType::Int32), Box::new(DataType::Utf8)), + true, + )])); + let batch = RecordBatch::try_new(schema, vec![Arc::new(dict)]).unwrap(); + + let bytes = write_batch_to_bytes(&batch, None); + let batches = read_batches_from_bytes(&bytes); + + let vals = column_to_strings(batches[0].column(0).as_ref()); + assert_eq!( + vals, + vec![ + Some("alpha".into()), + None, + Some("beta".into()), + None, + Some("alpha".into()), + ] + ); + } + + // T5: Mixed batches (DictionaryArray then StringArray for same column writer) + #[test] + fn test_dict_passthrough_mixed_batches() { + // First batch: DictionaryArray + let dict: DictionaryArray = vec!["aaa", "bbb", "aaa"].into_iter().collect(); + let dict_schema = Arc::new(Schema::new(vec![Field::new( + "col", + DataType::Dictionary(Box::new(DataType::Int32), Box::new(DataType::Utf8)), + false, + )])); + let dict_batch = RecordBatch::try_new(dict_schema, vec![Arc::new(dict)]).unwrap(); + + // Second batch: plain StringArray (same logical column) + let plain = StringArray::from(vec!["ccc", "bbb"]); + let plain_schema = Arc::new(Schema::new(vec![Field::new("col", DataType::Utf8, false)])); + let plain_batch = RecordBatch::try_new(plain_schema, vec![Arc::new(plain)]).unwrap(); + + // Write both batches to same writer using the dict schema + let mut buf = Vec::new(); + let mut writer = ArrowWriter::try_new(&mut buf, dict_batch.schema(), None).unwrap(); + writer.write(&dict_batch).unwrap(); + writer.write(&plain_batch).unwrap(); + writer.close().unwrap(); + + let bytes: Bytes = buf.into(); + let batches = read_batches_from_bytes(&bytes); + + let mut all_values = Vec::new(); + for b in &batches { + all_values.extend(column_to_strings(b.column(0).as_ref())); + } + assert_eq!( + all_values, + vec![ + Some("aaa".into()), + Some("bbb".into()), + Some("aaa".into()), + Some("ccc".into()), + Some("bbb".into()), + ] + ); + } + + // T6: Multiple row groups with DictionaryArray input + #[test] + fn test_dict_passthrough_multiple_row_groups() { + let strings1: DictionaryArray = vec!["x", "y", "z", "x"].into_iter().collect(); + let strings2: DictionaryArray = vec!["a", "b", "a", "c"].into_iter().collect(); + + let schema = Arc::new(Schema::new(vec![Field::new( + "col", + DataType::Dictionary(Box::new(DataType::Int32), Box::new(DataType::Utf8)), + false, + )])); + + let batch1 = RecordBatch::try_new(schema.clone(), vec![Arc::new(strings1)]).unwrap(); + let batch2 = RecordBatch::try_new(schema.clone(), vec![Arc::new(strings2)]).unwrap(); + + // Force each batch into its own row group + let props = WriterProperties::builder() + .set_max_row_group_row_count(Some(4)) + .build(); + + let mut buf = Vec::new(); + let mut writer = ArrowWriter::try_new(&mut buf, schema, Some(props)).unwrap(); + writer.write(&batch1).unwrap(); + writer.write(&batch2).unwrap(); + writer.close().unwrap(); + + let bytes: Bytes = buf.into(); + let batches = read_batches_from_bytes(&bytes); + + let mut all_values: Vec> = Vec::new(); + for b in &batches { + all_values.extend(column_to_strings(b.column(0).as_ref())); + } + let expected: Vec> = vec!["x", "y", "z", "x", "a", "b", "a", "c"] + .into_iter() + .map(|s| Some(s.to_string())) + .collect(); + assert_eq!(all_values, expected); + } + + // T7: Statistics correctness — same data as Dict and Plain should produce same stats + #[test] + fn test_dict_passthrough_statistics_correctness() { + use crate::file::reader::FileReader; + use crate::file::serialized_reader::SerializedFileReader; + + let strings = vec!["cherry", "apple", "banana", "apple", "cherry"]; + + // Plain + let plain = StringArray::from(strings.clone()); + let plain_schema = Arc::new(Schema::new(vec![Field::new("col", DataType::Utf8, false)])); + let plain_batch = RecordBatch::try_new(plain_schema, vec![Arc::new(plain)]).unwrap(); + + // Dict + let dict: DictionaryArray = strings.into_iter().collect(); + let dict_schema = Arc::new(Schema::new(vec![Field::new( + "col", + DataType::Dictionary(Box::new(DataType::Int32), Box::new(DataType::Utf8)), + false, + )])); + let dict_batch = RecordBatch::try_new(dict_schema, vec![Arc::new(dict)]).unwrap(); + + let plain_bytes = write_batch_to_bytes(&plain_batch, None); + let dict_bytes = write_batch_to_bytes(&dict_batch, None); + + // Compare metadata statistics + let plain_reader = SerializedFileReader::new(plain_bytes).unwrap(); + let dict_reader = SerializedFileReader::new(dict_bytes).unwrap(); + + let plain_meta = plain_reader.metadata().row_group(0).column(0); + let dict_meta = dict_reader.metadata().row_group(0).column(0); + + assert_eq!( + plain_meta.statistics(), + dict_meta.statistics(), + "Statistics must match between plain and dictionary paths" + ); + } + + // T8: High cardinality dictionary that may trigger fallback + #[test] + fn test_dict_passthrough_high_cardinality() { + // Create a dictionary with many unique values + let values: Vec = (0..5000).map(|i| format!("value_{i:06}")).collect(); + let dict: DictionaryArray = values.iter().map(|s| s.as_str()).collect(); + + let schema = Arc::new(Schema::new(vec![Field::new( + "col", + DataType::Dictionary(Box::new(DataType::Int32), Box::new(DataType::Utf8)), + false, + )])); + let batch = RecordBatch::try_new(schema, vec![Arc::new(dict)]).unwrap(); + + // Use small dictionary page size to potentially trigger fallback + let props = WriterProperties::builder() + .set_dictionary_page_size_limit(1024) + .build(); + let bytes = write_batch_to_bytes(&batch, Some(props)); + + // Verify roundtrip correctness regardless of fallback or row group splitting + let batches = read_batches_from_bytes(&bytes); + let mut all_values: Vec> = Vec::new(); + for b in &batches { + all_values.extend(column_to_strings(b.column(0).as_ref())); + } + let expected: Vec> = values.iter().map(|s| Some(s.clone())).collect(); + assert_eq!(all_values, expected); + } +} diff --git a/parquet/tests/arrow_writer_dictionary.rs b/parquet/tests/arrow_writer_dictionary.rs new file mode 100644 index 000000000000..f2be7dd167f7 --- /dev/null +++ b/parquet/tests/arrow_writer_dictionary.rs @@ -0,0 +1,332 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Integration tests for writing DictionaryArray columns through `ArrowWriter` +//! and reading them back via `ParquetRecordBatchReader`. These tests exercise +//! the full public API write-read roundtrip. + +#![cfg(feature = "arrow")] + +use arrow_array::builder::StringDictionaryBuilder; +use arrow_array::cast::AsArray; +use arrow_array::types::Int32Type; +use arrow_array::{Array, ArrayAccessor, DictionaryArray, RecordBatch, StringArray}; +use arrow_schema::{DataType, Field, Schema}; +use bytes::Bytes; +use parquet::arrow::ArrowWriter; +use parquet::arrow::arrow_reader::ParquetRecordBatchReaderBuilder; +use parquet::file::properties::WriterProperties; +use std::sync::Arc; + +/// Helper: write a single RecordBatch to an in-memory Parquet buffer. +fn write_to_parquet(batch: &RecordBatch, props: Option) -> Bytes { + let mut buf = Vec::new(); + let mut writer = ArrowWriter::try_new(&mut buf, batch.schema(), props).unwrap(); + writer.write(batch).unwrap(); + writer.close().unwrap(); + Bytes::from(buf) +} + +/// Helper: read all RecordBatches from an in-memory Parquet buffer. +fn read_from_parquet(data: Bytes) -> Vec { + let reader = ParquetRecordBatchReaderBuilder::try_new(data) + .unwrap() + .build() + .unwrap(); + reader.collect::, _>>().unwrap() +} + +/// Helper: extract string values from a column, handling both StringArray and +/// DictionaryArray transparently. +fn extract_strings(col: &dyn Array) -> Vec> { + if let Some(sa) = col.as_any().downcast_ref::() { + (0..sa.len()) + .map(|i| { + if sa.is_null(i) { + None + } else { + Some(sa.value(i).to_string()) + } + }) + .collect() + } else { + let da = col.as_dictionary::(); + let typed = da.downcast_dict::().unwrap(); + (0..da.len()) + .map(|i| { + if da.is_null(i) { + None + } else { + // TypedDictionaryArray::value() already resolves keys + Some(typed.value(i).to_string()) + } + }) + .collect() + } +} + +// --------------------------------------------------------------------------- +// Test 1: Basic DictionaryArray roundtrip +// --------------------------------------------------------------------------- +#[test] +fn dictionary_roundtrip_low_cardinality() { + // 15 unique values across 4096 rows — triggers the remap optimization path + let categories = [ + "DELIVERED", + "SHIPPED", + "PENDING", + "CANCELLED", + "RETURNED", + "PROCESSING", + "ON_HOLD", + "REFUNDED", + "BACKORDERED", + "IN_TRANSIT", + "CONFIRMED", + "DISPATCHED", + "FAILED", + "COMPLETED", + "UNKNOWN", + ]; + let mut builder = StringDictionaryBuilder::::new(); + let mut expected = Vec::with_capacity(4096); + for i in 0..4096 { + let val = categories[i % categories.len()]; + builder.append_value(val); + expected.push(Some(val.to_string())); + } + let dict = builder.finish(); + + let schema = Arc::new(Schema::new(vec![Field::new( + "status", + DataType::Dictionary(Box::new(DataType::Int32), Box::new(DataType::Utf8)), + false, + )])); + let batch = RecordBatch::try_new(schema, vec![Arc::new(dict)]).unwrap(); + + let data = write_to_parquet(&batch, None); + let batches = read_from_parquet(data); + + let mut actual = Vec::new(); + for b in &batches { + actual.extend(extract_strings(b.column(0))); + } + assert_eq!(actual, expected); +} + +// --------------------------------------------------------------------------- +// Test 2: DictionaryArray + plain StringArray mix in the same RecordBatch +// --------------------------------------------------------------------------- +#[test] +fn dictionary_and_plain_columns_roundtrip() { + // Column 1: DictionaryArray (low cardinality) + let dict: DictionaryArray = vec!["US", "CA", "MX", "US", "CA"].into_iter().collect(); + + // Column 2: plain StringArray + let plain = StringArray::from(vec!["Alice", "Bob", "Carol", "Dave", "Eve"]); + + let schema = Arc::new(Schema::new(vec![ + Field::new( + "country", + DataType::Dictionary(Box::new(DataType::Int32), Box::new(DataType::Utf8)), + false, + ), + Field::new("name", DataType::Utf8, false), + ])); + let batch = RecordBatch::try_new(schema, vec![Arc::new(dict), Arc::new(plain)]).unwrap(); + + let data = write_to_parquet(&batch, None); + let batches = read_from_parquet(data); + + // Verify both columns + let mut countries = Vec::new(); + let mut names = Vec::new(); + for b in &batches { + countries.extend(extract_strings(b.column(0))); + names.extend(extract_strings(b.column(1))); + } + assert_eq!( + countries, + vec![ + Some("US".into()), + Some("CA".into()), + Some("MX".into()), + Some("US".into()), + Some("CA".into()), + ] + ); + assert_eq!( + names, + vec![ + Some("Alice".into()), + Some("Bob".into()), + Some("Carol".into()), + Some("Dave".into()), + Some("Eve".into()), + ] + ); +} + +// --------------------------------------------------------------------------- +// Test 3: DictionaryArray with correct statistics (min/max) +// --------------------------------------------------------------------------- +#[test] +fn dictionary_statistics_match_plain() { + use parquet::file::reader::FileReader; + use parquet::file::serialized_reader::SerializedFileReader; + + let values = vec!["cherry", "apple", "banana", "apple", "cherry"]; + + // Write as plain StringArray + let plain = StringArray::from(values.clone()); + let plain_schema = Arc::new(Schema::new(vec![Field::new("col", DataType::Utf8, false)])); + let plain_batch = RecordBatch::try_new(plain_schema, vec![Arc::new(plain)]).unwrap(); + let plain_data = write_to_parquet(&plain_batch, None); + + // Write as DictionaryArray + let dict: DictionaryArray = values.into_iter().collect(); + let dict_schema = Arc::new(Schema::new(vec![Field::new( + "col", + DataType::Dictionary(Box::new(DataType::Int32), Box::new(DataType::Utf8)), + false, + )])); + let dict_batch = RecordBatch::try_new(dict_schema, vec![Arc::new(dict)]).unwrap(); + let dict_data = write_to_parquet(&dict_batch, None); + + // Compare statistics from both files + let plain_reader = SerializedFileReader::new(plain_data).unwrap(); + let dict_reader = SerializedFileReader::new(dict_data).unwrap(); + + let plain_stats = plain_reader + .metadata() + .row_group(0) + .column(0) + .statistics() + .unwrap() + .clone(); + let dict_stats = dict_reader + .metadata() + .row_group(0) + .column(0) + .statistics() + .unwrap() + .clone(); + + assert_eq!( + format!("{plain_stats:?}"), + format!("{dict_stats:?}"), + "Statistics from DictionaryArray path must match plain StringArray path" + ); +} + +// --------------------------------------------------------------------------- +// Test 4: Multi-row-group with DictionaryArray +// --------------------------------------------------------------------------- +#[test] +fn dictionary_multi_row_group_roundtrip() { + let batch1_values: DictionaryArray = vec!["alpha", "beta", "gamma", "alpha"] + .into_iter() + .collect(); + let batch2_values: DictionaryArray = vec!["delta", "epsilon", "delta", "gamma"] + .into_iter() + .collect(); + + let schema = Arc::new(Schema::new(vec![Field::new( + "col", + DataType::Dictionary(Box::new(DataType::Int32), Box::new(DataType::Utf8)), + false, + )])); + + let batch1 = RecordBatch::try_new(schema.clone(), vec![Arc::new(batch1_values)]).unwrap(); + let batch2 = RecordBatch::try_new(schema.clone(), vec![Arc::new(batch2_values)]).unwrap(); + + // Force each batch into its own row group + let props = WriterProperties::builder() + .set_max_row_group_row_count(Some(4)) + .build(); + + let mut buf = Vec::new(); + let mut writer = ArrowWriter::try_new(&mut buf, schema, Some(props)).unwrap(); + writer.write(&batch1).unwrap(); + writer.write(&batch2).unwrap(); + writer.close().unwrap(); + + let data = Bytes::from(buf); + + // Verify row group count + let reader_builder = ParquetRecordBatchReaderBuilder::try_new(data.clone()).unwrap(); + assert_eq!( + reader_builder.metadata().num_row_groups(), + 2, + "Expected 2 row groups" + ); + + // Verify all data + let batches = read_from_parquet(data); + let mut all_values = Vec::new(); + for b in &batches { + all_values.extend(extract_strings(b.column(0))); + } + let expected: Vec> = vec![ + "alpha", "beta", "gamma", "alpha", "delta", "epsilon", "delta", "gamma", + ] + .into_iter() + .map(|s| Some(s.to_string())) + .collect(); + assert_eq!(all_values, expected); +} + +// --------------------------------------------------------------------------- +// Test 5: DictionaryArray with null values roundtrip +// --------------------------------------------------------------------------- +#[test] +fn dictionary_with_nulls_roundtrip() { + let mut builder = StringDictionaryBuilder::::new(); + builder.append_value("red"); + builder.append_null(); + builder.append_value("blue"); + builder.append_null(); + builder.append_value("red"); + builder.append_value("green"); + let dict = builder.finish(); + + let schema = Arc::new(Schema::new(vec![Field::new( + "color", + DataType::Dictionary(Box::new(DataType::Int32), Box::new(DataType::Utf8)), + true, + )])); + let batch = RecordBatch::try_new(schema, vec![Arc::new(dict)]).unwrap(); + + let data = write_to_parquet(&batch, None); + let batches = read_from_parquet(data); + + let mut actual = Vec::new(); + for b in &batches { + actual.extend(extract_strings(b.column(0))); + } + assert_eq!( + actual, + vec![ + Some("red".into()), + None, + Some("blue".into()), + None, + Some("red".into()), + Some("green".into()), + ] + ); +}