@@ -4,9 +4,9 @@ use arrow::{
44 array:: { Array , ArrayRef , StructArray } ,
55 compute:: concat,
66} ;
7- use arrow_schema:: { DataType , Field , Fields } ;
7+ use arrow_schema:: { ArrowError , DataType , Field , FieldRef , Fields } ;
88use datafusion:: {
9- common:: { exec_datafusion_err, exec_err} ,
9+ common:: { arrow_datafusion_err , exec_datafusion_err, exec_err} ,
1010 error:: { DataFusionError , Result } ,
1111 logical_expr:: {
1212 ColumnarValue , ReturnFieldArgs , ScalarUDFImpl , Signature , TypeSignature , Volatility ,
@@ -20,6 +20,44 @@ use crate::shared::{
2020 try_field_as_variant_array, try_parse_string_columnar, try_parse_string_scalar,
2121} ;
2222
23+ fn type_hint_from_scalar ( field_name : & str , scalar : & ScalarValue ) -> Result < FieldRef > {
24+ let type_name = match scalar {
25+ ScalarValue :: Utf8 ( Some ( value) )
26+ | ScalarValue :: Utf8View ( Some ( value) )
27+ | ScalarValue :: LargeUtf8 ( Some ( value) ) => value. as_str ( ) ,
28+ other => {
29+ return exec_err ! (
30+ "type hint must be a non-null UTF8 literal, got {}" ,
31+ other. data_type( )
32+ ) ;
33+ }
34+ } ;
35+
36+ let casted_type = match type_name. parse :: < DataType > ( ) {
37+ Ok ( data_type) => Ok ( data_type) ,
38+ Err ( ArrowError :: ParseError ( e) ) => Err ( exec_datafusion_err ! ( "{e}" ) ) ,
39+ Err ( e) => Err ( arrow_datafusion_err ! ( e) ) ,
40+ } ?;
41+
42+ Ok ( Arc :: new ( Field :: new ( field_name, casted_type, true ) ) )
43+ }
44+
45+ fn type_hint_from_value ( field_name : & str , arg : & ColumnarValue ) -> Result < FieldRef > {
46+ match arg {
47+ ColumnarValue :: Scalar ( value) => type_hint_from_scalar ( field_name, value) ,
48+ ColumnarValue :: Array ( _) => {
49+ exec_err ! ( "type hint argument must be a scalar UTF8 literal" )
50+ }
51+ }
52+ }
53+
54+ fn build_get_options < ' a > ( path : VariantPath < ' a > , as_type : & Option < FieldRef > ) -> GetOptions < ' a > {
55+ match as_type {
56+ Some ( field) => GetOptions :: new_with_path ( path) . with_as_type ( Some ( field. clone ( ) ) ) ,
57+ None => GetOptions :: new_with_path ( path) ,
58+ }
59+ }
60+
2361#[ derive( Debug , Hash , PartialEq , Eq ) ]
2462pub struct VariantGetUdf {
2563 signature : Signature ,
@@ -28,7 +66,10 @@ pub struct VariantGetUdf {
2866impl Default for VariantGetUdf {
2967 fn default ( ) -> Self {
3068 Self {
31- signature : Signature :: new ( TypeSignature :: Any ( 2 ) , Volatility :: Immutable ) ,
69+ signature : Signature :: new (
70+ TypeSignature :: OneOf ( vec ! [ TypeSignature :: Any ( 2 ) , TypeSignature :: Any ( 3 ) ] ) ,
71+ Volatility :: Immutable ,
72+ ) ,
3273 }
3374 }
3475}
@@ -52,7 +93,14 @@ impl ScalarUDFImpl for VariantGetUdf {
5293 ) )
5394 }
5495
55- fn return_field_from_args ( & self , _args : ReturnFieldArgs ) -> Result < Arc < Field > > {
96+ fn return_field_from_args ( & self , args : ReturnFieldArgs ) -> Result < Arc < Field > > {
97+ if let Some ( maybe_scalar) = args. scalar_arguments . get ( 2 ) {
98+ let scalar = maybe_scalar. ok_or_else ( || {
99+ exec_datafusion_err ! ( "type hint argument to variant_get must be a literal" )
100+ } ) ?;
101+ return type_hint_from_scalar ( self . name ( ) , scalar) ;
102+ }
103+
56104 let data_type = DataType :: Struct ( Fields :: from ( vec ! [
57105 Field :: new( "metadata" , DataType :: BinaryView , false ) ,
58106 Field :: new( "value" , DataType :: BinaryView , true ) ,
@@ -67,8 +115,10 @@ impl ScalarUDFImpl for VariantGetUdf {
67115 & self ,
68116 args : datafusion:: logical_expr:: ScalarFunctionArgs ,
69117 ) -> Result < ColumnarValue > {
70- let [ variant_arg, variant_path] = args. args . as_slice ( ) else {
71- return exec_err ! ( "expected 2 arguments" ) ;
118+ let ( variant_arg, variant_path, type_arg) = match args. args . as_slice ( ) {
119+ [ variant_arg, variant_path] => ( variant_arg, variant_path, None ) ,
120+ [ variant_arg, variant_path, type_arg] => ( variant_arg, variant_path, Some ( type_arg) ) ,
121+ _ => return exec_err ! ( "expected 2 or 3 arguments" ) ,
72122 } ;
73123
74124 let variant_field = args
@@ -78,6 +128,10 @@ impl ScalarUDFImpl for VariantGetUdf {
78128
79129 try_field_as_variant_array ( variant_field. as_ref ( ) ) ?;
80130
131+ let type_field = type_arg
132+ . map ( |arg| type_hint_from_value ( self . name ( ) , arg) )
133+ . transpose ( ) ?;
134+
81135 let out = match ( variant_arg, variant_path) {
82136 ( ColumnarValue :: Array ( variant_array) , ColumnarValue :: Scalar ( variant_path) ) => {
83137 let variant_path = try_parse_string_scalar ( variant_path) ?
@@ -86,7 +140,7 @@ impl ScalarUDFImpl for VariantGetUdf {
86140
87141 let res = variant_get (
88142 variant_array,
89- GetOptions :: new_with_path ( VariantPath :: from ( variant_path) ) ,
143+ build_get_options ( VariantPath :: from ( variant_path) , & type_field ) ,
90144 ) ?;
91145
92146 ColumnarValue :: Array ( res)
@@ -104,14 +158,11 @@ impl ScalarUDFImpl for VariantGetUdf {
104158
105159 let res = variant_get (
106160 & variant_array,
107- GetOptions :: new_with_path ( VariantPath :: from ( variant_path) ) ,
108- ) ?
109- . as_any ( )
110- . downcast_ref :: < StructArray > ( )
111- . unwrap ( )
112- . clone ( ) ;
113-
114- ColumnarValue :: Scalar ( ScalarValue :: Struct ( Arc :: new ( res) ) )
161+ build_get_options ( VariantPath :: from ( variant_path) , & type_field) ,
162+ ) ?;
163+
164+ let scalar = ScalarValue :: try_from_array ( res. as_ref ( ) , 0 ) ?;
165+ ColumnarValue :: Scalar ( scalar)
115166 }
116167 ( ColumnarValue :: Array ( variant_array) , ColumnarValue :: Array ( variant_paths) ) => {
117168 if variant_array. len ( ) != variant_paths. len ( ) {
@@ -134,7 +185,7 @@ impl ScalarUDFImpl for VariantGetUdf {
134185
135186 let res = variant_get (
136187 & arr,
137- GetOptions :: new_with_path ( VariantPath :: from ( path. unwrap_or_default ( ) ) ) ,
188+ build_get_options ( VariantPath :: from ( path. unwrap_or_default ( ) ) , & type_field ) ,
138189 ) ?;
139190
140191 out. push ( res) ;
@@ -157,7 +208,7 @@ impl ScalarUDFImpl for VariantGetUdf {
157208 let path = path. unwrap_or_default ( ) ;
158209 let res = variant_get (
159210 & variant_array,
160- GetOptions :: new_with_path ( VariantPath :: from ( path) ) ,
211+ build_get_options ( VariantPath :: from ( path) , & type_field ) ,
161212 ) ?;
162213
163214 out. push ( res) ;
@@ -174,7 +225,7 @@ impl ScalarUDFImpl for VariantGetUdf {
174225
175226#[ cfg( test) ]
176227mod tests {
177- use arrow:: array:: { Array , BinaryViewArray } ;
228+ use arrow:: array:: { Array , BinaryViewArray , Int64Array } ;
178229 use arrow_schema:: { Field , Fields } ;
179230 use datafusion:: logical_expr:: { ReturnFieldArgs , ScalarFunctionArgs } ;
180231 use parquet_variant:: Variant ;
@@ -183,48 +234,96 @@ mod tests {
183234
184235 use super :: * ;
185236
186- #[ test]
187- fn test_get_variant_scalar ( ) {
188- let expected_json = serde_json:: json!( {
189- "name" : "norm" ,
190- "age" : 50 ,
191- "list" : [ false , true , ( ) ]
192- } ) ;
193-
194- let json_str = expected_json. to_string ( ) ;
237+ fn variant_scalar_from_json ( json : serde_json:: Value ) -> ScalarValue {
195238 let mut builder = VariantArrayBuilder :: new ( 1 ) ;
196- builder. append_json ( json_str. as_str ( ) ) . unwrap ( ) ;
239+ builder. append_json ( json. to_string ( ) . as_str ( ) ) . unwrap ( ) ;
240+ ScalarValue :: Struct ( Arc :: new ( builder. build ( ) . into ( ) ) )
241+ }
197242
198- let input = builder. build ( ) . into ( ) ;
243+ fn variant_array_from_json_rows ( json_rows : & [ serde_json:: Value ] ) -> ArrayRef {
244+ let mut builder = VariantArrayBuilder :: new ( json_rows. len ( ) ) ;
245+ for value in json_rows {
246+ builder. append_json ( value. to_string ( ) . as_str ( ) ) . unwrap ( ) ;
247+ }
248+ let variant_array: StructArray = builder. build ( ) . into ( ) ;
249+ Arc :: new ( variant_array) as ArrayRef
250+ }
199251
200- let variant_input = ScalarValue :: Struct ( Arc :: new ( input) ) ;
201- let path = "name" ;
252+ fn standard_arg_fields ( with_type_hint : bool ) -> Vec < FieldRef > {
253+ let mut fields = vec ! [
254+ Arc :: new(
255+ Field :: new( "input" , DataType :: Struct ( Fields :: empty( ) ) , true )
256+ . with_extension_type( VariantType ) ,
257+ ) ,
258+ Arc :: new( Field :: new( "path" , DataType :: Utf8 , true ) ) ,
259+ ] ;
260+ if with_type_hint {
261+ fields. push ( Arc :: new ( Field :: new ( "type" , DataType :: Utf8 , true ) ) ) ;
262+ }
263+ fields
264+ }
202265
203- let udf = VariantGetUdf :: default ( ) ;
266+ fn get_return_field (
267+ udf : & VariantGetUdf ,
268+ arg_fields : & [ FieldRef ] ,
269+ type_hint_value : Option < & ScalarValue > ,
270+ ) -> FieldRef {
271+ let scalar_arguments: Vec < Option < & ScalarValue > > = if let Some ( hint) = type_hint_value {
272+ vec ! [ None , None , Some ( hint) ]
273+ } else {
274+ vec ! [ ]
275+ } ;
204276
205- let arg_field = Arc :: new (
206- Field :: new ( "input" , DataType :: Struct ( Fields :: empty ( ) ) , true )
207- . with_extension_type ( VariantType ) ,
208- ) ;
209- let arg_field2 = Arc :: new ( Field :: new ( "path" , DataType :: Utf8 , true ) ) ;
277+ udf. return_field_from_args ( ReturnFieldArgs {
278+ arg_fields,
279+ scalar_arguments : & scalar_arguments,
280+ } )
281+ . unwrap ( )
282+ }
210283
211- let return_field = udf
212- . return_field_from_args ( ReturnFieldArgs {
213- arg_fields : & [ arg_field. clone ( ) , arg_field2. clone ( ) ] ,
214- scalar_arguments : & [ ] ,
215- } )
216- . unwrap ( ) ;
284+ fn build_scalar_function_args (
285+ variant_input : ColumnarValue ,
286+ path : & str ,
287+ arg_fields : Vec < FieldRef > ,
288+ return_field : FieldRef ,
289+ type_hint : Option < ScalarValue > ,
290+ ) -> ScalarFunctionArgs {
291+ let mut args = vec ! [
292+ variant_input,
293+ ColumnarValue :: Scalar ( ScalarValue :: Utf8 ( Some ( path. to_string( ) ) ) ) ,
294+ ] ;
295+ if let Some ( hint) = type_hint {
296+ args. push ( ColumnarValue :: Scalar ( hint) ) ;
297+ }
217298
218- let args = ScalarFunctionArgs {
219- args : vec ! [
220- ColumnarValue :: Scalar ( variant_input) ,
221- ColumnarValue :: Scalar ( ScalarValue :: Utf8 ( Some ( path. to_string( ) ) ) ) ,
222- ] ,
299+ ScalarFunctionArgs {
300+ args,
223301 return_field,
224- arg_fields : vec ! [ arg_field ] ,
302+ arg_fields,
225303 number_rows : Default :: default ( ) ,
226304 config_options : Default :: default ( ) ,
227- } ;
305+ }
306+ }
307+
308+ #[ test]
309+ fn test_get_variant_scalar ( ) {
310+ let variant_input = variant_scalar_from_json ( serde_json:: json!( {
311+ "name" : "norm" ,
312+ "age" : 50 ,
313+ "list" : [ false , true , ( ) ]
314+ } ) ) ;
315+
316+ let udf = VariantGetUdf :: default ( ) ;
317+ let arg_fields = standard_arg_fields ( false ) ;
318+ let return_field = get_return_field ( & udf, & arg_fields, None ) ;
319+
320+ let args = build_scalar_function_args (
321+ ColumnarValue :: Scalar ( variant_input) ,
322+ "name" ,
323+ arg_fields,
324+ return_field,
325+ None ,
326+ ) ;
228327
229328 let result = udf. invoke_with_args ( args) . unwrap ( ) ;
230329
@@ -247,9 +346,81 @@ mod tests {
247346
248347 let metadata = metadata_arr. value ( 0 ) ;
249348 let value = value_arr. value ( 0 ) ;
250-
251349 let v = Variant :: try_new ( metadata, value) . unwrap ( ) ;
252350
253351 assert_eq ! ( v, Variant :: from( "norm" ) )
254352 }
353+
354+ #[ test]
355+ fn test_return_field_with_type_hint ( ) {
356+ let udf = VariantGetUdf :: default ( ) ;
357+ let arg_fields = standard_arg_fields ( true ) ;
358+ let type_hint = ScalarValue :: Utf8 ( Some ( "Int64" . to_string ( ) ) ) ;
359+ let return_field = get_return_field ( & udf, & arg_fields, Some ( & type_hint) ) ;
360+
361+ assert_eq ! ( return_field. data_type( ) , & DataType :: Int64 ) ;
362+ }
363+
364+ #[ test]
365+ fn test_get_variant_scalar_with_type_hint ( ) {
366+ let variant_input = variant_scalar_from_json ( serde_json:: json!( {
367+ "name" : "norm" ,
368+ "age" : 50 ,
369+ } ) ) ;
370+
371+ let udf = VariantGetUdf :: default ( ) ;
372+ let arg_fields = standard_arg_fields ( true ) ;
373+ let type_hint = ScalarValue :: Utf8 ( Some ( "Int64" . to_string ( ) ) ) ;
374+ let return_field = get_return_field ( & udf, & arg_fields, Some ( & type_hint) ) ;
375+
376+ let args = build_scalar_function_args (
377+ ColumnarValue :: Scalar ( variant_input) ,
378+ "age" ,
379+ arg_fields,
380+ return_field,
381+ Some ( type_hint) ,
382+ ) ;
383+
384+ let result = udf. invoke_with_args ( args) . unwrap ( ) ;
385+
386+ let ColumnarValue :: Scalar ( ScalarValue :: Int64 ( Some ( value) ) ) = result else {
387+ panic ! ( "expected ScalarValue Int64" ) ;
388+ } ;
389+
390+ assert_eq ! ( value, 50 ) ;
391+ }
392+
393+ #[ test]
394+ fn test_get_variant_array_with_type_hint ( ) {
395+ let json_rows = vec ! [
396+ serde_json:: json!( { "age" : 50 } ) ,
397+ serde_json:: json!( { "age" : 60 } ) ,
398+ ] ;
399+
400+ let variant_array = variant_array_from_json_rows ( & json_rows) ;
401+
402+ let udf = VariantGetUdf :: default ( ) ;
403+ let arg_fields = standard_arg_fields ( true ) ;
404+ let type_hint = ScalarValue :: Utf8 ( Some ( "Int64" . to_string ( ) ) ) ;
405+ let return_field = get_return_field ( & udf, & arg_fields, Some ( & type_hint) ) ;
406+
407+ let args = build_scalar_function_args (
408+ ColumnarValue :: Array ( variant_array) ,
409+ "age" ,
410+ arg_fields,
411+ return_field,
412+ Some ( type_hint) ,
413+ ) ;
414+
415+ let result = udf. invoke_with_args ( args) . unwrap ( ) ;
416+
417+ let ColumnarValue :: Array ( array) = result else {
418+ panic ! ( "expected array output" ) ;
419+ } ;
420+
421+ let values = array. as_any ( ) . downcast_ref :: < Int64Array > ( ) . unwrap ( ) ;
422+ assert_eq ! ( values. len( ) , 2 ) ;
423+ assert_eq ! ( values. value( 0 ) , 50 ) ;
424+ assert_eq ! ( values. value( 1 ) , 60 ) ;
425+ }
255426}
0 commit comments