diff --git a/src/query/src/planner.rs b/src/query/src/planner.rs index 278058974a62..b95803a52b8a 100644 --- a/src/query/src/planner.rs +++ b/src/query/src/planner.rs @@ -26,8 +26,8 @@ use common_telemetry::tracing; use datafusion::common::{DFSchema, plan_err}; use datafusion::execution::context::SessionState; use datafusion::sql::planner::PlannerContext; -use datafusion_common::ToDFSchema; use datafusion_common::tree_node::{TreeNode, TreeNodeRecursion}; +use datafusion_common::{ScalarValue, ToDFSchema}; use datafusion_expr::expr::{Exists, InSubquery}; use datafusion_expr::{ Analyze, Explain, ExplainFormat, Expr as DfExpr, LogicalPlan, LogicalPlanBuilder, PlanType, @@ -451,6 +451,19 @@ impl DfLogicalPlanner { casted_placeholders.insert(ph.id.clone()); } + // Handle arrow_cast(Placeholder, 'type_string') generated by SQL rewriter + if let DfExpr::ScalarFunction(scalar_func) = e + && scalar_func.name() == "arrow_cast" + && scalar_func.args.len() == 2 + && let DfExpr::Placeholder(ph) = &scalar_func.args[0] + && let DfExpr::Literal(ScalarValue::Utf8(Some(type_str)), _) = + &scalar_func.args[1] + && let Ok(data_type) = type_str.parse::() + { + placeholder_types.insert(ph.id.clone(), Some(data_type)); + casted_placeholders.insert(ph.id.clone()); + } + // Handle bare (non-casted) placeholders if let DfExpr::Placeholder(ph) = e && !casted_placeholders.contains(&ph.id) @@ -869,4 +882,25 @@ mod tests { assert_eq!(types.get("$3"), Some(&Some(DataType::Int32))); assert_eq!(types.get("$4"), Some(&Some(DataType::Utf8))); } + + #[tokio::test] + async fn test_get_inferred_parameter_types_arrow_cast() { + let plan = parse_sql_to_plan("SELECT $1::INT64, $2::FLOAT64, $3::INT16, $4::INT32, $5::UINT8, $6::UINT16, $7::UINT32").await; + let types = DfLogicalPlanner::get_inferred_parameter_types(&plan).unwrap(); + + assert_eq!(types.get("$1"), Some(&Some(DataType::Int64))); + assert_eq!(types.get("$2"), Some(&Some(DataType::Float64))); + assert_eq!(types.get("$3"), Some(&Some(DataType::Int16))); + assert_eq!(types.get("$4"), Some(&Some(DataType::Int32))); + assert_eq!(types.get("$5"), Some(&Some(DataType::UInt8))); + assert_eq!(types.get("$6"), Some(&Some(DataType::UInt16))); + assert_eq!(types.get("$7"), Some(&Some(DataType::UInt32))); + + let plan = parse_sql_to_plan("SELECT $1::INT8, $2::FLOAT8, $3::INT2, $4::INT8").await; + let types = DfLogicalPlanner::get_inferred_parameter_types(&plan).unwrap(); + + assert_eq!(types.get("$1"), Some(&Some(DataType::Int64))); + assert_eq!(types.get("$2"), Some(&Some(DataType::Float64))); + assert_eq!(types.get("$3"), Some(&Some(DataType::Int16))); + } } diff --git a/src/servers/src/postgres/types.rs b/src/servers/src/postgres/types.rs index 887300a5733d..9fc5dccf5f31 100644 --- a/src/servers/src/postgres/types.rs +++ b/src/servers/src/postgres/types.rs @@ -398,8 +398,8 @@ pub(super) fn parameters_to_scalar_values( return Err(invalid_parameter_error( "unknown_parameter_type", Some(format!( - "Cannot get parameter type information for parameter {}", - idx + "Cannot get type for parameter {}, try to provide a type using ${}::", + idx, idx )), )); };