Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@

package org.apache.spark.sql.catalyst.plans.logical

import scala.collection.mutable

import org.apache.spark.{SparkException, SparkIllegalArgumentException, SparkUnsupportedOperationException}
import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.analysis.{AnalysisContext, AssignmentUtils, EliminateSubqueryAliases, FieldName, NamedRelation, PartitionSpec, ResolvedIdentifier, ResolvedProcedure, ResolveSchemaEvolution, TypeCheckResult, UnresolvedAttribute, UnresolvedException, UnresolvedProcedure, ViewSchemaMode}
Expand All @@ -38,9 +40,10 @@ import org.apache.spark.sql.connector.expressions.filter.Predicate
import org.apache.spark.sql.connector.write.{DeltaWrite, RowLevelOperation, RowLevelOperationTable, SupportsDelta, Write}
import org.apache.spark.sql.connector.write.RowLevelOperation.Command.{DELETE, MERGE, UPDATE}
import org.apache.spark.sql.errors.DataTypeErrors.toSQLType
import org.apache.spark.sql.errors.QueryExecutionErrors
import org.apache.spark.sql.execution.datasources.v2.{DataSourceV2Relation, ExtractV2Table}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types.{BooleanType, DataType, IntegerType, MapType, MetadataBuilder, StringType, StructType}
import org.apache.spark.sql.types.{ArrayType, AtomicType, BooleanType, DataType, IntegerType, MapType, MetadataBuilder, StringType, StructField, StructType}
import org.apache.spark.sql.util.SchemaUtils
import org.apache.spark.util.ArrayImplicits._
import org.apache.spark.util.Utils

Expand Down Expand Up @@ -1004,11 +1007,15 @@ case class MergeIntoTable(
case _ => false
}

/**
* Catalog changes for MERGE auto schema evolution, produced from UPDATE/INSERT assignments.
*
* Unlike INSERT evolution (struct diff of table vs query), MERGE uses assignment-driven
* [[TableChange]]s from [[MergeIntoTable.computePendingSchemaChanges]].
*/
override lazy val pendingSchemaChanges: Seq[TableChange] = {
if (schemaEvolutionEnabled && schemaEvolutionReady) {
val referencedSourceSchema = MergeIntoTable.sourceSchemaForSchemaEvolution(this)
ResolveSchemaEvolution.computeSchemaChanges(
targetTable.schema, referencedSourceSchema, isByName = true).toSeq
MergeIntoTable.computePendingSchemaChanges(this)
} else {
Seq.empty
}
Expand Down Expand Up @@ -1062,52 +1069,168 @@ object MergeIntoTable {
.toSet
}

// A pruned version of source schema that only contains columns/nested fields
// explicitly and directly assigned to a target counterpart in MERGE INTO actions,
// which are relevant for schema evolution.
// Examples:
// * UPDATE SET target.a = source.a
// * UPDATE SET nested.a = source.nested.a
// * INSERT (a, nested.b) VALUES (source.a, source.nested.b)
// New columns/nested fields in this schema that are not existing in target schema
// will be added for schema evolution.
def sourceSchemaForSchemaEvolution(merge: MergeIntoTable): StructType = {
/**
* Builds the list of catalog changes for `MERGE ... WITH SCHEMA EVOLUTION` from the explicit
* `SET` / `VALUES` assignments in `WHEN MATCHED` (UPDATE) and `WHEN NOT MATCHED` (INSERT)
* clauses.
*
* `UPDATE *` and `INSERT *` return no changes; those branches need to be expanded by other
* analysis steps before this logic applies.
*
* Only assignments that copy from a source column (or nested field) into the same path on the
* target are considered, including new target columns that do not exist on the table yet but name
* the same path as that source field.
*
* From those assignments we may produce:
* - `addColumn` when the assignment targets a new column and the table does not already have it
* at that name/path.
* - `updateColumnType` when an existing target column and the matching source column disagree on
* a simple (non-struct) type (for example widening `INT` to `BIGINT`).
* - Extra nested `addColumn` steps when the source side has struct fields (including inside
* arrays or maps) that the target table row does not yet store at the same path.
* - Nothing extra when the types already line up for that assignment.
*
* @param merge analyzed MERGE command (must satisfy `schemaEvolutionEnabled` and
* `schemaEvolutionReady` on the caller side)
* @return catalog edits to apply to the target table, deduplicated and ordered by assignment
* then stable set iteration
*/
private def computePendingSchemaChanges(merge: MergeIntoTable): Seq[TableChange] = {
val actions = merge.matchedActions ++ merge.notMatchedActions
val assignments = actions.collect {
val originalTarget = merge.targetTable.schema
val originalSource = merge.sourceTable.schema

val schemaEvolutionAssignments = actions.flatMap {
case a: UpdateAction => a.assignments
case a: InsertAction => a.assignments
}.flatten

val containsStarAction = actions.exists {
case _: UpdateStarAction => true
case _: InsertStarAction => true
case _ => false
case _: UpdateStarAction | _: InsertStarAction => Seq.empty
case _ => Seq.empty
}.filter(isSchemaEvolutionCandidate(_, merge.sourceTable))

val changes = mutable.LinkedHashSet.empty[TableChange]
val failIncompatible: () => Nothing = () =>
throw QueryExecutionErrors.failedToMergeIncompatibleSchemasError(
originalTarget, originalSource, null)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: failIncompatible passes null as the cause, the error only shows full target/source schemas with no hint about which field path actually conflicts. Since fieldPath, keyType, and valueType are already available at the call site, should we include them in the exception? Would make debugging much easier for deeply nested schemas.


schemaEvolutionAssignments.foreach {
case a if !a.key.resolved =>
val fieldPath = extractFieldPath(a.key, allowUnresolved = true)
if (fieldPath.nonEmpty &&
!SchemaUtils.fieldExistsAtPath(originalTarget, fieldPath)) {
changes += TableChange.addColumn(fieldPath.toArray, a.value.dataType.asNullable)
}
case a if a.key.dataType != a.value.dataType =>
computeTypeSchemaChanges(
a.key.dataType,
a.value.dataType,
changes,
fieldPath = extractFieldPath(a.key, allowUnresolved = false),
targetTypeAtPath = originalTarget,
failIncompatible)
case _ =>
}

def filterSchema(sourceSchema: StructType, basePath: Seq[String]): StructType =
StructType(sourceSchema.flatMap { field =>
val fieldPath = basePath :+ field.name

field.dataType match {
// Specifically assigned to in one clause:
// always keep, including all nested attributes
case _ if assignments.exists(isEqual(_, fieldPath)) => Some(field)
// If this is a struct and one of the children is being assigned to in a merge clause,
// keep it and continue filtering children.
case struct: StructType if assignments.exists(assign =>
isPrefix(fieldPath, extractFieldPath(assign.key, allowUnresolved = true))) =>
Some(field.copy(dataType = filterSchema(struct, fieldPath)))
// The field isn't assigned to directly or indirectly (i.e. its children) in any non-*
// clause. Check if it should be kept with any * action.
case struct: StructType if containsStarAction =>
Some(field.copy(dataType = filterSchema(struct, fieldPath)))
case _ if containsStarAction => Some(field)
// The field and its children are not assigned to in any * or non-* action, drop it.
case _ => None
changes.toSeq
}

/**
* Recursively compares assignment key vs value types at `fieldPath` and appends matching
* `addColumn` / `updateColumnType` entries to `changes`.
*
* `keyType` and `valueType` come from the assignment expression types (MERGE target side vs
* source side). `targetTypeAtPath` is the type of the same path in the current
* MERGE target table.
*
* @param keyType type of the assignment key at this path (MERGE target column expression)
* @param valueType type of the assignment value at this path (typically source column)
* @param changes accumulator for [[TableChange]] instances
* @param fieldPath qualified path segments for nested columns (`element` / `key` / `value`
* under arrays and mapss)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

typo? mapssmaps

* @param targetTypeAtPath type of the loaded MERGE target table at `fieldPath`
* @param failIncompatible error handling when assignment types cannot be reconciled
*/
private def computeTypeSchemaChanges(
keyType: DataType,
valueType: DataType,
changes: mutable.LinkedHashSet[TableChange],
fieldPath: Seq[String],
targetTypeAtPath: DataType,
failIncompatible: () => Nothing): Unit = {
(keyType, valueType) match {
case (StructType(keyFields), StructType(valueFields)) =>
val keyFieldMap = SchemaUtils.toFieldMap(keyFields)
val valueFieldMap = SchemaUtils.toFieldMap(valueFields)
val targetFieldMap = targetTypeAtPath match {
case st: StructType => SchemaUtils.toFieldMap(st.fields)
case _ => Map.empty[String, StructField]
}

keyFields
.filter(f => valueFieldMap.contains(f.name))
.foreach { f =>
val nextTargetType =
targetFieldMap.get(f.name).map(_.dataType).getOrElse(f.dataType)
computeTypeSchemaChanges(
f.dataType,
valueFieldMap(f.name).dataType,
changes,
fieldPath ++ Seq(f.name),
nextTargetType,
failIncompatible)
}

valueFields
.filterNot(f => keyFieldMap.contains(f.name))
.foreach { f =>
if (!targetFieldMap.contains(f.name)) {
changes += TableChange.addColumn(
(fieldPath :+ f.name).toArray,
f.dataType.asNullable)
}
}

case (ArrayType(keyElemType, _), ArrayType(valueElemType, _)) =>
val nextTargetType = targetTypeAtPath match {
case ArrayType(elementType, _) => elementType
case _ => keyElemType
}
computeTypeSchemaChanges(
keyElemType,
valueElemType,
changes,
fieldPath :+ "element",
nextTargetType,
failIncompatible)

case (MapType(keySideMapKeyType, keySideMapValueType, _),
MapType(valueSideMapKeyType, valueSideMapValueType, _)) =>
val (nextMapKeyTargetType, nextMapValueTargetType) = targetTypeAtPath match {
case MapType(kt, vt, _) => (kt, vt)
case _ => (keySideMapKeyType, keySideMapValueType)
}
})
computeTypeSchemaChanges(
keySideMapKeyType,
valueSideMapKeyType,
changes,
fieldPath :+ "key",
nextMapKeyTargetType,
failIncompatible)
computeTypeSchemaChanges(
keySideMapValueType,
valueSideMapValueType,
changes,
fieldPath :+ "value",
nextMapValueTargetType,
failIncompatible)

case (kt: AtomicType, vt: AtomicType) if kt != vt =>
changes += TableChange.updateColumnType(fieldPath.toArray, vt)

case (kt, vt) if kt == vt =>

filterSchema(merge.sourceTable.schema, Seq.empty)
case _ =>
failIncompatible()
}
}

// Helper method to extract field path from an Expression.
Expand All @@ -1121,24 +1244,6 @@ object MergeIntoTable {
}
}

// Helper method to check if a given field path is a prefix of another path.
private def isPrefix(prefix: Seq[String], path: Seq[String]): Boolean =
prefix.length <= path.length && prefix.zip(path).forall {
case (prefixNamePart, pathNamePart) =>
SQLConf.get.resolver(prefixNamePart, pathNamePart)
}

// Helper method to check if an assignment key is equal to a source column
// and if the assignment value is that same source column.
// Example: UPDATE SET target.a = source.a
private def isEqual(assignment: Assignment, sourceFieldPath: Seq[String]): Boolean = {
// key must be a non-qualified field path that may be added to target schema via evolution
val assignmentKeyExpr = extractFieldPath(assignment.key, allowUnresolved = true)
// value should always be resolved (from source)
val assignmentValueExpr = extractFieldPath(assignment.value, allowUnresolved = false)
assignmentKeyExpr == assignmentValueExpr && assignmentKeyExpr == sourceFieldPath
}

private def areSchemaEvolutionReady(
assignments: Seq[Assignment],
source: LogicalPlan): Boolean = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,11 @@ import scala.collection.mutable
import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.analysis._
import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, NamedExpression}
import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap
import org.apache.spark.sql.connector.catalog.CatalogV2Implicits.MultipartIdentifierHelper
import org.apache.spark.sql.connector.expressions.{BucketTransform, FieldReference, NamedTransform, Transform}
import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryExecutionErrors}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._
import org.apache.spark.sql.util.SchemaValidationMode.{ALLOW_NEW_TOP_LEVEL_FIELDS, PROHIBIT_CHANGES}
import org.apache.spark.util.ArrayImplicits._
Expand Down Expand Up @@ -125,6 +127,70 @@ private[spark] object SchemaUtils {
}
}

/**
* Returns whether a column path exists in the schema's nested structure.
*
* @param root type schema
* @param path name segments
*/
def fieldExistsAtPath(
Copy link
Member Author

@szehon-ho szehon-ho Mar 21, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

unfortunately this is still needed, but only for the top level unresolved reference case

root: StructType,
path: Seq[String]): Boolean = {
if (path.isEmpty) {
false
} else {
fieldExistsAtPathInternal(root, path)
}
}

private def fieldExistsAtPathInternal(
dt: DataType,
parts: Seq[String]): Boolean = {
def checkAndRecurse(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks correct.
nit: checkAndRecurse seems unnecessary, can we inline the logic? Also, can we consider rewriting the recursion using pattern matching on parts so the base case is handled in one place?

nextType: DataType,
remaining: Seq[String]): Boolean = {
if (remaining.isEmpty) {
true
} else {
fieldExistsAtPathInternal(nextType, remaining)
}
}

if (parts.isEmpty) {
true
} else {
dt match {
case st: StructType =>
toFieldMap(st.fields).get(parts.head) match {
case Some(f) => checkAndRecurse(f.dataType, parts.tail)
case None => false
}
case ArrayType(elementType, _) if parts.head == "element" =>
checkAndRecurse(elementType, parts.tail)
case MapType(keyType, _, _) if parts.head == "key" =>
checkAndRecurse(keyType, parts.tail)
case MapType(_, valueType, _) if parts.head == "value" =>
checkAndRecurse(valueType, parts.tail)
case _ => false
}
}
}

/**
* Returns a map of field name to StructField for the given fields.
* @param fields the fields to create the map for
* @return a map of field name to StructField for the given fields
*/
def toFieldMap(
fields: Array[StructField]): Map[String, StructField] = {
val fieldMap = fields.map(f => f.name -> f).toMap
if (SQLConf.get.caseSensitiveAnalysis) {
fieldMap
} else {
CaseInsensitiveMap(fieldMap)
}
}

/**
* Checks if input column names have duplicate identifiers. This throws an exception if
* the duplication exists.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,12 @@ import java.util.Locale

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.SQLConfHelper
import org.apache.spark.sql.catalyst.analysis._
import org.apache.spark.sql.types.{ArrayType, LongType, MapType, StructType}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types.{ArrayType, LongType, MapType, StringType, StructType}

class SchemaUtilsSuite extends SparkFunSuite {
class SchemaUtilsSuite extends SparkFunSuite with SQLConfHelper {

private def resolver(caseSensitiveAnalysis: Boolean): Resolver = {
if (caseSensitiveAnalysis) {
Expand Down Expand Up @@ -110,4 +112,27 @@ class SchemaUtilsSuite extends SparkFunSuite {
parameters = Map("columnName" -> "`camelcase`"))
}
}

test("fieldExistsAtPath: structs, arrays, maps, and name case rules") {
val nested = new StructType().add("y", LongType)
val root = new StructType()
.add("a", LongType)
.add("S", nested)
.add("arr", ArrayType(LongType))
.add("m", MapType(StringType, LongType))

assert(!SchemaUtils.fieldExistsAtPath(root, Seq.empty))
assert(SchemaUtils.fieldExistsAtPath(root, Seq("a")))
withSQLConf(SQLConf.CASE_SENSITIVE.key -> "true") {
assert(!SchemaUtils.fieldExistsAtPath(root, Seq("A")))
}
withSQLConf(SQLConf.CASE_SENSITIVE.key -> "false") {
assert(SchemaUtils.fieldExistsAtPath(root, Seq("A")))
}
assert(SchemaUtils.fieldExistsAtPath(root, Seq("S", "y")))
assert(SchemaUtils.fieldExistsAtPath(root, Seq("arr", "element")))
assert(SchemaUtils.fieldExistsAtPath(root, Seq("m", "key")))
assert(SchemaUtils.fieldExistsAtPath(root, Seq("m", "value")))
assert(!SchemaUtils.fieldExistsAtPath(root, Seq("missing")))
}
}
Loading