diff --git a/Cargo.lock b/Cargo.lock index eee14643..00d503cb 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2630,6 +2630,7 @@ dependencies = [ "arrow 54.2.1", "chrono", "eyre", + "half", ] [[package]] @@ -4389,9 +4390,9 @@ dependencies = [ [[package]] name = "half" -version = "2.4.1" +version = "2.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6dd08c532ae367adf81c312a4580bc67f1d0fe8bc9c460520283f4c0ff277888" +checksum = "7db2ff139bba50379da6aa0766b52fdcb62cb5b263009b09ed58ba604e14bbd1" dependencies = [ "bytemuck", "cfg-if 1.0.0", diff --git a/libraries/arrow-convert/Cargo.toml b/libraries/arrow-convert/Cargo.toml index db43a61f..066c0bcc 100644 --- a/libraries/arrow-convert/Cargo.toml +++ b/libraries/arrow-convert/Cargo.toml @@ -12,4 +12,5 @@ repository.workspace = true [dependencies] arrow = { workspace = true } eyre = "0.6.8" +half = "2.5.0" chrono = "0.4.39" diff --git a/libraries/arrow-convert/src/from_impls.rs b/libraries/arrow-convert/src/from_impls.rs index f6ccb7e2..1f00ffc1 100644 --- a/libraries/arrow-convert/src/from_impls.rs +++ b/libraries/arrow-convert/src/from_impls.rs @@ -3,8 +3,8 @@ use arrow::{ datatypes::{ArrowPrimitiveType, ArrowTemporalType}, }; use chrono::{NaiveDate, NaiveDateTime, NaiveTime}; - use eyre::ContextCompat; +use half::f16; use crate::ArrowData; @@ -91,6 +91,7 @@ impl_try_from_arrow_data!( i16 => Int16Type, i32 => Int32Type, i64 => Int64Type, + f16 => Float16Type, f32 => Float32Type, f64 => Float64Type ); diff --git a/libraries/arrow-convert/src/into_impls.rs b/libraries/arrow-convert/src/into_impls.rs index f2a05fae..a8434694 100644 --- a/libraries/arrow-convert/src/into_impls.rs +++ b/libraries/arrow-convert/src/into_impls.rs @@ -1,10 +1,11 @@ use crate::IntoArrow; use arrow::array::{PrimitiveArray, StringArray, TimestampNanosecondArray}; use arrow::datatypes::{ - ArrowPrimitiveType, ArrowTimestampType, Float32Type, Float64Type, Int16Type, Int32Type, - Int64Type, Int8Type, UInt16Type, UInt32Type, UInt64Type, UInt8Type, + ArrowPrimitiveType, ArrowTimestampType, Float16Type, Float32Type, Float64Type, Int16Type, + Int32Type, Int64Type, Int8Type, UInt16Type, UInt32Type, UInt64Type, UInt8Type, }; use chrono::{NaiveDate, NaiveDateTime, NaiveTime}; +use half::f16; impl IntoArrow for bool { type A = arrow::array::BooleanArray; @@ -44,6 +45,7 @@ impl_into_arrow!( i16 => Int16Type, i32 => Int32Type, i64 => Int64Type, + f16 => Float16Type, f32 => Float32Type, f64 => Float64Type ); diff --git a/libraries/arrow-convert/tests/conversion_test.rs b/libraries/arrow-convert/tests/conversion_test.rs index 3c5f3dbb..cf288232 100644 --- a/libraries/arrow-convert/tests/conversion_test.rs +++ b/libraries/arrow-convert/tests/conversion_test.rs @@ -1,5 +1,6 @@ use chrono::{NaiveDate, NaiveDateTime, NaiveTime}; use dora_arrow_convert::{ArrowData, IntoArrow}; +use half::f16; use std::sync::Arc; #[cfg(test)] @@ -97,6 +98,16 @@ mod tests { Ok(()) } + #[test] + fn test_f16_round_trip() -> Result<(), Report> { + let value_f16: f16 = f16::from_f32(42.42); + let arrow_array = value_f16.into_arrow(); + let data: ArrowData = ArrowData(Arc::new(arrow_array)); + let result_f16: f16 = TryFrom::try_from(&data)?; + assert_eq!(value_f16, result_f16); + Ok(()) + } + #[test] fn test_f32_round_trip() -> Result<(), Report> { let value_f32: f32 = 42.42;