Skip to content

Commit

Permalink
Stop losing precision and scale when casting decimal to dictionary
Browse files Browse the repository at this point in the history
  • Loading branch information
andygrove committed Sep 11, 2024
1 parent f050ff7 commit 26095fc
Show file tree
Hide file tree
Showing 2 changed files with 60 additions and 2 deletions.
30 changes: 28 additions & 2 deletions arrow-cast/src/cast/dictionary.rs
Original file line number Diff line number Diff line change
Expand Up @@ -203,10 +203,36 @@ pub(crate) fn cast_to_dictionary<K: ArrowDictionaryKeyType>(
UInt32 => pack_numeric_to_dictionary::<K, UInt32Type>(array, dict_value_type, cast_options),
UInt64 => pack_numeric_to_dictionary::<K, UInt64Type>(array, dict_value_type, cast_options),
Decimal128(_, _) => {
pack_numeric_to_dictionary::<K, Decimal128Type>(array, dict_value_type, cast_options)
// pack_numeric_to_dictionary loses the precision and scale so we have to perform a
// second cast
let decimal_dict_max_precision_scale = pack_numeric_to_dictionary::<K, Decimal128Type>(
array,
dict_value_type,
cast_options,
)?;
let expected_type =
Dictionary(Box::new(K::DATA_TYPE), Box::new(dict_value_type.clone()));
cast_with_options(
&decimal_dict_max_precision_scale,
&expected_type,
cast_options,
)
}
Decimal256(_, _) => {
pack_numeric_to_dictionary::<K, Decimal256Type>(array, dict_value_type, cast_options)
// pack_numeric_to_dictionary loses the precision and scale so we have to perform a
// second cast
let decimal_dict_max_precision_scale = pack_numeric_to_dictionary::<K, Decimal256Type>(
array,
dict_value_type,
cast_options,
)?;
let expected_type =
Dictionary(Box::new(K::DATA_TYPE), Box::new(dict_value_type.clone()));
cast_with_options(
&decimal_dict_max_precision_scale,
&expected_type,
cast_options,
)
}
Float16 => {
pack_numeric_to_dictionary::<K, Float16Type>(array, dict_value_type, cast_options)
Expand Down
32 changes: 32 additions & 0 deletions arrow-cast/src/cast/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2650,6 +2650,38 @@ mod tests {
err.unwrap_err().to_string());
}

#[test]
fn test_cast_decimal128_to_decimal128_dict() {
let p = 20;
let s = 3;
let input_type = DataType::Decimal128(p, s);
let output_type = DataType::Dictionary(
Box::new(DataType::Int32),
Box::new(DataType::Decimal128(p, s)),
);
assert!(can_cast_types(&input_type, &output_type));
let array = vec![Some(1123456), Some(2123456), Some(3123456), None];
let array = create_decimal_array(array, p, s).unwrap();
let cast_array = cast_with_options(&array, &output_type, &CastOptions::default()).unwrap();
assert_eq!(cast_array.data_type(), &output_type);
}

#[test]
fn test_cast_decimal256_to_decimal256_dict() {
let p = 20;
let s = 3;
let input_type = DataType::Decimal256(p, s);
let output_type = DataType::Dictionary(
Box::new(DataType::Int32),
Box::new(DataType::Decimal256(p, s)),
);
assert!(can_cast_types(&input_type, &output_type));
let array = vec![Some(1123456), Some(2123456), Some(3123456), None];
let array = create_decimal_array(array, p, s).unwrap();
let cast_array = cast_with_options(&array, &output_type, &CastOptions::default()).unwrap();
assert_eq!(cast_array.data_type(), &output_type);
}

#[test]
fn test_cast_decimal128_to_decimal128_overflow() {
let input_type = DataType::Decimal128(38, 3);
Expand Down

0 comments on commit 26095fc

Please sign in to comment.