Skip to content

Commit 29c4c60

Browse files
authored
Update variant_get to allow optional type (#17)
1 parent 41237c0 commit 29c4c60

File tree

1 file changed

+222
-51
lines changed

1 file changed

+222
-51
lines changed

src/variant_get.rs

Lines changed: 222 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,9 @@ use arrow::{
44
array::{Array, ArrayRef, StructArray},
55
compute::concat,
66
};
7-
use arrow_schema::{DataType, Field, Fields};
7+
use arrow_schema::{ArrowError, DataType, Field, FieldRef, Fields};
88
use datafusion::{
9-
common::{exec_datafusion_err, exec_err},
9+
common::{arrow_datafusion_err, exec_datafusion_err, exec_err},
1010
error::{DataFusionError, Result},
1111
logical_expr::{
1212
ColumnarValue, ReturnFieldArgs, ScalarUDFImpl, Signature, TypeSignature, Volatility,
@@ -20,6 +20,44 @@ use crate::shared::{
2020
try_field_as_variant_array, try_parse_string_columnar, try_parse_string_scalar,
2121
};
2222

23+
fn type_hint_from_scalar(field_name: &str, scalar: &ScalarValue) -> Result<FieldRef> {
24+
let type_name = match scalar {
25+
ScalarValue::Utf8(Some(value))
26+
| ScalarValue::Utf8View(Some(value))
27+
| ScalarValue::LargeUtf8(Some(value)) => value.as_str(),
28+
other => {
29+
return exec_err!(
30+
"type hint must be a non-null UTF8 literal, got {}",
31+
other.data_type()
32+
);
33+
}
34+
};
35+
36+
let casted_type = match type_name.parse::<DataType>() {
37+
Ok(data_type) => Ok(data_type),
38+
Err(ArrowError::ParseError(e)) => Err(exec_datafusion_err!("{e}")),
39+
Err(e) => Err(arrow_datafusion_err!(e)),
40+
}?;
41+
42+
Ok(Arc::new(Field::new(field_name, casted_type, true)))
43+
}
44+
45+
fn type_hint_from_value(field_name: &str, arg: &ColumnarValue) -> Result<FieldRef> {
46+
match arg {
47+
ColumnarValue::Scalar(value) => type_hint_from_scalar(field_name, value),
48+
ColumnarValue::Array(_) => {
49+
exec_err!("type hint argument must be a scalar UTF8 literal")
50+
}
51+
}
52+
}
53+
54+
fn build_get_options<'a>(path: VariantPath<'a>, as_type: &Option<FieldRef>) -> GetOptions<'a> {
55+
match as_type {
56+
Some(field) => GetOptions::new_with_path(path).with_as_type(Some(field.clone())),
57+
None => GetOptions::new_with_path(path),
58+
}
59+
}
60+
2361
#[derive(Debug, Hash, PartialEq, Eq)]
2462
pub struct VariantGetUdf {
2563
signature: Signature,
@@ -28,7 +66,10 @@ pub struct VariantGetUdf {
2866
impl Default for VariantGetUdf {
2967
fn default() -> Self {
3068
Self {
31-
signature: Signature::new(TypeSignature::Any(2), Volatility::Immutable),
69+
signature: Signature::new(
70+
TypeSignature::OneOf(vec![TypeSignature::Any(2), TypeSignature::Any(3)]),
71+
Volatility::Immutable,
72+
),
3273
}
3374
}
3475
}
@@ -52,7 +93,14 @@ impl ScalarUDFImpl for VariantGetUdf {
5293
))
5394
}
5495

55-
fn return_field_from_args(&self, _args: ReturnFieldArgs) -> Result<Arc<Field>> {
96+
fn return_field_from_args(&self, args: ReturnFieldArgs) -> Result<Arc<Field>> {
97+
if let Some(maybe_scalar) = args.scalar_arguments.get(2) {
98+
let scalar = maybe_scalar.ok_or_else(|| {
99+
exec_datafusion_err!("type hint argument to variant_get must be a literal")
100+
})?;
101+
return type_hint_from_scalar(self.name(), scalar);
102+
}
103+
56104
let data_type = DataType::Struct(Fields::from(vec![
57105
Field::new("metadata", DataType::BinaryView, false),
58106
Field::new("value", DataType::BinaryView, true),
@@ -67,8 +115,10 @@ impl ScalarUDFImpl for VariantGetUdf {
67115
&self,
68116
args: datafusion::logical_expr::ScalarFunctionArgs,
69117
) -> Result<ColumnarValue> {
70-
let [variant_arg, variant_path] = args.args.as_slice() else {
71-
return exec_err!("expected 2 arguments");
118+
let (variant_arg, variant_path, type_arg) = match args.args.as_slice() {
119+
[variant_arg, variant_path] => (variant_arg, variant_path, None),
120+
[variant_arg, variant_path, type_arg] => (variant_arg, variant_path, Some(type_arg)),
121+
_ => return exec_err!("expected 2 or 3 arguments"),
72122
};
73123

74124
let variant_field = args
@@ -78,6 +128,10 @@ impl ScalarUDFImpl for VariantGetUdf {
78128

79129
try_field_as_variant_array(variant_field.as_ref())?;
80130

131+
let type_field = type_arg
132+
.map(|arg| type_hint_from_value(self.name(), arg))
133+
.transpose()?;
134+
81135
let out = match (variant_arg, variant_path) {
82136
(ColumnarValue::Array(variant_array), ColumnarValue::Scalar(variant_path)) => {
83137
let variant_path = try_parse_string_scalar(variant_path)?
@@ -86,7 +140,7 @@ impl ScalarUDFImpl for VariantGetUdf {
86140

87141
let res = variant_get(
88142
variant_array,
89-
GetOptions::new_with_path(VariantPath::from(variant_path)),
143+
build_get_options(VariantPath::from(variant_path), &type_field),
90144
)?;
91145

92146
ColumnarValue::Array(res)
@@ -104,14 +158,11 @@ impl ScalarUDFImpl for VariantGetUdf {
104158

105159
let res = variant_get(
106160
&variant_array,
107-
GetOptions::new_with_path(VariantPath::from(variant_path)),
108-
)?
109-
.as_any()
110-
.downcast_ref::<StructArray>()
111-
.unwrap()
112-
.clone();
113-
114-
ColumnarValue::Scalar(ScalarValue::Struct(Arc::new(res)))
161+
build_get_options(VariantPath::from(variant_path), &type_field),
162+
)?;
163+
164+
let scalar = ScalarValue::try_from_array(res.as_ref(), 0)?;
165+
ColumnarValue::Scalar(scalar)
115166
}
116167
(ColumnarValue::Array(variant_array), ColumnarValue::Array(variant_paths)) => {
117168
if variant_array.len() != variant_paths.len() {
@@ -134,7 +185,7 @@ impl ScalarUDFImpl for VariantGetUdf {
134185

135186
let res = variant_get(
136187
&arr,
137-
GetOptions::new_with_path(VariantPath::from(path.unwrap_or_default())),
188+
build_get_options(VariantPath::from(path.unwrap_or_default()), &type_field),
138189
)?;
139190

140191
out.push(res);
@@ -157,7 +208,7 @@ impl ScalarUDFImpl for VariantGetUdf {
157208
let path = path.unwrap_or_default();
158209
let res = variant_get(
159210
&variant_array,
160-
GetOptions::new_with_path(VariantPath::from(path)),
211+
build_get_options(VariantPath::from(path), &type_field),
161212
)?;
162213

163214
out.push(res);
@@ -174,7 +225,7 @@ impl ScalarUDFImpl for VariantGetUdf {
174225

175226
#[cfg(test)]
176227
mod tests {
177-
use arrow::array::{Array, BinaryViewArray};
228+
use arrow::array::{Array, BinaryViewArray, Int64Array};
178229
use arrow_schema::{Field, Fields};
179230
use datafusion::logical_expr::{ReturnFieldArgs, ScalarFunctionArgs};
180231
use parquet_variant::Variant;
@@ -183,48 +234,96 @@ mod tests {
183234

184235
use super::*;
185236

186-
#[test]
187-
fn test_get_variant_scalar() {
188-
let expected_json = serde_json::json!({
189-
"name": "norm",
190-
"age": 50,
191-
"list": [false, true, ()]
192-
});
193-
194-
let json_str = expected_json.to_string();
237+
fn variant_scalar_from_json(json: serde_json::Value) -> ScalarValue {
195238
let mut builder = VariantArrayBuilder::new(1);
196-
builder.append_json(json_str.as_str()).unwrap();
239+
builder.append_json(json.to_string().as_str()).unwrap();
240+
ScalarValue::Struct(Arc::new(builder.build().into()))
241+
}
197242

198-
let input = builder.build().into();
243+
fn variant_array_from_json_rows(json_rows: &[serde_json::Value]) -> ArrayRef {
244+
let mut builder = VariantArrayBuilder::new(json_rows.len());
245+
for value in json_rows {
246+
builder.append_json(value.to_string().as_str()).unwrap();
247+
}
248+
let variant_array: StructArray = builder.build().into();
249+
Arc::new(variant_array) as ArrayRef
250+
}
199251

200-
let variant_input = ScalarValue::Struct(Arc::new(input));
201-
let path = "name";
252+
fn standard_arg_fields(with_type_hint: bool) -> Vec<FieldRef> {
253+
let mut fields = vec![
254+
Arc::new(
255+
Field::new("input", DataType::Struct(Fields::empty()), true)
256+
.with_extension_type(VariantType),
257+
),
258+
Arc::new(Field::new("path", DataType::Utf8, true)),
259+
];
260+
if with_type_hint {
261+
fields.push(Arc::new(Field::new("type", DataType::Utf8, true)));
262+
}
263+
fields
264+
}
202265

203-
let udf = VariantGetUdf::default();
266+
fn get_return_field(
267+
udf: &VariantGetUdf,
268+
arg_fields: &[FieldRef],
269+
type_hint_value: Option<&ScalarValue>,
270+
) -> FieldRef {
271+
let scalar_arguments: Vec<Option<&ScalarValue>> = if let Some(hint) = type_hint_value {
272+
vec![None, None, Some(hint)]
273+
} else {
274+
vec![]
275+
};
204276

205-
let arg_field = Arc::new(
206-
Field::new("input", DataType::Struct(Fields::empty()), true)
207-
.with_extension_type(VariantType),
208-
);
209-
let arg_field2 = Arc::new(Field::new("path", DataType::Utf8, true));
277+
udf.return_field_from_args(ReturnFieldArgs {
278+
arg_fields,
279+
scalar_arguments: &scalar_arguments,
280+
})
281+
.unwrap()
282+
}
210283

211-
let return_field = udf
212-
.return_field_from_args(ReturnFieldArgs {
213-
arg_fields: &[arg_field.clone(), arg_field2.clone()],
214-
scalar_arguments: &[],
215-
})
216-
.unwrap();
284+
fn build_scalar_function_args(
285+
variant_input: ColumnarValue,
286+
path: &str,
287+
arg_fields: Vec<FieldRef>,
288+
return_field: FieldRef,
289+
type_hint: Option<ScalarValue>,
290+
) -> ScalarFunctionArgs {
291+
let mut args = vec![
292+
variant_input,
293+
ColumnarValue::Scalar(ScalarValue::Utf8(Some(path.to_string()))),
294+
];
295+
if let Some(hint) = type_hint {
296+
args.push(ColumnarValue::Scalar(hint));
297+
}
217298

218-
let args = ScalarFunctionArgs {
219-
args: vec![
220-
ColumnarValue::Scalar(variant_input),
221-
ColumnarValue::Scalar(ScalarValue::Utf8(Some(path.to_string()))),
222-
],
299+
ScalarFunctionArgs {
300+
args,
223301
return_field,
224-
arg_fields: vec![arg_field],
302+
arg_fields,
225303
number_rows: Default::default(),
226304
config_options: Default::default(),
227-
};
305+
}
306+
}
307+
308+
#[test]
309+
fn test_get_variant_scalar() {
310+
let variant_input = variant_scalar_from_json(serde_json::json!({
311+
"name": "norm",
312+
"age": 50,
313+
"list": [false, true, ()]
314+
}));
315+
316+
let udf = VariantGetUdf::default();
317+
let arg_fields = standard_arg_fields(false);
318+
let return_field = get_return_field(&udf, &arg_fields, None);
319+
320+
let args = build_scalar_function_args(
321+
ColumnarValue::Scalar(variant_input),
322+
"name",
323+
arg_fields,
324+
return_field,
325+
None,
326+
);
228327

229328
let result = udf.invoke_with_args(args).unwrap();
230329

@@ -247,9 +346,81 @@ mod tests {
247346

248347
let metadata = metadata_arr.value(0);
249348
let value = value_arr.value(0);
250-
251349
let v = Variant::try_new(metadata, value).unwrap();
252350

253351
assert_eq!(v, Variant::from("norm"))
254352
}
353+
354+
#[test]
355+
fn test_return_field_with_type_hint() {
356+
let udf = VariantGetUdf::default();
357+
let arg_fields = standard_arg_fields(true);
358+
let type_hint = ScalarValue::Utf8(Some("Int64".to_string()));
359+
let return_field = get_return_field(&udf, &arg_fields, Some(&type_hint));
360+
361+
assert_eq!(return_field.data_type(), &DataType::Int64);
362+
}
363+
364+
#[test]
365+
fn test_get_variant_scalar_with_type_hint() {
366+
let variant_input = variant_scalar_from_json(serde_json::json!({
367+
"name": "norm",
368+
"age": 50,
369+
}));
370+
371+
let udf = VariantGetUdf::default();
372+
let arg_fields = standard_arg_fields(true);
373+
let type_hint = ScalarValue::Utf8(Some("Int64".to_string()));
374+
let return_field = get_return_field(&udf, &arg_fields, Some(&type_hint));
375+
376+
let args = build_scalar_function_args(
377+
ColumnarValue::Scalar(variant_input),
378+
"age",
379+
arg_fields,
380+
return_field,
381+
Some(type_hint),
382+
);
383+
384+
let result = udf.invoke_with_args(args).unwrap();
385+
386+
let ColumnarValue::Scalar(ScalarValue::Int64(Some(value))) = result else {
387+
panic!("expected ScalarValue Int64");
388+
};
389+
390+
assert_eq!(value, 50);
391+
}
392+
393+
#[test]
394+
fn test_get_variant_array_with_type_hint() {
395+
let json_rows = vec![
396+
serde_json::json!({ "age": 50 }),
397+
serde_json::json!({ "age": 60 }),
398+
];
399+
400+
let variant_array = variant_array_from_json_rows(&json_rows);
401+
402+
let udf = VariantGetUdf::default();
403+
let arg_fields = standard_arg_fields(true);
404+
let type_hint = ScalarValue::Utf8(Some("Int64".to_string()));
405+
let return_field = get_return_field(&udf, &arg_fields, Some(&type_hint));
406+
407+
let args = build_scalar_function_args(
408+
ColumnarValue::Array(variant_array),
409+
"age",
410+
arg_fields,
411+
return_field,
412+
Some(type_hint),
413+
);
414+
415+
let result = udf.invoke_with_args(args).unwrap();
416+
417+
let ColumnarValue::Array(array) = result else {
418+
panic!("expected array output");
419+
};
420+
421+
let values = array.as_any().downcast_ref::<Int64Array>().unwrap();
422+
assert_eq!(values.len(), 2);
423+
assert_eq!(values.value(0), 50);
424+
assert_eq!(values.value(1), 60);
425+
}
255426
}

0 commit comments

Comments
 (0)