Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,12 @@ case class CollectList(
s"$prettyName($child)$ignoreNullsStr"
}

override def sql(isDistinct: Boolean): String = {
val distinct = if (isDistinct) "DISTINCT " else ""
val nullsStr = if (ignoreNulls) "" else " RESPECT NULLS"
s"$prettyName($distinct${child.sql})$nullsStr"
}

override protected def withNewChildInternal(newChild: Expression): CollectList =
copy(child = newChild)
}
Expand Down Expand Up @@ -268,6 +274,12 @@ case class CollectSet(
s"$prettyName($child)$ignoreNullsStr"
}

override def sql(isDistinct: Boolean): String = {
val distinct = if (isDistinct) "DISTINCT " else ""
Copy link
Contributor Author

@helioshe4 helioshe4 Mar 23, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

isDistinct should always be false for CollectSet because it gets removed in EliminateDistinct from Optimizer.scala

but i will keep it here for now to be consistent with the other aggregate functions

val nullsStr = if (ignoreNulls) "" else " RESPECT NULLS"
s"$prettyName($distinct${child.sql})$nullsStr"
}

override protected def withNewChildInternal(newChild: Expression): CollectSet =
copy(child = newChild)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -791,6 +791,19 @@ class DataFrameAggregateSuite extends QueryTest
Seq(Row(Seq(1.0, 2.0))))
}

test("SPARK-56155: collect functions sql() display RESPECT NULLS") {
val df = Seq((1, Some(2)), (1, None), (1, Some(4))).toDF("a", "b")
val collect_list_result = df.selectExpr("collect_list(b) RESPECT NULLS")
val collect_list_result2 = df.selectExpr("collect_list(b)")
assert(collect_list_result.columns.head == "collect_list(b) RESPECT NULLS")
assert(collect_list_result2.columns.head == "collect_list(b)")

val collect_set_result = df.selectExpr("collect_set(b) RESPECT NULLS")
val collect_set_result2 = df.selectExpr("collect_set(b)")
assert(collect_set_result.columns.head == "collect_set(b) RESPECT NULLS")
assert(collect_set_result2.columns.head == "collect_set(b)")
}

test("SPARK-14664: Decimal sum/avg over window should work.") {
checkAnswer(
spark.sql("select sum(a) over () from values 1.0, 2.0, 3.0 T(a)"),
Expand Down