Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions python/pyspark/errors/error-conditions.json
Original file line number Diff line number Diff line change
Expand Up @@ -425,6 +425,11 @@
"Collations can only be applied to string types, but the JSON data type is <jsonType>."
]
},
"INVALID_LAMBDA_EXPRESSION": {
"message": [
"Invalid SQL lambda expression: '<expression>'. Expected format: 'param -> expression'."
]
},
"INVALID_MULTIPLE_ARGUMENT_CONDITIONS": {
"message": [
"[{arg_names}] cannot be <condition>."
Expand Down
23 changes: 22 additions & 1 deletion python/pyspark/sql/classic/column.py
Original file line number Diff line number Diff line change
Expand Up @@ -610,7 +610,28 @@ def over(self, window: "WindowSpec") -> ParentColumn:
jc = self._jc.over(window._jspec)
return Column(jc)

def transform(self, f: Callable[[ParentColumn], ParentColumn]) -> ParentColumn:
def transform(self, f: Union[Callable[[ParentColumn], ParentColumn], str]) -> ParentColumn:
if isinstance(f, str):
from py4j.java_gateway import JVMView

arrow_idx = f.find("->")
if arrow_idx == -1:
raise PySparkValueError(
errorClass="INVALID_LAMBDA_EXPRESSION",
messageParameters={"expression": f},
)

param = f[:arrow_idx].strip()
if not param.isidentifier():
raise PySparkValueError(
errorClass="INVALID_LAMBDA_EXPRESSION",
messageParameters={"expression": f},
)

sc = get_active_spark_context()
jvm = cast(JVMView, sc._jvm)
jresult = jvm.PythonSQLUtils.applyLambda(self._jc, f)
return Column(jresult)
return f(self)

def outer(self) -> ParentColumn:
Expand Down
39 changes: 31 additions & 8 deletions python/pyspark/sql/column.py
Original file line number Diff line number Diff line change
Expand Up @@ -1512,24 +1512,24 @@ def over(self, window: "WindowSpec") -> "Column":
...

@dispatch_col_method
def transform(self, f: Callable[["Column"], "Column"]) -> "Column":
def transform(self, f: Union[Callable[["Column"], "Column"], str]) -> "Column":
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hm, can't people just use it likeexpr("transform(c, 'x -> x * 2')")? These API are supposed to be python friendly.

"""
Applies a transformation function to this column.
Applies a transformation to this column.

This method allows you to apply a function that takes a Column and returns a Column,
enabling method chaining and functional transformations.
Accepts either a Python callable or a SQL lambda expression string.

.. versionadded:: 4.1.0

Parameters
----------
f : callable
A function that takes a :class:`Column` and returns a :class:`Column`.
f : callable or str
A function that takes a :class:`Column` and returns a :class:`Column`,
or a SQL lambda expression string in the format ``'param -> expression'``.

Returns
-------
:class:`Column`
The result of applying the function to this column.
The result of applying the transformation to this column.

Examples
--------
Expand All @@ -1545,7 +1545,7 @@ def transform(self, f: Callable[["Column"], "Column"]) -> "Column":
| WORLD|
+------+

Example 2: Use lambda functions
Example 2: Use Python lambda functions

>>> df = spark.createDataFrame([(10,), (20,), (30,)], ["value"])
>>> df.select(
Expand All @@ -1560,6 +1560,29 @@ def transform(self, f: Callable[["Column"], "Column"]) -> "Column":
| 40|
| 60|
+------+

Example 3: Use SQL lambda expression

>>> df = spark.createDataFrame([(1,), (2,), (3,)], ["value"])
>>> df.select(df.value.transform('x -> x * 2').alias("result")).show()
+------+
|result|
+------+
| 2|
| 4|
| 6|
+------+

Example 4: SQL lambda with function calls

>>> df = spark.createDataFrame([("hello",), ("world",)], ["text"])
>>> df.select(df.text.transform('x -> upper(x)').alias("result")).show()
+------+
|result|
+------+
| HELLO|
| WORLD|
+------+
"""
...

Expand Down
26 changes: 25 additions & 1 deletion python/pyspark/sql/connect/column.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
LiteralExpression,
CaseWhen,
SortOrder,
SQLExpression,
SubqueryExpression,
CastExpression,
WindowExpression,
Expand Down Expand Up @@ -465,7 +466,30 @@ def over(self, window: "WindowSpec") -> ParentColumn: # type: ignore[override]

return Column(WindowExpression(windowFunction=self._expr, windowSpec=window))

def transform(self, f: Callable[[ParentColumn], ParentColumn]) -> ParentColumn:
def transform(self, f: Union[Callable[[ParentColumn], ParentColumn], str]) -> ParentColumn:
if isinstance(f, str):
arrow_idx = f.find("->")
if arrow_idx == -1:
raise PySparkValueError(
errorClass="INVALID_LAMBDA_EXPRESSION",
messageParameters={"expression": f},
)

param = f[:arrow_idx].strip()
if not param.isidentifier():
raise PySparkValueError(
errorClass="INVALID_LAMBDA_EXPRESSION",
messageParameters={"expression": f},
)

# Build: transform(array(col), lambda)[0]
# The server-side parser handles the lambda expression natively.
lambda_expr = SQLExpression(f)
array_col = Column(UnresolvedFunction("array", [self._expr]))
transform_col = Column(
UnresolvedFunction("transform", [array_col._expr, lambda_expr])
)
return transform_col[0]
return f(self)

def outer(self) -> ParentColumn:
Expand Down
63 changes: 63 additions & 0 deletions python/pyspark/sql/tests/test_column.py
Original file line number Diff line number Diff line change
Expand Up @@ -502,6 +502,69 @@ def test_transform(self):
self.assertEqual(result[1][0], 40)
self.assertEqual(result[2][0], 60)

def test_transform_sql_lambda_arithmetic(self):
"""Test transform() with SQL lambda expression for arithmetic."""
df = self.spark.createDataFrame([(1,), (2,), (3,)], ["value"])
result = df.select(df.value.transform("x -> x * 2").alias("result")).collect()
self.assertEqual(result[0][0], 2)
self.assertEqual(result[1][0], 4)
self.assertEqual(result[2][0], 6)

def test_transform_sql_lambda_function_call(self):
"""Test transform() with SQL lambda using function calls."""
df = self.spark.createDataFrame([("hello",), ("world",)], ["text"])
result = df.select(df.text.transform("x -> upper(x)").alias("result")).collect()
self.assertEqual(result[0][0], "HELLO")
self.assertEqual(result[1][0], "WORLD")

def test_transform_sql_lambda_conditional(self):
"""Test transform() with SQL lambda using CASE WHEN."""
df = self.spark.createDataFrame([(-1,), (0,), (1,)], ["value"])
result = df.select(
df.value.transform("x -> CASE WHEN x > 0 THEN x ELSE 0 END").alias("result")
).collect()
self.assertEqual(result[0][0], 0)
self.assertEqual(result[1][0], 0)
self.assertEqual(result[2][0], 1)

def test_transform_sql_lambda_with_nulls(self):
"""Test transform() with SQL lambda handles nulls."""
df = self.spark.createDataFrame([(1,), (None,), (3,)], ["value"])
result = df.select(df.value.transform("x -> x * 2").alias("result")).collect()
self.assertEqual(result[0][0], 2)
self.assertIsNone(result[1][0])
self.assertEqual(result[2][0], 6)

def test_transform_sql_lambda_chaining(self):
"""Test chaining SQL lambda transform calls."""
df = self.spark.createDataFrame([(5,)], ["value"])
result = df.select(
df.value.transform("x -> x * 2").transform("y -> y + 1").alias("result")
).collect()
self.assertEqual(result[0][0], 11)

def test_transform_sql_lambda_mixed_chaining(self):
"""Test chaining SQL lambda with Python callable."""
df = self.spark.createDataFrame([(5,)], ["value"])
result = df.select(
df.value.transform("x -> x * 2").transform(lambda c: c + 1).alias("result")
).collect()
self.assertEqual(result[0][0], 11)

def test_transform_sql_lambda_invalid_no_arrow(self):
"""Test transform() raises error for string without ->."""
df = self.spark.createDataFrame([(1,)], ["value"])
with self.assertRaises(PySparkValueError) as ctx:
df.select(df.value.transform("x * 2")).collect()
self.assertIn("INVALID_LAMBDA_EXPRESSION", ctx.exception.getErrorClass())

def test_transform_sql_lambda_invalid_param(self):
"""Test transform() raises error for invalid parameter name."""
df = self.spark.createDataFrame([(1,)], ["value"])
with self.assertRaises(PySparkValueError) as ctx:
df.select(df.value.transform("123 -> x * 2")).collect()
self.assertIn("INVALID_LAMBDA_EXPRESSION", ctx.exception.getErrorClass())


class ColumnTests(ColumnTestsMixin, ReusedSQLTestCase):
pass
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.parser.CatalystSqlParser
import org.apache.spark.sql.classic.ClassicConversions._
import org.apache.spark.sql.classic.ExpressionUtils
import org.apache.spark.sql.classic.ExpressionUtils.expression
import org.apache.spark.sql.execution.{ExplainMode, QueryExecution}
import org.apache.spark.sql.execution.arrow.ArrowConverters
Expand Down Expand Up @@ -181,6 +182,25 @@ private[sql] object PythonSQLUtils extends Logging {
Column(internal.LambdaFunction(function.node, arguments))
}

def applyLambda(col: Column, lambdaExpr: String): Column = {
val parsed = CatalystSqlParser.parseExpression(lambdaExpr)
parsed match {
case LambdaFunction(function, arguments, _) if arguments.size == 1 =>
val colExpr = expression(col)
val param = arguments.head.asInstanceOf[UnresolvedNamedLambdaVariable]
val replaced = function.transform {
case v: UnresolvedNamedLambdaVariable if v.nameParts == param.nameParts => colExpr
}
ExpressionUtils.column(replaced)
case _: LambdaFunction =>
throw new IllegalArgumentException(
s"Expected a single-parameter lambda expression, but got: $lambdaExpr")
case _ =>
throw new IllegalArgumentException(
s"Expected a lambda expression (param -> expr), but got: $lambdaExpr")
}
}

def namedArgumentExpression(name: String, e: Column): Column =
Column(NamedArgumentExpression(name, expression(e)))

Expand Down