diff --git a/python/pyspark/errors/error-conditions.json b/python/pyspark/errors/error-conditions.json index 3111fbd540777..4b25272a6e987 100644 --- a/python/pyspark/errors/error-conditions.json +++ b/python/pyspark/errors/error-conditions.json @@ -425,6 +425,11 @@ "Collations can only be applied to string types, but the JSON data type is ." ] }, + "INVALID_LAMBDA_EXPRESSION": { + "message": [ + "Invalid SQL lambda expression: ''. Expected format: 'param -> expression'." + ] + }, "INVALID_MULTIPLE_ARGUMENT_CONDITIONS": { "message": [ "[{arg_names}] cannot be ." diff --git a/python/pyspark/sql/classic/column.py b/python/pyspark/sql/classic/column.py index c7acf504098ac..de1c9c4b7b65a 100644 --- a/python/pyspark/sql/classic/column.py +++ b/python/pyspark/sql/classic/column.py @@ -610,7 +610,28 @@ def over(self, window: "WindowSpec") -> ParentColumn: jc = self._jc.over(window._jspec) return Column(jc) - def transform(self, f: Callable[[ParentColumn], ParentColumn]) -> ParentColumn: + def transform(self, f: Union[Callable[[ParentColumn], ParentColumn], str]) -> ParentColumn: + if isinstance(f, str): + from py4j.java_gateway import JVMView + + arrow_idx = f.find("->") + if arrow_idx == -1: + raise PySparkValueError( + errorClass="INVALID_LAMBDA_EXPRESSION", + messageParameters={"expression": f}, + ) + + param = f[:arrow_idx].strip() + if not param.isidentifier(): + raise PySparkValueError( + errorClass="INVALID_LAMBDA_EXPRESSION", + messageParameters={"expression": f}, + ) + + sc = get_active_spark_context() + jvm = cast(JVMView, sc._jvm) + jresult = jvm.PythonSQLUtils.applyLambda(self._jc, f) + return Column(jresult) return f(self) def outer(self) -> ParentColumn: diff --git a/python/pyspark/sql/column.py b/python/pyspark/sql/column.py index a38ed64eb700d..04d9b2e0ed917 100644 --- a/python/pyspark/sql/column.py +++ b/python/pyspark/sql/column.py @@ -1512,24 +1512,24 @@ def over(self, window: "WindowSpec") -> "Column": ... @dispatch_col_method - def transform(self, f: Callable[["Column"], "Column"]) -> "Column": + def transform(self, f: Union[Callable[["Column"], "Column"], str]) -> "Column": """ - Applies a transformation function to this column. + Applies a transformation to this column. - This method allows you to apply a function that takes a Column and returns a Column, - enabling method chaining and functional transformations. + Accepts either a Python callable or a SQL lambda expression string. .. versionadded:: 4.1.0 Parameters ---------- - f : callable - A function that takes a :class:`Column` and returns a :class:`Column`. + f : callable or str + A function that takes a :class:`Column` and returns a :class:`Column`, + or a SQL lambda expression string in the format ``'param -> expression'``. Returns ------- :class:`Column` - The result of applying the function to this column. + The result of applying the transformation to this column. Examples -------- @@ -1545,7 +1545,7 @@ def transform(self, f: Callable[["Column"], "Column"]) -> "Column": | WORLD| +------+ - Example 2: Use lambda functions + Example 2: Use Python lambda functions >>> df = spark.createDataFrame([(10,), (20,), (30,)], ["value"]) >>> df.select( @@ -1560,6 +1560,29 @@ def transform(self, f: Callable[["Column"], "Column"]) -> "Column": | 40| | 60| +------+ + + Example 3: Use SQL lambda expression + + >>> df = spark.createDataFrame([(1,), (2,), (3,)], ["value"]) + >>> df.select(df.value.transform('x -> x * 2').alias("result")).show() + +------+ + |result| + +------+ + | 2| + | 4| + | 6| + +------+ + + Example 4: SQL lambda with function calls + + >>> df = spark.createDataFrame([("hello",), ("world",)], ["text"]) + >>> df.select(df.text.transform('x -> upper(x)').alias("result")).show() + +------+ + |result| + +------+ + | HELLO| + | WORLD| + +------+ """ ... diff --git a/python/pyspark/sql/connect/column.py b/python/pyspark/sql/connect/column.py index 2fca83748f1ec..3a0ad84266efd 100644 --- a/python/pyspark/sql/connect/column.py +++ b/python/pyspark/sql/connect/column.py @@ -44,6 +44,7 @@ LiteralExpression, CaseWhen, SortOrder, + SQLExpression, SubqueryExpression, CastExpression, WindowExpression, @@ -465,7 +466,30 @@ def over(self, window: "WindowSpec") -> ParentColumn: # type: ignore[override] return Column(WindowExpression(windowFunction=self._expr, windowSpec=window)) - def transform(self, f: Callable[[ParentColumn], ParentColumn]) -> ParentColumn: + def transform(self, f: Union[Callable[[ParentColumn], ParentColumn], str]) -> ParentColumn: + if isinstance(f, str): + arrow_idx = f.find("->") + if arrow_idx == -1: + raise PySparkValueError( + errorClass="INVALID_LAMBDA_EXPRESSION", + messageParameters={"expression": f}, + ) + + param = f[:arrow_idx].strip() + if not param.isidentifier(): + raise PySparkValueError( + errorClass="INVALID_LAMBDA_EXPRESSION", + messageParameters={"expression": f}, + ) + + # Build: transform(array(col), lambda)[0] + # The server-side parser handles the lambda expression natively. + lambda_expr = SQLExpression(f) + array_col = Column(UnresolvedFunction("array", [self._expr])) + transform_col = Column( + UnresolvedFunction("transform", [array_col._expr, lambda_expr]) + ) + return transform_col[0] return f(self) def outer(self) -> ParentColumn: diff --git a/python/pyspark/sql/tests/test_column.py b/python/pyspark/sql/tests/test_column.py index 1983b291a5a36..9361b3102e90a 100644 --- a/python/pyspark/sql/tests/test_column.py +++ b/python/pyspark/sql/tests/test_column.py @@ -502,6 +502,69 @@ def test_transform(self): self.assertEqual(result[1][0], 40) self.assertEqual(result[2][0], 60) + def test_transform_sql_lambda_arithmetic(self): + """Test transform() with SQL lambda expression for arithmetic.""" + df = self.spark.createDataFrame([(1,), (2,), (3,)], ["value"]) + result = df.select(df.value.transform("x -> x * 2").alias("result")).collect() + self.assertEqual(result[0][0], 2) + self.assertEqual(result[1][0], 4) + self.assertEqual(result[2][0], 6) + + def test_transform_sql_lambda_function_call(self): + """Test transform() with SQL lambda using function calls.""" + df = self.spark.createDataFrame([("hello",), ("world",)], ["text"]) + result = df.select(df.text.transform("x -> upper(x)").alias("result")).collect() + self.assertEqual(result[0][0], "HELLO") + self.assertEqual(result[1][0], "WORLD") + + def test_transform_sql_lambda_conditional(self): + """Test transform() with SQL lambda using CASE WHEN.""" + df = self.spark.createDataFrame([(-1,), (0,), (1,)], ["value"]) + result = df.select( + df.value.transform("x -> CASE WHEN x > 0 THEN x ELSE 0 END").alias("result") + ).collect() + self.assertEqual(result[0][0], 0) + self.assertEqual(result[1][0], 0) + self.assertEqual(result[2][0], 1) + + def test_transform_sql_lambda_with_nulls(self): + """Test transform() with SQL lambda handles nulls.""" + df = self.spark.createDataFrame([(1,), (None,), (3,)], ["value"]) + result = df.select(df.value.transform("x -> x * 2").alias("result")).collect() + self.assertEqual(result[0][0], 2) + self.assertIsNone(result[1][0]) + self.assertEqual(result[2][0], 6) + + def test_transform_sql_lambda_chaining(self): + """Test chaining SQL lambda transform calls.""" + df = self.spark.createDataFrame([(5,)], ["value"]) + result = df.select( + df.value.transform("x -> x * 2").transform("y -> y + 1").alias("result") + ).collect() + self.assertEqual(result[0][0], 11) + + def test_transform_sql_lambda_mixed_chaining(self): + """Test chaining SQL lambda with Python callable.""" + df = self.spark.createDataFrame([(5,)], ["value"]) + result = df.select( + df.value.transform("x -> x * 2").transform(lambda c: c + 1).alias("result") + ).collect() + self.assertEqual(result[0][0], 11) + + def test_transform_sql_lambda_invalid_no_arrow(self): + """Test transform() raises error for string without ->.""" + df = self.spark.createDataFrame([(1,)], ["value"]) + with self.assertRaises(PySparkValueError) as ctx: + df.select(df.value.transform("x * 2")).collect() + self.assertIn("INVALID_LAMBDA_EXPRESSION", ctx.exception.getErrorClass()) + + def test_transform_sql_lambda_invalid_param(self): + """Test transform() raises error for invalid parameter name.""" + df = self.spark.createDataFrame([(1,)], ["value"]) + with self.assertRaises(PySparkValueError) as ctx: + df.select(df.value.transform("123 -> x * 2")).collect() + self.assertIn("INVALID_LAMBDA_EXPRESSION", ctx.exception.getErrorClass()) + class ColumnTests(ColumnTestsMixin, ReusedSQLTestCase): pass diff --git a/sql/core/src/main/scala/org/apache/spark/sql/api/python/PythonSQLUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/api/python/PythonSQLUtils.scala index 5607c98bf29e5..52ab635ba2297 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/api/python/PythonSQLUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/api/python/PythonSQLUtils.scala @@ -34,6 +34,7 @@ import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.parser.CatalystSqlParser import org.apache.spark.sql.classic.ClassicConversions._ +import org.apache.spark.sql.classic.ExpressionUtils import org.apache.spark.sql.classic.ExpressionUtils.expression import org.apache.spark.sql.execution.{ExplainMode, QueryExecution} import org.apache.spark.sql.execution.arrow.ArrowConverters @@ -181,6 +182,25 @@ private[sql] object PythonSQLUtils extends Logging { Column(internal.LambdaFunction(function.node, arguments)) } + def applyLambda(col: Column, lambdaExpr: String): Column = { + val parsed = CatalystSqlParser.parseExpression(lambdaExpr) + parsed match { + case LambdaFunction(function, arguments, _) if arguments.size == 1 => + val colExpr = expression(col) + val param = arguments.head.asInstanceOf[UnresolvedNamedLambdaVariable] + val replaced = function.transform { + case v: UnresolvedNamedLambdaVariable if v.nameParts == param.nameParts => colExpr + } + ExpressionUtils.column(replaced) + case _: LambdaFunction => + throw new IllegalArgumentException( + s"Expected a single-parameter lambda expression, but got: $lambdaExpr") + case _ => + throw new IllegalArgumentException( + s"Expected a lambda expression (param -> expr), but got: $lambdaExpr") + } + } + def namedArgumentExpression(name: String, e: Column): Column = Column(NamedArgumentExpression(name, expression(e)))