diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index dc5c580f1b08c..6196f86a960b9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -1229,6 +1229,14 @@ object SQLConf { .bytesConf(ByteUnit.BYTE) .createOptional + val ADAPTIVE_BROADCAST_JOIN_FALLBACK_TO_SHUFFLE_ENABLED = + buildConf("spark.sql.adaptive.broadcastJoin.fallbackToShuffle.enabled") + .doc("When true, adaptive execution retries with broadcast joins disabled if a broadcast " + + "query stage fails because it exceeds broadcast table row or size limits.") + .version("4.2.0") + .booleanConf + .createWithDefault(false) + val ADAPTIVE_MAX_SHUFFLE_HASH_JOIN_LOCAL_MAP_THRESHOLD = buildConf("spark.sql.adaptive.maxShuffledHashJoinLocalMapThreshold") .doc("Configures the maximum size in bytes per partition that can be allowed to build " + @@ -7425,6 +7433,9 @@ class SQLConf extends Serializable with Logging with SqlApiConf { def nonEmptyPartitionRatioForBroadcastJoin: Double = getConf(NON_EMPTY_PARTITION_RATIO_FOR_BROADCAST_JOIN) + def adaptiveBroadcastJoinFallbackToShuffleEnabled: Boolean = + getConf(ADAPTIVE_BROADCAST_JOIN_FALLBACK_TO_SHUFFLE_ENABLED) + def coalesceShufflePartitionsEnabled: Boolean = getConf(COALESCE_PARTITIONS_ENABLED) def minBatchesToRetain: Int = getConf(MIN_BATCHES_TO_RETAIN) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveSparkPlanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveSparkPlanExec.scala index 4840016bf745d..0c1ddb598ecc6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveSparkPlanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveSparkPlanExec.scala @@ -20,20 +20,29 @@ package org.apache.spark.sql.execution.adaptive import java.util import java.util.concurrent.{ConcurrentHashMap, LinkedBlockingQueue} +import scala.annotation.tailrec import scala.collection.concurrent.TrieMap import scala.collection.mutable import scala.concurrent.ExecutionContext import scala.jdk.CollectionConverters._ import scala.util.control.NonFatal -import org.apache.spark.SparkException +import org.apache.spark.{SparkException, SparkThrowable} import org.apache.spark.broadcast import org.apache.spark.internal.LogKeys._ import org.apache.spark.internal.MessageWithContext import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.Attribute -import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, ReturnAnswer} +import org.apache.spark.sql.catalyst.plans.logical.{ + BROADCAST, + HintInfo, + Join, + JoinHint, + LogicalPlan, + NO_BROADCAST_HASH, + ReturnAnswer +} import org.apache.spark.sql.catalyst.plans.physical.{Distribution, UnspecifiedDistribution} import org.apache.spark.sql.catalyst.rules.{PlanChangeLogger, Rule} import org.apache.spark.sql.catalyst.trees.TreeNodeTag @@ -279,6 +288,9 @@ case class AdaptiveSparkPlanExec( val events = new LinkedBlockingQueue[StageMaterializationEvent]() val errors = new mutable.ArrayBuffer[Throwable]() var stagesToReplace = Seq.empty[QueryStageExec] + // Persist failed broadcast relations across AQE iterations so later replans + // cannot reintroduce the same relation and re-trigger the same broadcast failure. + val failedBroadcastStagesBuffer = new mutable.ArrayBuffer[BroadcastQueryStageExec]() while (!result.allChildStagesMaterialized) { currentPhysicalPlan = result.newPlan if (result.newStages.nonEmpty) { @@ -328,14 +340,35 @@ case class AdaptiveSparkPlanExec( val nextMsg = events.take() val rem = new util.ArrayList[StageMaterializationEvent]() events.drainTo(rem) - (Seq(nextMsg) ++ rem.asScala).foreach { + val stageEvents = Seq(nextMsg) ++ rem.asScala + val fallbackStages = new mutable.ArrayBuffer[QueryStageExec]() + val fallbackErrors = new mutable.ArrayBuffer[Throwable]() + stageEvents.foreach { case StageSuccess(stage, res) => stage.resultOption.set(Some(res)) + case StageFailure(stage, ex) if shouldFallbackToShuffleJoin(stage, ex) => + logInfo("Broadcast query stage failed on table size/row limit; retrying " + + s"adaptive replanning without broadcast joins. Stage ${stage.id}") + removeStageFromCache(stage) + registerFailedBroadcastStage(failedBroadcastStagesBuffer, stage) + if (!fallbackStages.exists(_.eq(stage))) { + fallbackStages.append(stage) + } + fallbackErrors.append(ex) case StageFailure(stage, ex) => stage.error.set(Some(ex)) errors.append(ex) } + // Do not carry failed fallback stages into the next logical-plan replacement pass. + // They are intentionally invalidated and must be rebuilt by replanning. + if (fallbackStages.nonEmpty) { + stagesToReplace = filterStagesToReplaceForFallback(stagesToReplace, fallbackStages.toSeq) + currentLogicalPlan = removeFailedBroadcastStagesFromLogicalPlan( + currentLogicalPlan, + failedBroadcastStagesBuffer.toSeq) + } + // In case of errors, we cancel all running stages and throw exception. if (errors.nonEmpty) { cleanUpAndThrowException(errors.toSeq, None) @@ -353,12 +386,33 @@ case class AdaptiveSparkPlanExec( // plans are updated, we can clear the query stage list because at this point the two // plans are semantically and physically in sync again. val logicalPlan = replaceWithQueryStagesInLogicalPlan(currentLogicalPlan, stagesToReplace) - val afterReOptimize = reOptimize(logicalPlan) + val shouldBroadcastFallback = fallbackStages.nonEmpty + val failedBroadcastStages = failedBroadcastStagesBuffer.toSeq + + val afterReOptimize = if (shouldBroadcastFallback) { + val targetedLogicalPlan = + addNoBroadcastHashHintsForFailedRelations(logicalPlan, failedBroadcastStages) + reOptimize(targetedLogicalPlan) + } else { + reOptimize(logicalPlan) + } + if (afterReOptimize.isDefined) { val (newPhysicalPlan, newLogicalPlan) = afterReOptimize.get + val rejectBroadcastFallbackPlan = hasFailedBroadcastRelation( + newPhysicalPlan, failedBroadcastStages) val origCost = costEvaluator.evaluateCost(currentPhysicalPlan) val newCost = costEvaluator.evaluateCost(newPhysicalPlan) - if (newCost < origCost || + if (rejectBroadcastFallbackPlan) { + if (shouldBroadcastFallback) { + logInfo("Adaptive fallback replan still contains failed broadcast relation; " + + "aborting without retrying broadcast join fallback.") + cleanUpAndThrowException(fallbackErrors.toSeq, None) + } else { + logDebug("Rejecting AQE replan because it reintroduces a previously failed " + + "broadcast relation.") + } + } else if (shouldBroadcastFallback || newCost < origCost || (newCost == origCost && currentPhysicalPlan != newPhysicalPlan)) { lazy val plans = sideBySide( currentPhysicalPlan.treeString, newPhysicalPlan.treeString).mkString("\n") @@ -368,6 +422,8 @@ case class AdaptiveSparkPlanExec( currentLogicalPlan = newLogicalPlan stagesToReplace = Seq.empty[QueryStageExec] } + } else if (shouldBroadcastFallback) { + cleanUpAndThrowException(fallbackErrors.toSeq, None) } } // Now that some stages have finished, we can try creating new stages. @@ -860,6 +916,210 @@ case class AdaptiveSparkPlanExec( } } + private def shouldFallbackToShuffleJoin(stage: QueryStageExec, error: Throwable): Boolean = { + // Detect errors such as: + // - Cannot broadcast the table over rows: rows. + // - Cannot broadcast the table that is larger than : . + // Check error-classes.json for details + conf.adaptiveBroadcastJoinFallbackToShuffleEnabled && + !isRequiredRootBroadcastStage(stage) && + stage.isInstanceOf[BroadcastQueryStageExec] && + hasErrorClass(error, "_LEGACY_ERROR_TEMP_2248", "_LEGACY_ERROR_TEMP_2249") + } + + // Remove failed fallback stages from the replacement set before replanning. + // For broadcast stages, remove semantically equivalent siblings as well, so + // reused/duplicated broadcast stages for the same relation don't leak through. + private def filterStagesToReplaceForFallback( + stagesToReplace: Seq[QueryStageExec], + failedStages: Seq[QueryStageExec]): Seq[QueryStageExec] = { + if (failedStages.isEmpty) return stagesToReplace + + val failedBroadcastStages = failedStages.collect { + case b: BroadcastQueryStageExec => b + } + + stagesToReplace.filterNot { + case stage if failedStages.exists(_ eq stage) => true + case stage: BroadcastQueryStageExec => + failedBroadcastStages.exists { failed => + sameBroadcastRelation(stage.broadcast.child, failed.broadcast.child) + } + case _ => false + } + } + + // Remove failed broadcast query-stage wrappers that may have been embedded in prior adopted + // logical plans, so fallback replanning does not keep retrying the same failed relation. + private def removeFailedBroadcastStagesFromLogicalPlan( + plan: LogicalPlan, + failedBroadcastStages: Seq[BroadcastQueryStageExec]): LogicalPlan = { + if (failedBroadcastStages.isEmpty) { + plan + } else { + plan.transformDown { + case stage: LogicalQueryStage + if hasFailedBroadcastRelation(stage.physicalPlan, failedBroadcastStages) => + stage.logicalPlan + } + } + } + + // Check if a Spark physical plan contains any of the previous failed broadcast stages + private def hasFailedBroadcastRelation( + plan: SparkPlan, + failedBroadcastStages: Seq[BroadcastQueryStageExec]): Boolean = { + failedBroadcastStages.nonEmpty && plan.exists { p => + broadcastChildPlan(p).exists { child => + failedBroadcastStages.exists { failed => + sameBroadcastRelation(child, failed.broadcast.child) + } + } + } + } + + // Track failed broadcast relations using semantic equality so equivalent stages + // in later AQE replans are treated as the same failed relation. + private def registerFailedBroadcastStage( + failedBroadcastStages: mutable.ArrayBuffer[BroadcastQueryStageExec], + stage: QueryStageExec): Unit = { + stage match { + case failedStage: BroadcastQueryStageExec => + if (!failedBroadcastStages.exists { existing => + sameBroadcastRelation(existing.broadcast.child, failedStage.broadcast.child) + }) { + failedBroadcastStages.append(failedStage) + } + case _ => + } + } + + // Drop stale cache entries for a failed/replaced exchange stage so AQE can rebuild it. + private def removeStageFromCache(stage: QueryStageExec): Unit = stage match { + case exchangeStage: ExchangeQueryStageExec => + val planKey = exchangeStage.plan.canonicalized + val stageKey = exchangeStage.canonicalized + + context.stageCache.remove(planKey) + if (stageKey != planKey) { + context.stageCache.remove(stageKey) + } + case _ => + } + + // Normalize different wrappers to the underlying broadcast child plan when present. + private def broadcastChildPlan(plan: SparkPlan): Option[SparkPlan] = plan match { + case stage: BroadcastQueryStageExec => Some(stage.broadcast.child) + case exchange: BroadcastExchangeLike => Some(exchange.child) + case ReusedExchangeExec(_, exchange: BroadcastExchangeLike) => Some(exchange.child) + case _ => None + } + + // Compare broadcast inputs by semantic plan equivalence. + // Do not fall back to output-attribute equality: different plans can share + // output schemas while producing different relations. + private def sameBroadcastRelation(left: SparkPlan, right: SparkPlan): Boolean = { + left.sameResult(right) + } + + // Root broadcast stages must be preserved (e.g. broadcast subquery output contract). + // Only exempt the actual root broadcast stage instance, not semantically equivalent relations. + private def isRequiredRootBroadcastStage(stage: QueryStageExec): Boolean = stage match { + case broadcastStage: BroadcastQueryStageExec if inputPlan.isInstanceOf[BroadcastExchangeLike] => + @tailrec + def findRootBroadcast(current: SparkPlan): Option[BroadcastQueryStageExec] = current match { + case root: BroadcastQueryStageExec => Some(root) + case resultStage: ResultQueryStageExec => findRootBroadcast(resultStage.plan) + case _ => None + } + findRootBroadcast(currentPhysicalPlan).exists(_ eq broadcastStage) + case _ => false + } + + // Apply side-specific NO_BROADCAST_HASH hints for relations that previously failed + // broadcast stage materialization, so targeted fallback can keep unrelated BHJs. + private def addNoBroadcastHashHintsForFailedRelations( + plan: LogicalPlan, + failedBroadcastStages: Seq[BroadcastQueryStageExec]): LogicalPlan = { + val failedLogicalPlans = extractFailedBroadcastLogicalPlans(failedBroadcastStages) + if (failedLogicalPlans.isEmpty) { + plan + } else { + plan.transformDown { + case join: Join => + val disableLeft = shouldDisableBroadcastForJoinSide(join.left, failedLogicalPlans) + val disableRight = shouldDisableBroadcastForJoinSide(join.right, failedLogicalPlans) + if (!disableLeft && !disableRight) { + join + } else { + val newHint = JoinHint( + if (disableLeft) { + toNoBroadcastHashHint(join.hint.leftHint) + } else { + join.hint.leftHint + }, + if (disableRight) { + toNoBroadcastHashHint(join.hint.rightHint) + } else { + join.hint.rightHint + }) + if (newHint != join.hint) join.copy(hint = newHint) else join + } + } + } + } + + private def extractFailedBroadcastLogicalPlans( + failedBroadcastStages: Seq[BroadcastQueryStageExec]): Seq[LogicalPlan] = { + val failedLogicalPlans = new mutable.ArrayBuffer[LogicalPlan]() + failedBroadcastStages.foreach { stage => + Seq(stage.broadcast.child.logicalLink, stage.broadcast.logicalLink, stage.logicalLink) + .flatten + .foreach { logicalPlan => + if (!failedLogicalPlans.exists(_.sameResult(logicalPlan))) { + failedLogicalPlans.append(logicalPlan) + } + } + } + failedLogicalPlans.toSeq + } + + private def shouldDisableBroadcastForJoinSide( + side: LogicalPlan, + failedLogicalPlans: Seq[LogicalPlan]): Boolean = { + failedLogicalPlans.exists(side.sameResult) + } + + private def toNoBroadcastHashHint(hint: Option[HintInfo]): Option[HintInfo] = { + hint match { + // Only rewrite explicit BROADCAST hints. Preserve existing non-broadcast + // strategies such as NO_BROADCAST_AND_REPLICATION, SHUFFLE_HASH, etc. + case Some(h) if h.strategy.contains(BROADCAST) => + Some(h.copy(strategy = Some(NO_BROADCAST_HASH))) + case Some(_) => hint + case None => Some(HintInfo(strategy = Some(NO_BROADCAST_HASH))) + } + } + + // Walk throwable causes and return true if any SparkThrowable has one of the error classes. + private def hasErrorClass(error: Throwable, errorClasses: String*): Boolean = { + @tailrec + def loop(current: Throwable): Boolean = { + if (current == null) { + false + } else { + current match { + case sparkThrowable: SparkThrowable + if errorClasses.contains(sparkThrowable.getCondition) => + true + case _ => + loop(current.getCause) + } + } + } + loop(error) + } + private def assertStageNotFailed(stage: QueryStageExec): Unit = { if (stage.hasFailed) { throw stage.error.get().get match { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala index a3e6e65d5027a..d4f8898fdab25 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala @@ -31,8 +31,18 @@ import org.apache.spark.sql.{DataFrame, Dataset, QueryTest, Row, SparkSession} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.catalyst.optimizer.{BuildLeft, BuildRight} -import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, LogicalPlan} +import org.apache.spark.sql.catalyst.plans.logical.{ + Aggregate, + BROADCAST, + HintInfo, + Join, + LogicalPlan, + NO_BROADCAST_AND_REPLICATION, + NO_BROADCAST_HASH +} +import org.apache.spark.sql.catalyst.plans.physical.IdentityBroadcastMode import org.apache.spark.sql.classic.Strategy +import org.apache.spark.sql.errors.QueryExecutionErrors import org.apache.spark.sql.execution._ import org.apache.spark.sql.execution.aggregate.BaseAggregateExec import org.apache.spark.sql.execution.columnar.{InMemoryTableScanExec, InMemoryTableScanLike} @@ -624,6 +634,357 @@ class AdaptiveQueryExecSuite } } + test("Fallback to shuffled join should apply NO_BROADCAST_HASH only to failed relation") { + withSQLConf( + SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true", + SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "10MB", + SQLConf.ADAPTIVE_AUTO_BROADCASTJOIN_THRESHOLD.key -> "10MB", + SQLConf.ADAPTIVE_BROADCAST_JOIN_FALLBACK_TO_SHUFFLE_ENABLED.key -> "true", + SQLConf.PREFER_SORTMERGEJOIN.key -> "false") { + val df = sql( + "SELECT /*+ BROADCAST(t2), BROADCAST(t3) */ * " + + "FROM testData t1 JOIN testData2 t2 ON t1.key = t2.a " + + "JOIN testData3 t3 ON t1.key = t3.a") + val adaptivePlan = df.queryExecution.executedPlan.asInstanceOf[AdaptiveSparkPlanExec] + val logicalPlan = adaptivePlan.inputPlan.logicalLink.get + val addNoBroadcastHashHintsForFailedRelations = + PrivateMethod[LogicalPlan](Symbol("addNoBroadcastHashHintsForFailedRelations")) + + val failedExchange = BroadcastExchangeExec( + IdentityBroadcastMode, sql("SELECT * FROM testData2").queryExecution.sparkPlan) + val failedStage = BroadcastQueryStageExec(0, failedExchange, failedExchange.canonicalized) + + val targetedLogicalPlan = adaptivePlan.invokePrivate( + addNoBroadcastHashHintsForFailedRelations(logicalPlan, Seq(failedStage))) + + val joinsWithNoBroadcastHash = targetedLogicalPlan.collect { + case join: Join + if join.hint.leftHint.exists(_.strategy.contains(NO_BROADCAST_HASH)) || + join.hint.rightHint.exists(_.strategy.contains(NO_BROADCAST_HASH)) => + join + } + assert(joinsWithNoBroadcastHash.nonEmpty) + assert(targetedLogicalPlan.collect { + case join: Join + if !join.hint.leftHint.exists(_.strategy.contains(NO_BROADCAST_HASH)) && + !join.hint.rightHint.exists(_.strategy.contains(NO_BROADCAST_HASH)) => + join + }.nonEmpty) + } + } + + test("Fallback to shuffled join should preserve non-broadcast side hint strategies") { + val qe = sql("SELECT key FROM testData").queryExecution + val adaptivePlan = AdaptiveSparkPlanExec( + qe.sparkPlan, + AdaptiveExecutionContext(spark, qe), + Seq.empty, + isSubquery = false) + val toNoBroadcastHashHint = + PrivateMethod[Option[HintInfo]](Symbol("toNoBroadcastHashHint")) + + val noBroadcastAndReplicationHint = + Some(HintInfo(strategy = Some(NO_BROADCAST_AND_REPLICATION))) + val preservedNoBroadcastAndReplication = + adaptivePlan.invokePrivate(toNoBroadcastHashHint(noBroadcastAndReplicationHint)) + assert(preservedNoBroadcastAndReplication == noBroadcastAndReplicationHint) + + val noBroadcastHashHint = Some(HintInfo(strategy = Some(NO_BROADCAST_HASH))) + val preservedNoBroadcastHash = + adaptivePlan.invokePrivate(toNoBroadcastHashHint(noBroadcastHashHint)) + assert(preservedNoBroadcastHash == noBroadcastHashHint) + + val broadcastHint = Some(HintInfo(strategy = Some(BROADCAST))) + val rewrittenBroadcastHint = adaptivePlan.invokePrivate(toNoBroadcastHashHint(broadcastHint)) + assert(rewrittenBroadcastHint.exists(_.strategy.contains(NO_BROADCAST_HASH))) + } + + test("Fallback to shuffled join should keep unrelated BHJs after targeted fallback replan") { + withSQLConf( + SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true", + SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "10MB", + SQLConf.ADAPTIVE_AUTO_BROADCASTJOIN_THRESHOLD.key -> "10MB", + SQLConf.ADAPTIVE_BROADCAST_JOIN_FALLBACK_TO_SHUFFLE_ENABLED.key -> "true", + SQLConf.PREFER_SORTMERGEJOIN.key -> "false") { + val df = sql( + "SELECT /*+ BROADCAST(t2), BROADCAST(t3) */ * " + + "FROM testData t1 JOIN testData2 t2 ON t1.key = t2.a " + + "JOIN testData3 t3 ON t1.key = t3.a") + val adaptivePlan = df.queryExecution.executedPlan.asInstanceOf[AdaptiveSparkPlanExec] + val logicalPlan = adaptivePlan.inputPlan.logicalLink.get + + val reOptimize = + PrivateMethod[Option[(SparkPlan, LogicalPlan)]](Symbol("reOptimize")) + val hasFailedBroadcastRelation = + PrivateMethod[Boolean](Symbol("hasFailedBroadcastRelation")) + val addNoBroadcastHashHintsForFailedRelations = + PrivateMethod[LogicalPlan](Symbol("addNoBroadcastHashHintsForFailedRelations")) + + val failedExchange = BroadcastExchangeExec( + IdentityBroadcastMode, sql("SELECT * FROM testData2").queryExecution.sparkPlan) + val failedStage = BroadcastQueryStageExec(0, failedExchange, failedExchange.canonicalized) + + val baselinePlan = adaptivePlan.invokePrivate(reOptimize(logicalPlan)).get._1 + assert(adaptivePlan.invokePrivate( + hasFailedBroadcastRelation(baselinePlan, Seq(failedStage)))) + + val targetedLogicalPlan = adaptivePlan.invokePrivate( + addNoBroadcastHashHintsForFailedRelations(logicalPlan, Seq(failedStage))) + val targetedPlan = adaptivePlan.invokePrivate(reOptimize(targetedLogicalPlan)).get._1 + + assert(findTopLevelBroadcastHashJoin(targetedPlan).nonEmpty) + assert(!adaptivePlan.invokePrivate( + hasFailedBroadcastRelation(targetedPlan, Seq(failedStage)))) + } + } + + test("Fallback to shuffled join should recognize broadcast limit failures") { + withSQLConf( + SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true", + SQLConf.ADAPTIVE_BROADCAST_JOIN_FALLBACK_TO_SHUFFLE_ENABLED.key -> "true") { + val qe = sql("SELECT key FROM testData").queryExecution + val broadcastInputPlan = BroadcastExchangeExec(IdentityBroadcastMode, qe.sparkPlan) + val adaptivePlan = AdaptiveSparkPlanExec( + qe.sparkPlan, + AdaptiveExecutionContext(spark, qe), + Seq.empty, + isSubquery = false) + val broadcastSubqueryAdaptivePlan = AdaptiveSparkPlanExec( + broadcastInputPlan, + AdaptiveExecutionContext(spark, qe), + Seq.empty, + isSubquery = true) + val regularSubqueryAdaptivePlan = AdaptiveSparkPlanExec( + qe.sparkPlan, + AdaptiveExecutionContext(spark, qe), + Seq.empty, + isSubquery = true) + val stage = + BroadcastQueryStageExec(0, broadcastInputPlan, broadcastInputPlan.canonicalized) + val innerBroadcastPlan = + BroadcastExchangeExec( + IdentityBroadcastMode, + sql("SELECT a FROM testData2").queryExecution.sparkPlan) + val innerStage = + BroadcastQueryStageExec(1, innerBroadcastPlan, innerBroadcastPlan.canonicalized) + val equivalentInnerBroadcastPlan = + BroadcastExchangeExec(IdentityBroadcastMode, qe.sparkPlan) + val equivalentInnerStage = BroadcastQueryStageExec( + 2, equivalentInnerBroadcastPlan, equivalentInnerBroadcastPlan.canonicalized) + val shouldFallbackToShuffleJoin = + PrivateMethod[Boolean](Symbol("shouldFallbackToShuffleJoin")) + + assert(adaptivePlan.invokePrivate(shouldFallbackToShuffleJoin( + stage, + QueryExecutionErrors.cannotBroadcastTableOverMaxTableBytesError(1, 2)))) + assert(adaptivePlan.invokePrivate(shouldFallbackToShuffleJoin( + stage, + QueryExecutionErrors.cannotBroadcastTableOverMaxTableRowsError(1, 2)))) + assert(regularSubqueryAdaptivePlan.invokePrivate(shouldFallbackToShuffleJoin( + stage, + QueryExecutionErrors.cannotBroadcastTableOverMaxTableBytesError(1, 2)))) + assert(regularSubqueryAdaptivePlan.invokePrivate(shouldFallbackToShuffleJoin( + stage, + QueryExecutionErrors.cannotBroadcastTableOverMaxTableRowsError(1, 2)))) + // `stage` is a synthetic stage instance, not the stage currently at the plan root. + // Only the actual root stage instance is exempt from fallback. + assert(broadcastSubqueryAdaptivePlan.invokePrivate(shouldFallbackToShuffleJoin( + stage, + QueryExecutionErrors.cannotBroadcastTableOverMaxTableBytesError(1, 2)))) + assert(broadcastSubqueryAdaptivePlan.invokePrivate(shouldFallbackToShuffleJoin( + stage, + QueryExecutionErrors.cannotBroadcastTableOverMaxTableRowsError(1, 2)))) + assert(broadcastSubqueryAdaptivePlan.invokePrivate(shouldFallbackToShuffleJoin( + innerStage, + QueryExecutionErrors.cannotBroadcastTableOverMaxTableBytesError(1, 2)))) + assert(broadcastSubqueryAdaptivePlan.invokePrivate(shouldFallbackToShuffleJoin( + innerStage, + QueryExecutionErrors.cannotBroadcastTableOverMaxTableRowsError(1, 2)))) + assert(broadcastSubqueryAdaptivePlan.invokePrivate(shouldFallbackToShuffleJoin( + equivalentInnerStage, + QueryExecutionErrors.cannotBroadcastTableOverMaxTableBytesError(1, 2)))) + assert(broadcastSubqueryAdaptivePlan.invokePrivate(shouldFallbackToShuffleJoin( + equivalentInnerStage, + QueryExecutionErrors.cannotBroadcastTableOverMaxTableRowsError(1, 2)))) + } + } + + test("Fallback to shuffled join should be disabled when conf is false") { + withSQLConf( + SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true", + SQLConf.ADAPTIVE_BROADCAST_JOIN_FALLBACK_TO_SHUFFLE_ENABLED.key -> "false") { + val qe = sql("SELECT key FROM testData").queryExecution + val broadcastInputPlan = BroadcastExchangeExec(IdentityBroadcastMode, qe.sparkPlan) + val adaptivePlan = AdaptiveSparkPlanExec( + qe.sparkPlan, + AdaptiveExecutionContext(spark, qe), + Seq.empty, + isSubquery = false) + val stage = + BroadcastQueryStageExec(0, broadcastInputPlan, broadcastInputPlan.canonicalized) + val shouldFallbackToShuffleJoin = + PrivateMethod[Boolean](Symbol("shouldFallbackToShuffleJoin")) + + assert(!adaptivePlan.invokePrivate(shouldFallbackToShuffleJoin( + stage, + QueryExecutionErrors.cannotBroadcastTableOverMaxTableBytesError(1, 2)))) + assert(!adaptivePlan.invokePrivate(shouldFallbackToShuffleJoin( + stage, + QueryExecutionErrors.cannotBroadcastTableOverMaxTableRowsError(1, 2)))) + } + } + + test("Re-optimize should preserve input broadcast exchange") { + val qe = sql("SELECT key FROM testData").queryExecution + val inputPlan = BroadcastExchangeExec(IdentityBroadcastMode, qe.sparkPlan) + val adaptivePlan = AdaptiveSparkPlanExec( + inputPlan, + AdaptiveExecutionContext(spark, qe), + Seq.empty, + isSubquery = true) + val logicalPlan = qe.optimizedPlan + val reOptimize = + PrivateMethod[Option[(SparkPlan, LogicalPlan)]](Symbol("reOptimize")) + + val replanned = adaptivePlan.invokePrivate(reOptimize(logicalPlan)).get._1 + assert(replanned.isInstanceOf[BroadcastExchangeExec]) + } + + test("Fallback to shuffled join should reject replans containing failed broadcast relation") { + val qe = sql("SELECT key FROM testData").queryExecution + val adaptivePlan = AdaptiveSparkPlanExec( + qe.sparkPlan, + AdaptiveExecutionContext(spark, qe), + Seq.empty, + isSubquery = false) + val hasFailedBroadcastRelation = + PrivateMethod[Boolean](Symbol("hasFailedBroadcastRelation")) + + val failedExchange = BroadcastExchangeExec( + IdentityBroadcastMode, sql("SELECT a FROM testData2").queryExecution.sparkPlan) + val failedStage = BroadcastQueryStageExec(0, failedExchange, failedExchange.canonicalized) + + val sameExchange = BroadcastExchangeExec( + IdentityBroadcastMode, sql("SELECT a FROM testData2").queryExecution.sparkPlan) + val reusedSameExchange = ReusedExchangeExec(sameExchange.output, sameExchange) + val sameOutputDifferentRelation = BroadcastExchangeExec( + IdentityBroadcastMode, sql("SELECT a FROM testData2 WHERE a > 1").queryExecution.sparkPlan) + val otherExchange = BroadcastExchangeExec( + IdentityBroadcastMode, sql("SELECT key FROM testData").queryExecution.sparkPlan) + + assert(adaptivePlan.invokePrivate( + hasFailedBroadcastRelation(sameExchange, Seq(failedStage)))) + assert(adaptivePlan.invokePrivate( + hasFailedBroadcastRelation(reusedSameExchange, Seq(failedStage)))) + assert(!adaptivePlan.invokePrivate( + hasFailedBroadcastRelation(sameOutputDifferentRelation, Seq(failedStage)))) + assert(!adaptivePlan.invokePrivate( + hasFailedBroadcastRelation(otherExchange, Seq(failedStage)))) + assert(!adaptivePlan.invokePrivate( + hasFailedBroadcastRelation(qe.sparkPlan, Seq(failedStage)))) + assert(!adaptivePlan.invokePrivate( + hasFailedBroadcastRelation(sameExchange, Seq.empty[BroadcastQueryStageExec]))) + } + + test("Fallback to shuffled join should persist failed broadcast relation across iterations") { + val qe = sql("SELECT key FROM testData").queryExecution + val adaptivePlan = AdaptiveSparkPlanExec( + qe.sparkPlan, + AdaptiveExecutionContext(spark, qe), + Seq.empty, + isSubquery = false) + val registerFailedBroadcastStage = + PrivateMethod[Unit](Symbol("registerFailedBroadcastStage")) + val hasFailedBroadcastRelation = + PrivateMethod[Boolean](Symbol("hasFailedBroadcastRelation")) + + val failedStages = new scala.collection.mutable.ArrayBuffer[BroadcastQueryStageExec]() + val failedExchange1 = BroadcastExchangeExec( + IdentityBroadcastMode, sql("SELECT a FROM testData2").queryExecution.sparkPlan) + val failedStage1 = BroadcastQueryStageExec(0, failedExchange1, failedExchange1.canonicalized) + adaptivePlan.invokePrivate(registerFailedBroadcastStage(failedStages, failedStage1)) + + // Equivalent failed relation from a later iteration should be deduplicated. + val failedExchange2 = BroadcastExchangeExec( + IdentityBroadcastMode, sql("SELECT a FROM testData2").queryExecution.sparkPlan) + val failedStage2 = BroadcastQueryStageExec(1, failedExchange2, failedExchange2.canonicalized) + adaptivePlan.invokePrivate(registerFailedBroadcastStage(failedStages, failedStage2)) + + assert(failedStages.size == 1) + + val sameExchange = BroadcastExchangeExec( + IdentityBroadcastMode, sql("SELECT a FROM testData2").queryExecution.sparkPlan) + assert(adaptivePlan.invokePrivate( + hasFailedBroadcastRelation(sameExchange, failedStages.toSeq))) + } + + test("Fallback to shuffled join should remove equivalent failed broadcast stages") { + val qe = sql("SELECT key FROM testData").queryExecution + val adaptivePlan = AdaptiveSparkPlanExec( + qe.sparkPlan, + AdaptiveExecutionContext(spark, qe), + Seq.empty, + isSubquery = false) + val filterStagesToReplaceForFallback = + PrivateMethod[Seq[QueryStageExec]](Symbol("filterStagesToReplaceForFallback")) + + val failedExchange = BroadcastExchangeExec( + IdentityBroadcastMode, sql("SELECT a FROM testData2").queryExecution.sparkPlan) + val failedStage = BroadcastQueryStageExec(0, failedExchange, failedExchange.canonicalized) + + val equivalentExchange = BroadcastExchangeExec( + IdentityBroadcastMode, sql("SELECT a FROM testData2").queryExecution.sparkPlan) + val equivalentStage = + BroadcastQueryStageExec(1, equivalentExchange, equivalentExchange.canonicalized) + + val otherExchange = BroadcastExchangeExec( + IdentityBroadcastMode, sql("SELECT key FROM testData").queryExecution.sparkPlan) + val otherStage = BroadcastQueryStageExec(2, otherExchange, otherExchange.canonicalized) + + val filtered = adaptivePlan.invokePrivate( + filterStagesToReplaceForFallback( + Seq(failedStage, equivalentStage, otherStage), + Seq(failedStage))) + + assert(!filtered.exists(_.eq(failedStage))) + assert(!filtered.exists(_.eq(equivalentStage))) + assert(filtered.exists(_.eq(otherStage))) + } + + test("Fallback to shuffled join should remove failed broadcast logical query stages") { + val qe = sql("SELECT key FROM testData").queryExecution + val adaptivePlan = AdaptiveSparkPlanExec( + qe.sparkPlan, + AdaptiveExecutionContext(spark, qe), + Seq.empty, + isSubquery = false) + val removeFailedBroadcastStagesFromLogicalPlan = + PrivateMethod[LogicalPlan](Symbol("removeFailedBroadcastStagesFromLogicalPlan")) + + val failedLogical = sql("SELECT a FROM testData2").queryExecution.optimizedPlan + val failedExchange = BroadcastExchangeExec( + IdentityBroadcastMode, sql("SELECT a FROM testData2").queryExecution.sparkPlan) + val failedStage = BroadcastQueryStageExec(0, failedExchange, failedExchange.canonicalized) + val failedLogicalStage = LogicalQueryStage(failedLogical, failedStage) + + val otherLogical = sql("SELECT key FROM testData").queryExecution.optimizedPlan + val otherExchange = BroadcastExchangeExec( + IdentityBroadcastMode, sql("SELECT key FROM testData").queryExecution.sparkPlan) + val otherStage = BroadcastQueryStageExec(1, otherExchange, otherExchange.canonicalized) + val otherLogicalStage = LogicalQueryStage(otherLogical, otherStage) + + val logicalPlan = + org.apache.spark.sql.catalyst.plans.logical.Union(Seq(failedLogicalStage, otherLogicalStage)) + val updated = adaptivePlan.invokePrivate( + removeFailedBroadcastStagesFromLogicalPlan(logicalPlan, Seq(failedStage))) + val remainingLogicalStages = updated.collect { + case stage: LogicalQueryStage => stage + } + + assert(remainingLogicalStages.size == 1) + assert(remainingLogicalStages.head.physicalPlan.eq(otherStage)) + } + test("Union/Except/Intersect queries") { withSQLConf(SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true") { runAdaptiveAndVerifyResult(