Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,80 @@ public long rtHandle() {

public static native long cloneHashTable(long hashTableData);

/**
* Serialize a hash table for broadcasting.
*
* @param hashTableHandle Handle to the hash table builder
* @return Handle to the serialized hash table data
*/
public static native long serializeHashTable(long hashTableHandle);

/**
* Deserialize a hash table from broadcast data. Uses the default leaf memory pool for allocation.
*
* @param serializedData Byte array containing serialized hash table
* @return Handle to the deserialized hash table builder
*/
public static native long deserializeHashTable(byte[] serializedData);

/**
* Deserialize a hash table from broadcast data with explicit ignoreNullKeys parameter.
*
* @param serializedData Byte array containing serialized hash table
* @param ignoreNullKeys Whether to ignore null keys (must match the serialized hash table)
* @param joinHasNullKeys Whether the build side has null keys (for null-aware anti join)
* @return Handle to the deserialized hash table builder
*/
public static native long deserializeHashTableWithIgnoreNullKeys(
byte[] serializedData, boolean ignoreNullKeys, boolean joinHasNullKeys);

/**
* Get the size of serialized hash table data.
*
* @param serializedHandle Handle to serialized data
* @return Size in bytes
*/
public static native long getSerializedSize(long serializedHandle);

/**
* Get ignoreNullKeys parameter from serialized hash table metadata.
*
* @param serializedHandle Handle to serialized data
* @return ignoreNullKeys flag used when building the hash table
*/
public static native boolean getSerializedIgnoreNullKeys(long serializedHandle);

/**
* Get joinHasNullKeys parameter from serialized hash table metadata.
*
* @param serializedHandle Handle to serialized data
* @return joinHasNullKeys flag indicating if build side has null keys
*/
public static native boolean getSerializedJoinHasNullKeys(long serializedHandle);

/**
* Get bloom filter blocks byte size from serialized hash table metadata.
*
* @param serializedHandle Handle to serialized data
* @return bloom filter blocks byte size
*/
public static native long getBloomFilterBlocksByteSize(long serializedHandle);

/**
* Get serialized hash table data as byte array.
*
* @param serializedHandle Handle to serialized data
* @return Byte array containing serialized data
*/
public static native byte[] getSerializedData(long serializedHandle);

/**
* Release serialized hash table data.
*
* @param serializedHandle Handle to serialized data
*/
public static native void releaseSerializedData(long serializedHandle);

public native long nativeBuild(
String buildHashTableId,
long[] batchHandlers,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,9 @@ class VeloxConfig(conf: SQLConf) extends GlutenConfig(conf) {
getConf(VALUE_STREAM_DYNAMIC_FILTER_ENABLED)

def enableTimestampNtzValidation: Boolean = getConf(ENABLE_TIMESTAMP_NTZ_VALIDATION)

def enableDriverSideBroadcastHashTableBuild: Boolean =
getConf(VELOX_DRIVER_SIDE_BROADCAST_HASH_TABLE_BUILD)
}

object VeloxConfig extends ConfigRegistry {
Expand Down Expand Up @@ -622,6 +625,18 @@ object VeloxConfig extends ConfigRegistry {
.booleanConf
.createWithDefault(true)

val VELOX_DRIVER_SIDE_BROADCAST_HASH_TABLE_BUILD =
buildConf("spark.gluten.sql.columnar.backend.velox.driverSideBroadcastHashTableBuild")
.doc(
"Enable driver-side broadcast hash table build. When enabled, the hash table is " +
"built and serialized on the driver, then broadcast to executors. When disabled, " +
"each executor builds its own hash table from the broadcast data. " +
"Note: This feature may have issues with complex queries involving Semi/Anti-Join, " +
"sorting, or complex filter conditions. Consider disabling if you encounter " +
"incorrect results in such queries.")
.booleanConf
.createWithDefault(true)

val QUERY_TRACE_ENABLED = buildConf("spark.gluten.sql.columnar.backend.velox.queryTraceEnabled")
.doc("Enable query tracing flag.")
.booleanConf
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,10 @@ import org.apache.spark.rpc.GlutenDriverEndpoint
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.optimizer.{BuildRight, BuildSide}
import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.execution.{SparkPlan, SQLExecution}
import org.apache.spark.sql.execution.{ColumnarBuildSideRelation, SparkPlan, SQLExecution}
import org.apache.spark.sql.execution.joins.BuildSideRelation
import org.apache.spark.sql.execution.metric.SQLMetric
import org.apache.spark.sql.execution.unsafe.UnsafeColumnarBuildSideRelation
import org.apache.spark.sql.vectorized.ColumnarBatch

import io.substrait.proto.JoinRel
Expand Down Expand Up @@ -179,7 +180,48 @@ case class BroadcastHashJoinExecTransformer(
bloomFilterPushdownSize,
metrics.get("buildHashTableTime")
)
val broadcastRDD = VeloxBroadcastBuildSideRDD(sparkContext, broadcast, context)

// Check if the build side can be offloaded to Velox
// If offload=false (e.g., due to unsupported operators), we must use executor-side build
val canOffload = broadcast.value match {
case columnar: ColumnarBuildSideRelation => columnar.offload
case unsafe: UnsafeColumnarBuildSideRelation => unsafe.offload
case _ => false
}

// Choose between driver-side and executor-side hash table build
val broadcastRDD = if (VeloxBroadcastBuildSideCache.isDriverSideBuildEnabled && canOffload) {
// New approach: Build and serialize hash table on driver
// Only use this when the build side can be offloaded to Velox
logInfo(s"Using driver-side broadcast hash table build for $buildBroadcastTableId")
val serializedHashTable = VeloxBroadcastBuildSideCache.buildAndSerializeOnDriver(
broadcast,
context
)
val broadcastSerialized = sparkContext.broadcast(serializedHashTable)
val rdd = VeloxSerializedBroadcastRDD(sparkContext, broadcastSerialized, context)

// Update bloom filter metrics from driver-side build
val (bloomFilterSize, dynamicFiltersProduced) = rdd.getBloomFilterMetrics
metrics.get("bloomFilterBlocksByteSize").foreach(_.set(bloomFilterSize))
metrics.get("hashProbeDynamicFiltersProduced").foreach(_.set(dynamicFiltersProduced))

rdd
} else {
// Legacy approach: Build hash table on each executor
// Use this when:
// 1. Driver-side build is disabled, OR
// 2. The build side cannot be offloaded (offload=false due to unsupported operators)
if (!canOffload) {
logWarning(
s"Build side cannot be offloaded for $buildBroadcastTableId, " +
"falling back to executor-side build")
} else {
logInfo(s"Using executor-side broadcast hash table build for $buildBroadcastTableId")
}
VeloxBroadcastBuildSideRDD(sparkContext, broadcast, context)
}

// FIXME: Do we have to make build side a RDD?
streamedRDD :+ broadcastRDD
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,166 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.gluten.execution

import org.apache.gluten.vectorized.HashJoinBuilder

import org.apache.spark.sql.execution.joins.BuildSideRelation

import java.io.{Externalizable, ObjectInput, ObjectOutput}

/**
* Serialized broadcast hash table that can be efficiently broadcast to executors. This is built on
* the driver and contains the serialized hash table data.
*/
case class SerializedBroadcastHashTable(
serializedData: Array[Byte],
numRows: Long,
ignoreNullKeys: Boolean,
joinHasNullKeys: Boolean,
bloomFilterBlocksByteSize: Long,
hashProbeDynamicFiltersProduced: Long,
buildSideRelation: BuildSideRelation)
extends Externalizable {

def this() = this(null, 0, false, false, 0, 0, null) // Required for Externalizable

override def writeExternal(out: ObjectOutput): Unit = {
out.writeLong(numRows)
out.writeBoolean(ignoreNullKeys)
out.writeBoolean(joinHasNullKeys)
out.writeLong(bloomFilterBlocksByteSize)
out.writeLong(hashProbeDynamicFiltersProduced)
out.writeInt(serializedData.length)
out.write(serializedData)
out.writeObject(buildSideRelation)
}

override def readExternal(in: ObjectInput): Unit = {
val numRows = in.readLong()
val ignoreNullKeys = in.readBoolean()
val joinHasNullKeys = in.readBoolean()
val bloomFilterBlocksByteSize = in.readLong()
val hashProbeDynamicFiltersProduced = in.readLong()
val dataLength = in.readInt()
val data = new Array[Byte](dataLength)
in.readFully(data)
val relation = in.readObject().asInstanceOf[BuildSideRelation]

// Use reflection to set final fields
val numRowsField = classOf[SerializedBroadcastHashTable].getDeclaredField("numRows")
numRowsField.setAccessible(true)
numRowsField.set(this, numRows)

val dataField = classOf[SerializedBroadcastHashTable].getDeclaredField("serializedData")
dataField.setAccessible(true)
dataField.set(this, data)

val relationField = classOf[SerializedBroadcastHashTable].getDeclaredField("buildSideRelation")
relationField.setAccessible(true)
relationField.set(this, relation)

val ignoreNullKeysField =
classOf[SerializedBroadcastHashTable].getDeclaredField("ignoreNullKeys")
ignoreNullKeysField.setAccessible(true)
ignoreNullKeysField.set(this, ignoreNullKeys)

val joinHasNullKeysField =
classOf[SerializedBroadcastHashTable].getDeclaredField("joinHasNullKeys")
joinHasNullKeysField.setAccessible(true)
joinHasNullKeysField.set(this, joinHasNullKeys)

val bloomFilterBlocksByteSizeField =
classOf[SerializedBroadcastHashTable].getDeclaredField("bloomFilterBlocksByteSize")
bloomFilterBlocksByteSizeField.setAccessible(true)
bloomFilterBlocksByteSizeField.set(this, bloomFilterBlocksByteSize)

val hashProbeDynamicFiltersProducedField =
classOf[SerializedBroadcastHashTable].getDeclaredField("hashProbeDynamicFiltersProduced")
hashProbeDynamicFiltersProducedField.setAccessible(true)
hashProbeDynamicFiltersProducedField.set(this, hashProbeDynamicFiltersProduced)
}

/**
* Deserialize the hash table on executor side. The serialized Velox hash table is already in a
* prepared, probe-ready form, so executor side only needs deserialization without re-running
* prepareJoinTable.
*
* @return
* Hash table builder handle
*/
def deserialize(): Long = {
HashJoinBuilder.deserializeHashTableWithIgnoreNullKeys(
serializedData,
ignoreNullKeys,
joinHasNullKeys)
}

/** Get the size of serialized data in bytes. */
def sizeInBytes: Long = serializedData.length.toLong
}

object SerializedBroadcastHashTable {

/**
* Build and serialize a hash table on the driver.
*
* @param hashTableHandle
* Handle to the built hash table
* @param buildSideRelation
* The build side relation for metadata
* @return
* Serialized broadcast hash table
*/
def fromHashTable(
hashTableHandle: Long,
buildSideRelation: BuildSideRelation): SerializedBroadcastHashTable = {

// Serialize the hash table
val serializedHandle = HashJoinBuilder.serializeHashTable(hashTableHandle)

try {
// Get serialized data
val serializedData = HashJoinBuilder
.getSerializedData(serializedHandle)
val numRows = HashJoinBuilder
.getSerializedSize(serializedHandle)
val ignoreNullKeys = HashJoinBuilder
.getSerializedIgnoreNullKeys(serializedHandle)
val joinHasNullKeys = HashJoinBuilder
.getSerializedJoinHasNullKeys(serializedHandle)

// Get bloom filter metrics
val bloomFilterBlocksByteSize = HashJoinBuilder
.getBloomFilterBlocksByteSize(serializedHandle)
val hashProbeDynamicFiltersProduced = if (bloomFilterBlocksByteSize > 0) 1L else 0L

SerializedBroadcastHashTable(
serializedData,
numRows,
ignoreNullKeys,
joinHasNullKeys,
bloomFilterBlocksByteSize,
hashProbeDynamicFiltersProduced,
buildSideRelation)
} finally {
// Clean up serialized handle
HashJoinBuilder.releaseSerializedData(serializedHandle)
// Clean up original hash table
HashJoinBuilder.clearHashTable(hashTableHandle)
}
}
}
Loading
Loading