From 4ba9dbac01b3697694b70f27d39136b65b105086 Mon Sep 17 00:00:00 2001 From: Szehon Ho Date: Fri, 20 Mar 2026 17:49:45 -0700 Subject: [PATCH 1/2] [SPARK-56125][SQL] Simplify schema calculation code for Merge Into Schema Evolution --- .../catalyst/plans/logical/v2Commands.scala | 241 +++++++++++++----- .../apache/spark/sql/util/SchemaUtils.scala | 78 ++++++ .../spark/sql/util/SchemaUtilsSuite.scala | 21 +- ...ergeIntoSchemaEvolutionExtraSQLTests.scala | 2 +- 4 files changed, 279 insertions(+), 63 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2Commands.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2Commands.scala index d6beacadbb674..75eeead14fb39 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2Commands.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2Commands.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.catalyst.plans.logical +import scala.collection.mutable + import org.apache.spark.{SparkException, SparkIllegalArgumentException, SparkUnsupportedOperationException} import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.analysis.{AnalysisContext, AssignmentUtils, EliminateSubqueryAliases, FieldName, NamedRelation, PartitionSpec, ResolvedIdentifier, ResolvedProcedure, ResolveSchemaEvolution, TypeCheckResult, UnresolvedAttribute, UnresolvedException, UnresolvedProcedure, ViewSchemaMode} @@ -38,9 +40,11 @@ import org.apache.spark.sql.connector.expressions.filter.Predicate import org.apache.spark.sql.connector.write.{DeltaWrite, RowLevelOperation, RowLevelOperationTable, SupportsDelta, Write} import org.apache.spark.sql.connector.write.RowLevelOperation.Command.{DELETE, MERGE, UPDATE} import org.apache.spark.sql.errors.DataTypeErrors.toSQLType +import org.apache.spark.sql.errors.QueryExecutionErrors import org.apache.spark.sql.execution.datasources.v2.{DataSourceV2Relation, ExtractV2Table} import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.types.{BooleanType, DataType, IntegerType, MapType, MetadataBuilder, StringType, StructType} +import org.apache.spark.sql.types.{ArrayType, AtomicType, BooleanType, DataType, IntegerType, MapType, MetadataBuilder, StringType, StructField, StructType} +import org.apache.spark.sql.util.SchemaUtils import org.apache.spark.util.ArrayImplicits._ import org.apache.spark.util.Utils @@ -1004,11 +1008,15 @@ case class MergeIntoTable( case _ => false } + /** + * Catalog changes for MERGE auto schema evolution, produced from UPDATE/INSERT assignments. + * + * Unlike INSERT evolution (struct diff of table vs query), MERGE uses assignment-driven + * [[TableChange]]s from [[MergeIntoTable.computePendingSchemaChanges]]. + */ override lazy val pendingSchemaChanges: Seq[TableChange] = { if (schemaEvolutionEnabled && schemaEvolutionReady) { - val referencedSourceSchema = MergeIntoTable.sourceSchemaForSchemaEvolution(this) - ResolveSchemaEvolution.computeSchemaChanges( - targetTable.schema, referencedSourceSchema, isByName = true).toSeq + MergeIntoTable.computePendingSchemaChanges(this) } else { Seq.empty } @@ -1062,52 +1070,181 @@ object MergeIntoTable { .toSet } - // A pruned version of source schema that only contains columns/nested fields - // explicitly and directly assigned to a target counterpart in MERGE INTO actions, - // which are relevant for schema evolution. - // Examples: - // * UPDATE SET target.a = source.a - // * UPDATE SET nested.a = source.nested.a - // * INSERT (a, nested.b) VALUES (source.a, source.nested.b) - // New columns/nested fields in this schema that are not existing in target schema - // will be added for schema evolution. - def sourceSchemaForSchemaEvolution(merge: MergeIntoTable): StructType = { + /** + * Builds the list of catalog changes for `MERGE ... WITH SCHEMA EVOLUTION` from the explicit + * `SET` / `VALUES` assignments in `WHEN MATCHED` (UPDATE) and `WHEN NOT MATCHED` (INSERT) + * clauses. + * + * `UPDATE *` and `INSERT *` return no changes; those branches need to be expanded by other + * analysis steps before this logic applies. + * + * Only assignments that copy from a source column (or nested field) into the same path on the + * target are considered—including new target columns that do not exist on the table yet but name + * the same path as that source field. + * + * From those assignments we may produce: + * - `addColumn` when the assignment targets a new column and the table does not already have it + * at that name/path. + * - `updateColumnType` when an existing target column and the matching source column disagree on + * a simple (non-struct) type—for example widening `INT` to `BIGINT`. + * - Extra nested `addColumn` steps when the source side has struct fields (including inside + * arrays or maps) that the target table row does not yet store at the same path. + * - Nothing extra when the types already line up for that assignment. + * + * @param merge analyzed MERGE command (must satisfy `schemaEvolutionEnabled` and + * `schemaEvolutionReady` on the caller side) + * @return catalog edits to apply to the target table, deduplicated and ordered by assignment + * then stable set iteration + */ + private def computePendingSchemaChanges(merge: MergeIntoTable): Seq[TableChange] = { val actions = merge.matchedActions ++ merge.notMatchedActions - val assignments = actions.collect { + val originalTarget = merge.targetTable.schema + val originalSource = merge.sourceTable.schema + + val schemaEvolutionAssignments = actions.flatMap { case a: UpdateAction => a.assignments case a: InsertAction => a.assignments - }.flatten - - val containsStarAction = actions.exists { - case _: UpdateStarAction => true - case _: InsertStarAction => true - case _ => false + case _: UpdateStarAction | _: InsertStarAction => Seq.empty + case _ => Seq.empty + }.filter(isSchemaEvolutionCandidate(_, merge.sourceTable)) + + val changes = mutable.LinkedHashSet.empty[TableChange] + val failIncompatible: () => Nothing = () => + throw QueryExecutionErrors.failedToMergeIncompatibleSchemasError( + originalTarget, originalSource, null) + + schemaEvolutionAssignments.foreach { + case a if !a.key.resolved => + val fieldPath = extractFieldPath(a.key, allowUnresolved = true) + if (fieldPath.nonEmpty && + !SchemaUtils.fieldExistsAtPath(originalTarget, fieldPath, SQLConf.get.resolver)) { + changes += TableChange.addColumn(fieldPath.toArray, a.value.dataType.asNullable) + } + case a if a.key.dataType != a.value.dataType => + computeTypeSchemaChanges( + a.key.dataType, + a.value.dataType, + changes, + fieldPath = extractFieldPath(a.key, allowUnresolved = false), + targetTypeAtPath = originalTarget, + failIncompatible) + case _ => } - def filterSchema(sourceSchema: StructType, basePath: Seq[String]): StructType = - StructType(sourceSchema.flatMap { field => - val fieldPath = basePath :+ field.name - - field.dataType match { - // Specifically assigned to in one clause: - // always keep, including all nested attributes - case _ if assignments.exists(isEqual(_, fieldPath)) => Some(field) - // If this is a struct and one of the children is being assigned to in a merge clause, - // keep it and continue filtering children. - case struct: StructType if assignments.exists(assign => - isPrefix(fieldPath, extractFieldPath(assign.key, allowUnresolved = true))) => - Some(field.copy(dataType = filterSchema(struct, fieldPath))) - // The field isn't assigned to directly or indirectly (i.e. its children) in any non-* - // clause. Check if it should be kept with any * action. - case struct: StructType if containsStarAction => - Some(field.copy(dataType = filterSchema(struct, fieldPath))) - case _ if containsStarAction => Some(field) - // The field and its children are not assigned to in any * or non-* action, drop it. - case _ => None + changes.toSeq + } + + /** + * Recursively compares assignment key vs value types at `fieldPath` and appends matching + * `addColumn` / `updateColumnType` entries to `changes`. + * + * `keyType` and `valueType` come from the assignment expression types (MERGE target side vs + * source side). `targetTypeAtPath` is the type of the same path in `merge.targetTable.schema` + * (catalog-backed). Those can differ after `alterTable` + reload: the table may already include + * a new field while the key expression still carries an older [[DataType]], so nested + * `addColumn` is skipped when the table struct already contains the field name. + * + * @param keyType type of the assignment key at this path (MERGE target column expression) + * @param valueType type of the assignment value at this path (typically source column) + * @param changes accumulator for [[TableChange]] instances + * @param fieldPath qualified path segments for nested columns (`element` / `key` / `value` + * under arrays and maps, per DSv2 [[TableChange]] conventions) + * @param targetTypeAtPath type of the loaded target table at `fieldPath` (not necessarily + * equal to `keyType` when catalog and analyzer are out of sync) + * @param failIncompatible called when assignment types cannot be reconciled; throws with MERGE + * target and source schemas for the error message + */ + private def computeTypeSchemaChanges( + keyType: DataType, + valueType: DataType, + changes: mutable.LinkedHashSet[TableChange], + fieldPath: Seq[String], + targetTypeAtPath: DataType, + failIncompatible: () => Nothing): Unit = { + (keyType, valueType) match { + case (StructType(keyFields), StructType(valueFields)) => + val keyFieldMap = toFieldMap(keyFields) + val valueFieldMap = toFieldMap(valueFields) + val targetFieldMap = targetTypeAtPath match { + case st: StructType => toFieldMap(st.fields) + case _ => Map.empty[String, StructField] + } + + keyFields + .filter(f => valueFieldMap.contains(f.name)) + .foreach { f => + val nextTargetType = + targetFieldMap.get(f.name).map(_.dataType).getOrElse(f.dataType) + computeTypeSchemaChanges( + f.dataType, + valueFieldMap(f.name).dataType, + changes, + fieldPath ++ Seq(f.name), + nextTargetType, + failIncompatible) + } + + valueFields + .filterNot(f => keyFieldMap.contains(f.name)) + .foreach { f => + if (!targetFieldMap.contains(f.name)) { + changes += TableChange.addColumn( + (fieldPath :+ f.name).toArray, + f.dataType.asNullable) + } + } + + case (ArrayType(keyElemType, _), ArrayType(valueElemType, _)) => + val nextTargetType = targetTypeAtPath match { + case ArrayType(elementType, _) => elementType + case _ => keyElemType + } + computeTypeSchemaChanges( + keyElemType, + valueElemType, + changes, + fieldPath :+ "element", + nextTargetType, + failIncompatible) + + case (MapType(keySideMapKeyType, keySideMapValueType, _), + MapType(valueSideMapKeyType, valueSideMapValueType, _)) => + val (nextMapKeyTargetType, nextMapValueTargetType) = targetTypeAtPath match { + case MapType(kt, vt, _) => (kt, vt) + case _ => (keySideMapKeyType, keySideMapValueType) } - }) + computeTypeSchemaChanges( + keySideMapKeyType, + valueSideMapKeyType, + changes, + fieldPath :+ "key", + nextMapKeyTargetType, + failIncompatible) + computeTypeSchemaChanges( + keySideMapValueType, + valueSideMapValueType, + changes, + fieldPath :+ "value", + nextMapValueTargetType, + failIncompatible) + + case (kt: AtomicType, vt: AtomicType) if kt != vt => + changes += TableChange.updateColumnType(fieldPath.toArray, vt) + + case (kt, vt) if kt == vt => + + case _ => + failIncompatible() + } + } - filterSchema(merge.sourceTable.schema, Seq.empty) + private def toFieldMap(fields: Array[StructField]): Map[String, StructField] = { + val fieldMap = fields.map(f => f.name -> f).toMap + if (SQLConf.get.caseSensitiveAnalysis) { + fieldMap + } else { + CaseInsensitiveMap(fieldMap) + } } // Helper method to extract field path from an Expression. @@ -1121,24 +1258,6 @@ object MergeIntoTable { } } - // Helper method to check if a given field path is a prefix of another path. - private def isPrefix(prefix: Seq[String], path: Seq[String]): Boolean = - prefix.length <= path.length && prefix.zip(path).forall { - case (prefixNamePart, pathNamePart) => - SQLConf.get.resolver(prefixNamePart, pathNamePart) - } - - // Helper method to check if an assignment key is equal to a source column - // and if the assignment value is that same source column. - // Example: UPDATE SET target.a = source.a - private def isEqual(assignment: Assignment, sourceFieldPath: Seq[String]): Boolean = { - // key must be a non-qualified field path that may be added to target schema via evolution - val assignmentKeyExpr = extractFieldPath(assignment.key, allowUnresolved = true) - // value should always be resolved (from source) - val assignmentValueExpr = extractFieldPath(assignment.value, allowUnresolved = false) - assignmentKeyExpr == assignmentValueExpr && assignmentKeyExpr == sourceFieldPath - } - private def areSchemaEvolutionReady( assignments: Seq[Assignment], source: LogicalPlan): Boolean = { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/util/SchemaUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/util/SchemaUtils.scala index 58ababa04739f..3c9c64d356201 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/util/SchemaUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/util/SchemaUtils.scala @@ -25,6 +25,7 @@ import scala.collection.mutable import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.analysis._ import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, NamedExpression} +import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap import org.apache.spark.sql.connector.catalog.CatalogV2Implicits.MultipartIdentifierHelper import org.apache.spark.sql.connector.expressions.{BucketTransform, FieldReference, NamedTransform, Transform} import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryExecutionErrors} @@ -125,6 +126,83 @@ private[spark] object SchemaUtils { } } + /** + * Returns whether a column path exists in the schema's nested structure. + * + * Walks [[StructType]] fields and DSv2 nested path segments (`element`, `key`, `value`) + * through [[ArrayType]] / [[MapType]], consistent with + * [[org.apache.spark.sql.connector.catalog.TableChange]] path conventions. + * Struct field name matching uses the same case rules as + * [[checkSchemaColumnNameDuplication]]. + * + * @param root top-level row schema (e.g. a table schema) + * @param path name segments (e.g. from an unresolved attribute or path extractor) + * @param resolver resolver used to compare struct field names + */ + def fieldExistsAtPath(root: StructType, path: Seq[String], resolver: Resolver): Boolean = { + fieldExistsAtPath(root, path, isCaseSensitiveAnalysis(resolver)) + } + + /** + * Returns whether a column path exists in the schema's nested structure. + * + * @param root top-level row schema (e.g. a table schema) + * @param path name segments (e.g. from an unresolved attribute or path extractor) + * @param caseSensitiveAnalysis whether struct field name matching is case sensitive + */ + def fieldExistsAtPath( + root: StructType, + path: Seq[String], + caseSensitiveAnalysis: Boolean): Boolean = { + fieldExistsAtPathInternal(root, path, caseSensitiveAnalysis) + } + + private def fieldExistsAtPathInternal( + dt: DataType, + parts: Seq[String], + caseSensitiveAnalysis: Boolean): Boolean = { + def checkAndRecurse( + nextType: DataType, + remaining: Seq[String], + caseSensitiveAnalysis: Boolean): Boolean = { + if (remaining.isEmpty) { + true + } else { + fieldExistsAtPathInternal(nextType, remaining, caseSensitiveAnalysis) + } + } + + if (parts.isEmpty) { + true + } else { + dt match { + case st: StructType => + toFieldMap(st.fields, caseSensitiveAnalysis).get(parts.head) match { + case Some(f) => checkAndRecurse(f.dataType, parts.tail, caseSensitiveAnalysis) + case None => false + } + case ArrayType(elementType, _) if parts.head == "element" => + checkAndRecurse(elementType, parts.tail, caseSensitiveAnalysis) + case MapType(keyType, _, _) if parts.head == "key" => + checkAndRecurse(keyType, parts.tail, caseSensitiveAnalysis) + case MapType(_, valueType, _) if parts.head == "value" => + checkAndRecurse(valueType, parts.tail, caseSensitiveAnalysis) + case _ => false + } + } + } + + private def toFieldMap( + fields: Array[StructField], + caseSensitiveAnalysis: Boolean): Map[String, StructField] = { + val fieldMap = fields.map(f => f.name -> f).toMap + if (caseSensitiveAnalysis) { + fieldMap + } else { + CaseInsensitiveMap(fieldMap) + } + } + /** * Checks if input column names have duplicate identifiers. This throws an exception if * the duplication exists. diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/util/SchemaUtilsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/util/SchemaUtilsSuite.scala index a277bb021c3f6..a9d71e736a679 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/util/SchemaUtilsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/util/SchemaUtilsSuite.scala @@ -22,7 +22,7 @@ import java.util.Locale import org.apache.spark.SparkFunSuite import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.analysis._ -import org.apache.spark.sql.types.{ArrayType, LongType, MapType, StructType} +import org.apache.spark.sql.types.{ArrayType, LongType, MapType, StringType, StructType} class SchemaUtilsSuite extends SparkFunSuite { @@ -110,4 +110,23 @@ class SchemaUtilsSuite extends SparkFunSuite { parameters = Map("columnName" -> "`camelcase`")) } } + + test("fieldExistsAtPath: structs, arrays, maps, and name case rules") { + val nested = new StructType().add("y", LongType) + val root = new StructType() + .add("a", LongType) + .add("S", nested) + .add("arr", ArrayType(LongType)) + .add("m", MapType(StringType, LongType)) + + assert(!SchemaUtils.fieldExistsAtPath(root, Seq.empty, resolver(true))) + assert(SchemaUtils.fieldExistsAtPath(root, Seq("a"), resolver(true))) + assert(!SchemaUtils.fieldExistsAtPath(root, Seq("A"), resolver(true))) + assert(SchemaUtils.fieldExistsAtPath(root, Seq("A"), resolver(false))) + assert(SchemaUtils.fieldExistsAtPath(root, Seq("S", "y"), resolver(true))) + assert(SchemaUtils.fieldExistsAtPath(root, Seq("arr", "element"), resolver(true))) + assert(SchemaUtils.fieldExistsAtPath(root, Seq("m", "key"), resolver(true))) + assert(SchemaUtils.fieldExistsAtPath(root, Seq("m", "value"), resolver(true))) + assert(!SchemaUtils.fieldExistsAtPath(root, Seq("missing"), resolver(true))) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/MergeIntoSchemaEvolutionExtraSQLTests.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/MergeIntoSchemaEvolutionExtraSQLTests.scala index 8565c0b31c0c1..f4308162277c7 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/MergeIntoSchemaEvolutionExtraSQLTests.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/MergeIntoSchemaEvolutionExtraSQLTests.scala @@ -142,7 +142,7 @@ trait MergeIntoSchemaEvolutionExtraSQLTests extends RowLevelOperationSuiteBase { s"Error message should mention table name: ${ex.getMessage}") val msg = ex.getMessage - val expectedChanges = "ALTER COLUMN pk TYPE BIGINT; ADD COLUMN active BOOLEAN" + val expectedChanges = "ADD COLUMN active BOOLEAN; ALTER COLUMN pk TYPE BIGINT" assert(msg.contains(expectedChanges), s"Error message should contain exact changes '$expectedChanges': $msg") } From 58e32836c030129b23b813f0619e3b8662682a39 Mon Sep 17 00:00:00 2001 From: Szehon Ho Date: Fri, 20 Mar 2026 18:44:45 -0700 Subject: [PATCH 2/2] Some more cleanup --- .../catalyst/plans/logical/v2Commands.scala | 36 ++++------- .../apache/spark/sql/util/SchemaUtils.scala | 62 ++++++++----------- .../spark/sql/util/SchemaUtilsSuite.scala | 26 +++++--- 3 files changed, 52 insertions(+), 72 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2Commands.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2Commands.scala index 75eeead14fb39..1b8afc1511081 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2Commands.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2Commands.scala @@ -42,7 +42,6 @@ import org.apache.spark.sql.connector.write.RowLevelOperation.Command.{DELETE, M import org.apache.spark.sql.errors.DataTypeErrors.toSQLType import org.apache.spark.sql.errors.QueryExecutionErrors import org.apache.spark.sql.execution.datasources.v2.{DataSourceV2Relation, ExtractV2Table} -import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.{ArrayType, AtomicType, BooleanType, DataType, IntegerType, MapType, MetadataBuilder, StringType, StructField, StructType} import org.apache.spark.sql.util.SchemaUtils import org.apache.spark.util.ArrayImplicits._ @@ -1079,14 +1078,14 @@ object MergeIntoTable { * analysis steps before this logic applies. * * Only assignments that copy from a source column (or nested field) into the same path on the - * target are considered—including new target columns that do not exist on the table yet but name + * target are considered, including new target columns that do not exist on the table yet but name * the same path as that source field. * * From those assignments we may produce: * - `addColumn` when the assignment targets a new column and the table does not already have it * at that name/path. * - `updateColumnType` when an existing target column and the matching source column disagree on - * a simple (non-struct) type—for example widening `INT` to `BIGINT`. + * a simple (non-struct) type (for example widening `INT` to `BIGINT`). * - Extra nested `addColumn` steps when the source side has struct fields (including inside * arrays or maps) that the target table row does not yet store at the same path. * - Nothing extra when the types already line up for that assignment. @@ -1117,7 +1116,7 @@ object MergeIntoTable { case a if !a.key.resolved => val fieldPath = extractFieldPath(a.key, allowUnresolved = true) if (fieldPath.nonEmpty && - !SchemaUtils.fieldExistsAtPath(originalTarget, fieldPath, SQLConf.get.resolver)) { + !SchemaUtils.fieldExistsAtPath(originalTarget, fieldPath)) { changes += TableChange.addColumn(fieldPath.toArray, a.value.dataType.asNullable) } case a if a.key.dataType != a.value.dataType => @@ -1139,20 +1138,16 @@ object MergeIntoTable { * `addColumn` / `updateColumnType` entries to `changes`. * * `keyType` and `valueType` come from the assignment expression types (MERGE target side vs - * source side). `targetTypeAtPath` is the type of the same path in `merge.targetTable.schema` - * (catalog-backed). Those can differ after `alterTable` + reload: the table may already include - * a new field while the key expression still carries an older [[DataType]], so nested - * `addColumn` is skipped when the table struct already contains the field name. + * source side). `targetTypeAtPath` is the type of the same path in the current + * MERGE target table. * * @param keyType type of the assignment key at this path (MERGE target column expression) * @param valueType type of the assignment value at this path (typically source column) * @param changes accumulator for [[TableChange]] instances * @param fieldPath qualified path segments for nested columns (`element` / `key` / `value` - * under arrays and maps, per DSv2 [[TableChange]] conventions) - * @param targetTypeAtPath type of the loaded target table at `fieldPath` (not necessarily - * equal to `keyType` when catalog and analyzer are out of sync) - * @param failIncompatible called when assignment types cannot be reconciled; throws with MERGE - * target and source schemas for the error message + * under arrays and mapss) + * @param targetTypeAtPath type of the loaded MERGE target table at `fieldPath` + * @param failIncompatible error handling when assignment types cannot be reconciled */ private def computeTypeSchemaChanges( keyType: DataType, @@ -1163,10 +1158,10 @@ object MergeIntoTable { failIncompatible: () => Nothing): Unit = { (keyType, valueType) match { case (StructType(keyFields), StructType(valueFields)) => - val keyFieldMap = toFieldMap(keyFields) - val valueFieldMap = toFieldMap(valueFields) + val keyFieldMap = SchemaUtils.toFieldMap(keyFields) + val valueFieldMap = SchemaUtils.toFieldMap(valueFields) val targetFieldMap = targetTypeAtPath match { - case st: StructType => toFieldMap(st.fields) + case st: StructType => SchemaUtils.toFieldMap(st.fields) case _ => Map.empty[String, StructField] } @@ -1238,15 +1233,6 @@ object MergeIntoTable { } } - private def toFieldMap(fields: Array[StructField]): Map[String, StructField] = { - val fieldMap = fields.map(f => f.name -> f).toMap - if (SQLConf.get.caseSensitiveAnalysis) { - fieldMap - } else { - CaseInsensitiveMap(fieldMap) - } - } - // Helper method to extract field path from an Expression. private def extractFieldPath(expr: Expression, allowUnresolved: Boolean): Seq[String] = { expr match { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/util/SchemaUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/util/SchemaUtils.scala index 3c9c64d356201..385a1534d4cc5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/util/SchemaUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/util/SchemaUtils.scala @@ -29,6 +29,7 @@ import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap import org.apache.spark.sql.connector.catalog.CatalogV2Implicits.MultipartIdentifierHelper import org.apache.spark.sql.connector.expressions.{BucketTransform, FieldReference, NamedTransform, Transform} import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryExecutionErrors} +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ import org.apache.spark.sql.util.SchemaValidationMode.{ALLOW_NEW_TOP_LEVEL_FIELDS, PROHIBIT_CHANGES} import org.apache.spark.util.ArrayImplicits._ @@ -129,46 +130,29 @@ private[spark] object SchemaUtils { /** * Returns whether a column path exists in the schema's nested structure. * - * Walks [[StructType]] fields and DSv2 nested path segments (`element`, `key`, `value`) - * through [[ArrayType]] / [[MapType]], consistent with - * [[org.apache.spark.sql.connector.catalog.TableChange]] path conventions. - * Struct field name matching uses the same case rules as - * [[checkSchemaColumnNameDuplication]]. - * - * @param root top-level row schema (e.g. a table schema) - * @param path name segments (e.g. from an unresolved attribute or path extractor) - * @param resolver resolver used to compare struct field names - */ - def fieldExistsAtPath(root: StructType, path: Seq[String], resolver: Resolver): Boolean = { - fieldExistsAtPath(root, path, isCaseSensitiveAnalysis(resolver)) - } - - /** - * Returns whether a column path exists in the schema's nested structure. - * - * @param root top-level row schema (e.g. a table schema) - * @param path name segments (e.g. from an unresolved attribute or path extractor) - * @param caseSensitiveAnalysis whether struct field name matching is case sensitive + * @param root type schema + * @param path name segments */ def fieldExistsAtPath( root: StructType, - path: Seq[String], - caseSensitiveAnalysis: Boolean): Boolean = { - fieldExistsAtPathInternal(root, path, caseSensitiveAnalysis) + path: Seq[String]): Boolean = { + if (path.isEmpty) { + false + } else { + fieldExistsAtPathInternal(root, path) + } } private def fieldExistsAtPathInternal( dt: DataType, - parts: Seq[String], - caseSensitiveAnalysis: Boolean): Boolean = { + parts: Seq[String]): Boolean = { def checkAndRecurse( nextType: DataType, - remaining: Seq[String], - caseSensitiveAnalysis: Boolean): Boolean = { + remaining: Seq[String]): Boolean = { if (remaining.isEmpty) { true } else { - fieldExistsAtPathInternal(nextType, remaining, caseSensitiveAnalysis) + fieldExistsAtPathInternal(nextType, remaining) } } @@ -177,26 +161,30 @@ private[spark] object SchemaUtils { } else { dt match { case st: StructType => - toFieldMap(st.fields, caseSensitiveAnalysis).get(parts.head) match { - case Some(f) => checkAndRecurse(f.dataType, parts.tail, caseSensitiveAnalysis) + toFieldMap(st.fields).get(parts.head) match { + case Some(f) => checkAndRecurse(f.dataType, parts.tail) case None => false } case ArrayType(elementType, _) if parts.head == "element" => - checkAndRecurse(elementType, parts.tail, caseSensitiveAnalysis) + checkAndRecurse(elementType, parts.tail) case MapType(keyType, _, _) if parts.head == "key" => - checkAndRecurse(keyType, parts.tail, caseSensitiveAnalysis) + checkAndRecurse(keyType, parts.tail) case MapType(_, valueType, _) if parts.head == "value" => - checkAndRecurse(valueType, parts.tail, caseSensitiveAnalysis) + checkAndRecurse(valueType, parts.tail) case _ => false } } } - private def toFieldMap( - fields: Array[StructField], - caseSensitiveAnalysis: Boolean): Map[String, StructField] = { + /** + * Returns a map of field name to StructField for the given fields. + * @param fields the fields to create the map for + * @return a map of field name to StructField for the given fields + */ + def toFieldMap( + fields: Array[StructField]): Map[String, StructField] = { val fieldMap = fields.map(f => f.name -> f).toMap - if (caseSensitiveAnalysis) { + if (SQLConf.get.caseSensitiveAnalysis) { fieldMap } else { CaseInsensitiveMap(fieldMap) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/util/SchemaUtilsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/util/SchemaUtilsSuite.scala index a9d71e736a679..91b4edb618a6e 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/util/SchemaUtilsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/util/SchemaUtilsSuite.scala @@ -21,10 +21,12 @@ import java.util.Locale import org.apache.spark.SparkFunSuite import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.catalyst.SQLConfHelper import org.apache.spark.sql.catalyst.analysis._ +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.{ArrayType, LongType, MapType, StringType, StructType} -class SchemaUtilsSuite extends SparkFunSuite { +class SchemaUtilsSuite extends SparkFunSuite with SQLConfHelper { private def resolver(caseSensitiveAnalysis: Boolean): Resolver = { if (caseSensitiveAnalysis) { @@ -119,14 +121,18 @@ class SchemaUtilsSuite extends SparkFunSuite { .add("arr", ArrayType(LongType)) .add("m", MapType(StringType, LongType)) - assert(!SchemaUtils.fieldExistsAtPath(root, Seq.empty, resolver(true))) - assert(SchemaUtils.fieldExistsAtPath(root, Seq("a"), resolver(true))) - assert(!SchemaUtils.fieldExistsAtPath(root, Seq("A"), resolver(true))) - assert(SchemaUtils.fieldExistsAtPath(root, Seq("A"), resolver(false))) - assert(SchemaUtils.fieldExistsAtPath(root, Seq("S", "y"), resolver(true))) - assert(SchemaUtils.fieldExistsAtPath(root, Seq("arr", "element"), resolver(true))) - assert(SchemaUtils.fieldExistsAtPath(root, Seq("m", "key"), resolver(true))) - assert(SchemaUtils.fieldExistsAtPath(root, Seq("m", "value"), resolver(true))) - assert(!SchemaUtils.fieldExistsAtPath(root, Seq("missing"), resolver(true))) + assert(!SchemaUtils.fieldExistsAtPath(root, Seq.empty)) + assert(SchemaUtils.fieldExistsAtPath(root, Seq("a"))) + withSQLConf(SQLConf.CASE_SENSITIVE.key -> "true") { + assert(!SchemaUtils.fieldExistsAtPath(root, Seq("A"))) + } + withSQLConf(SQLConf.CASE_SENSITIVE.key -> "false") { + assert(SchemaUtils.fieldExistsAtPath(root, Seq("A"))) + } + assert(SchemaUtils.fieldExistsAtPath(root, Seq("S", "y"))) + assert(SchemaUtils.fieldExistsAtPath(root, Seq("arr", "element"))) + assert(SchemaUtils.fieldExistsAtPath(root, Seq("m", "key"))) + assert(SchemaUtils.fieldExistsAtPath(root, Seq("m", "value"))) + assert(!SchemaUtils.fieldExistsAtPath(root, Seq("missing"))) } }