Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

61 changes: 25 additions & 36 deletions src/common/function/src/scalars/json/json_get.rs
Original file line number Diff line number Diff line change
Expand Up @@ -63,25 +63,18 @@ trait JsonGetResultBuilder {
fn build(&mut self) -> ArrayRef;
}

fn result_builder(
len: usize,
with_type: Option<&DataType>,
) -> Result<Box<dyn JsonGetResultBuilder>> {
let builder = if let Some(t) = with_type {
match t {
DataType::Utf8 | DataType::LargeUtf8 | DataType::Utf8View => {
Box::new(StringResultBuilder(StringViewBuilder::with_capacity(len)))
as Box<dyn JsonGetResultBuilder>
}
DataType::Int64 => Box::new(IntResultBuilder(Int64Builder::with_capacity(len))),
DataType::Float64 => Box::new(FloatResultBuilder(Float64Builder::with_capacity(len))),
DataType::Boolean => Box::new(BoolResultBuilder(BooleanBuilder::with_capacity(len))),
t => {
return exec_err!("json_get with unknown type {t}");
}
fn result_builder(len: usize, with_type: &DataType) -> Result<Box<dyn JsonGetResultBuilder>> {
let builder = match with_type {
DataType::Utf8 | DataType::LargeUtf8 | DataType::Utf8View => {
Box::new(StringResultBuilder(StringViewBuilder::with_capacity(len)))
as Box<dyn JsonGetResultBuilder>
}
DataType::Int64 => Box::new(IntResultBuilder(Int64Builder::with_capacity(len))),
DataType::Float64 => Box::new(FloatResultBuilder(Float64Builder::with_capacity(len))),
DataType::Boolean => Box::new(BoolResultBuilder(BooleanBuilder::with_capacity(len))),
t => {
return exec_err!("json_get with unknown type {t}");
}
} else {
Box::new(StringResultBuilder(StringViewBuilder::with_capacity(len)))
};
Ok(builder)
}
Expand Down Expand Up @@ -339,7 +332,7 @@ fn jsonb_get(
Ok(())
}

fn json_struct_get(array: &ArrayRef, path: &str, with_type: Option<&DataType>) -> Result<ArrayRef> {
fn json_struct_get(array: &ArrayRef, path: &str, with_type: &DataType) -> Result<ArrayRef> {
let path = path.trim_start_matches("$");

// Fast path: if the JSON array fields can be directly indexed into by the `path`, simply get
Expand All @@ -356,20 +349,13 @@ fn json_struct_get(array: &ArrayRef, path: &str, with_type: Option<&DataType>) -
return exec_err!("unknown JSON array datatype: {}", current.data_type());
};
let Some(sub_json) = json.column_by_name(segment) else {
return Ok(new_null_array(
with_type.unwrap_or(&DataType::Utf8View),
array.len(),
));
return Ok(new_null_array(with_type, array.len()));
};
current = sub_json;
}

// Build the result array with optional value mapper.
fn build_with<F>(
input: &ArrayRef,
with_type: Option<&DataType>,
value_mapper: F,
) -> Result<ArrayRef>
fn build_with<F>(input: &ArrayRef, with_type: &DataType, value_mapper: F) -> Result<ArrayRef>
where
for<'a> F: Fn(&'a Value) -> Option<&'a Value>,
{
Expand Down Expand Up @@ -397,20 +383,18 @@ fn json_struct_get(array: &ArrayRef, path: &str, with_type: Option<&DataType>) -
}

if direct {
let casted = if let Some(with_type) = with_type
&& current.data_type() != with_type
{
let casted = if current.data_type() != with_type {
match (current.data_type(), with_type) {
(DataType::Binary, _) => {
// Fall back to the slow path if the found JSON sub-array is serialized to bytes
// (because of JSON type conflicting)
build_with(current, Some(with_type), |v| Some(v))?
build_with(current, with_type, |v| Some(v))?
}
(DataType::List(_) | DataType::Struct(_), with_type) if with_type.is_string() => {
// Special handle for wanted array is string (Arrow cast is not working here if
// the datatype is list or struct), because it could be used in displaying the
// result.
build_with(current, Some(with_type), |v| Some(v))?
build_with(current, with_type, |v| Some(v))?
}
(_, with_type) if with_type.is_string() => {
// Same special handle for wanted array is string as above, except for simply
Expand Down Expand Up @@ -510,17 +494,22 @@ impl Function for JsonGetWithType {
);
};

let with_type = args.args.get(2).map(|x| x.data_type());
let with_type = args
.args
.get(2)
.map(|x| x.data_type())
.unwrap_or(DataType::Utf8View);

let result = match arg0.data_type() {
DataType::Binary | DataType::LargeBinary | DataType::BinaryView => {
let arg0 = compute::cast(&arg0, &DataType::BinaryView)?;
let jsons = arg0.as_binary_view();

let mut builder = result_builder(len, with_type.as_ref())?;
let mut builder = result_builder(len, &with_type)?;
jsonb_get(jsons, path, builder.as_mut())?;
builder.build()
}
DataType::Struct(_) => json_struct_get(&arg0, path, with_type.as_ref())?,
DataType::Struct(_) => json_struct_get(&arg0, path, &with_type)?,
_ => {
return exec_err!("JSON_GET not supported argument type {}", arg0.data_type());
}
Expand Down
178 changes: 91 additions & 87 deletions src/common/function/src/scalars/json/json_get_rewriter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -40,92 +40,111 @@ impl FunctionRewrite for JsonGetRewriter {
_schema: &DFSchema,
_config: &ConfigOptions,
) -> Result<Transformed<Expr>> {
let transform = match &expr {
Expr::Cast(cast) => rewrite_json_get_cast(cast),
Expr::ScalarFunction(scalar_func) => rewrite_arrow_cast_json_get(scalar_func),
_ => None,
};
Ok(transform.unwrap_or_else(|| Transformed::no(expr)))
Ok(match expr {
Expr::Cast(cast) => inject_type_from_cast_expr(cast)?,
Expr::ScalarFunction(cast) => inject_type_from_cast_func(cast)?,
expr => Transformed::no(expr),
})
}
}

fn is_json_get_function_call(scalar_func: &ScalarFunction) -> bool {
scalar_func.func.name().to_ascii_lowercase() == JsonGetWithType::NAME
&& scalar_func.args.len() == 2
}

fn rewrite_json_get_cast(cast: &Cast) -> Option<Transformed<Expr>> {
let scalar_func = extract_scalar_function(&cast.expr)?;
if is_json_get_function_call(scalar_func) {
let null_expr = Expr::Literal(ScalarValue::Null, None);
let null_cast = Expr::Cast(datafusion::logical_expr::expr::Cast {
expr: Box::new(null_expr),
data_type: cast.data_type.clone(),
});

let mut args = scalar_func.args.clone();
args.push(null_cast);
// Expr::Cast(
// Expr::ScalarFunction(
// json_get(column, path),
// <data_type>
// )
// )
// =>
// Expr::ScalarFunction(
// json_get(column, path, <data_type>)
// )
fn inject_type_from_cast_expr(cast: Cast) -> Result<Transformed<Expr>> {
let Cast { expr, data_type } = cast;

let mut json_get = match *expr {
Expr::ScalarFunction(f)
if f.func.name().eq_ignore_ascii_case(JsonGetWithType::NAME) && f.args.len() == 2 =>
{
f
}
expr => {
return Ok(Transformed::no(Expr::Cast(Cast {
expr: Box::new(expr),
data_type,
})));
}
};

Some(Transformed::yes(Expr::ScalarFunction(ScalarFunction {
func: scalar_func.func.clone(),
args,
})))
} else {
None
}
let with_type = ScalarValue::try_new_null(&data_type).map(|x| Expr::Literal(x, None))?;
json_get.args.push(with_type);
Ok(Transformed::yes(Expr::ScalarFunction(json_get)))
}

// Handle Arrow cast function: cast(json_get(a, 'path'), 'Int64')
fn rewrite_arrow_cast_json_get(scalar_func: &ScalarFunction) -> Option<Transformed<Expr>> {
// Expr::ScalarFunction(
// arrow_cast(
// Expr::ScalarFunction(
// json_get(column, path),
// ),
// <data_type>
// )
// )
// =>
// Expr::ScalarFunction(
// json_get(column, path, <data_type>)
// )
fn inject_type_from_cast_func(cast: ScalarFunction) -> Result<Transformed<Expr>> {
let ScalarFunction { func, args } = cast;

// Check if this is an Arrow cast function
// The function name might be "arrow_cast" or similar
let func_name = scalar_func.func.name().to_ascii_lowercase();
let func_name = func.name().to_ascii_lowercase();
if !func_name.contains("arrow_cast") {
return None;
let original = Expr::ScalarFunction(ScalarFunction { func, args });
return Ok(Transformed::no(original));
}

// Arrow cast function should have exactly 2 arguments:
// 1. The expression to cast (could be json_get)
// 2. The target type as a string literal
if scalar_func.args.len() != 2 {
return None;
if args.len() != 2 {
let original = Expr::ScalarFunction(ScalarFunction { func, args });
return Ok(Transformed::no(original));
}

// Extract the inner json_get function
let json_get_func = extract_scalar_function(&scalar_func.args[0])?;

// Check if it's a json_get function
if is_json_get_function_call(json_get_func) {
// Get the target type from the second argument
let target_type = extract_string_literal(&scalar_func.args[1])?;
let data_type = parse_data_type_from_string(&target_type)?;

// Create the null expression with the same type
let null_expr = Expr::Literal(ScalarValue::Null, None);
let null_cast = Expr::Cast(datafusion::logical_expr::expr::Cast {
expr: Box::new(null_expr),
data_type,
let [arg0, arg1] = args.try_into().unwrap_or_else(|_| unreachable!());

let Some(with_type) = arg1
.as_literal()
.and_then(|x| x.try_as_str())
.flatten()
.and_then(parse_data_type_from_string)
else {
let original = Expr::ScalarFunction(ScalarFunction {
func,
args: vec![arg0, arg1],
});
return Ok(Transformed::no(original));
};

let mut json_get = match arg0 {
Expr::ScalarFunction(f)
if f.func.name().eq_ignore_ascii_case(JsonGetWithType::NAME) && f.args.len() == 2 =>
{
f
}
arg0 => {
let original = Expr::ScalarFunction(ScalarFunction {
func,
args: vec![arg0, arg1],
});
return Ok(Transformed::no(original));
}
};

// Create the new json_get_with_type function with the null parameter
let mut args = json_get_func.args.clone();
args.push(null_cast);

Some(Transformed::yes(Expr::ScalarFunction(ScalarFunction {
func: json_get_func.func.clone(),
args,
})))
} else {
None
}
}
let with_type = ScalarValue::try_new_null(&with_type).map(|x| Expr::Literal(x, None))?;
json_get.args.push(with_type);

// Extract string literal from an expression
fn extract_string_literal(expr: &Expr) -> Option<String> {
match expr {
Expr::Literal(ScalarValue::Utf8(Some(s)), _) => Some(s.clone()),
_ => None,
}
let rewritten = Expr::ScalarFunction(json_get);
Ok(Transformed::yes(rewritten))
}

// Parse a data type from a string representation
Expand All @@ -149,13 +168,6 @@ fn parse_data_type_from_string(type_str: &str) -> Option<DataType> {
}
}

fn extract_scalar_function(expr: &Expr) -> Option<&ScalarFunction> {
match expr {
Expr::ScalarFunction(func) => Some(func),
_ => None,
}
}

#[cfg(test)]
mod tests {
use arrow_schema::DataType;
Expand Down Expand Up @@ -221,12 +233,8 @@ mod tests {

// Third argument should be a null cast to Int8
match &func.args[2] {
Expr::Cast(Cast { expr, data_type }) => {
assert_eq!(*data_type, DataType::Int8);
match expr.as_ref() {
Expr::Literal(ScalarValue::Null, _) => {}
_ => panic!("Third argument should be a null cast"),
}
Expr::Literal(value, _) => {
assert_eq!(value.data_type(), DataType::Int8);
}
_ => panic!("Third argument should be a cast expression"),
}
Expand Down Expand Up @@ -314,12 +322,8 @@ mod tests {

// Third argument should be a null cast to Int64
match &func.args[2] {
Expr::Cast(Cast { expr, data_type }) => {
assert_eq!(*data_type, DataType::Int64);
match expr.as_ref() {
Expr::Literal(ScalarValue::Null, _) => {}
_ => panic!("Third argument should be a null cast"),
}
Expr::Literal(value, _) => {
assert_eq!(value.data_type(), DataType::Int64);
}
_ => panic!("Third argument should be a cast expression"),
}
Expand Down
Loading
Loading