Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,8 @@ template <typename ALayout,
bool MulRoutedWeight = true,
typename IndexType = index_t,
typename ComputeTypeA = ADataType,
typename ComputeTypeB = BDataType>
typename ComputeTypeB = BDataType,
typename BiasDataType = CDataType>
struct DeviceMoeGemmMXBNS : public DeviceMoEGemmMXBPreShuffle<ALayout,
BLayout,
DsLayout,
Expand Down Expand Up @@ -419,6 +420,7 @@ struct DeviceMoeGemmMXBNS : public DeviceMoEGemmMXBPreShuffle<ALayout,
const void* p_a_scale,
const void* p_b,
const void* p_b_scale,
const void* p_bias,
std::array<const void*, NumDTensor> p_ds,
void* p_c,
index_t NumTokens,
Expand All @@ -444,6 +446,7 @@ struct DeviceMoeGemmMXBNS : public DeviceMoEGemmMXBPreShuffle<ALayout,
static_cast<const AScaleDataType*>(p_a_scale),
static_cast<const BDataType*>(p_b),
static_cast<const BScaleDataType*>(p_b_scale),
static_cast<const BiasDataType*>(p_bias),
p_ds,
static_cast<CDataType*>(p_c),
NumTokens,
Expand Down Expand Up @@ -493,6 +496,7 @@ struct DeviceMoeGemmMXBNS : public DeviceMoEGemmMXBPreShuffle<ALayout,
static_cast<const AScaleDataType*>(p_a_scale),
static_cast<const BDataType*>(p_b),
static_cast<const BScaleDataType*>(p_b_scale),
nullptr,
p_ds,
static_cast<CDataType*>(p_c),
M, // randoms set, no use
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,8 @@ template <typename ALayout,
bool MulRoutedWeight = true,
typename IndexType = index_t,
typename ComputeTypeA = ADataType,
typename ComputeTypeB = BDataType>
typename ComputeTypeB = BDataType,
typename BiasDataType = CDataType>
struct DeviceMoeGemmMXBPreShuffle : public DeviceMoEGemmMXBPreShuffle<ALayout,
BLayout,
DsLayout,
Expand Down Expand Up @@ -446,6 +447,7 @@ struct DeviceMoeGemmMXBPreShuffle : public DeviceMoEGemmMXBPreShuffle<ALayout,
const void* p_a_scale,
const void* p_b,
const void* p_b_scale,
const void* p_bias,
std::array<const void*, NumDTensor> p_ds,
void* p_c,
index_t NumTokens,
Expand All @@ -471,6 +473,7 @@ struct DeviceMoeGemmMXBPreShuffle : public DeviceMoEGemmMXBPreShuffle<ALayout,
static_cast<const AScaleDataType*>(p_a_scale),
static_cast<const BDataType*>(p_b),
static_cast<const BScaleDataType*>(p_b_scale),
static_cast<const BiasDataType*>(p_bias),
p_ds,
static_cast<CDataType*>(p_c),
NumTokens,
Expand Down Expand Up @@ -520,6 +523,7 @@ struct DeviceMoeGemmMXBPreShuffle : public DeviceMoEGemmMXBPreShuffle<ALayout,
static_cast<const AScaleDataType*>(p_a_scale),
static_cast<const BDataType*>(p_b),
static_cast<const BScaleDataType*>(p_b_scale),
nullptr,
p_ds,
static_cast<CDataType*>(p_c),
M, // randoms set, no use
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,15 @@ enum Activation
swiglustep_and_mul = 2
};

__host__ __device__ inline float compute_swiglu_and_mul(float gate, float linear)
{
constexpr float alpha = 1.702f;
constexpr float limit = 7.0f;
gate = gate < limit ? gate : limit;
linear = linear < limit ? (linear > -limit ? linear : -limit) : limit;
return gate * (1.0f / (1.0f + math::exp(-alpha * gate))) * (linear + 1.0f);
}

template <typename ALayout,
typename BLayout,
typename ELayout,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ __launch_bounds__(GridwiseGemm::MaxBlockSize, MinimumOccupancy)
karg.p_a_scale_grid + splitk_batch_offset.a_k_split_offset,
karg.p_b_grid + splitk_batch_offset.b_k_split_offset,
karg.p_b_scale_grid + splitk_batch_offset.b_k_split_offset,
karg.p_bias_grid,
karg.p_ds_grid,
karg.p_c_grid,
p_shared,
Expand Down Expand Up @@ -128,7 +129,8 @@ template <typename ALayout,
bool MulRoutedWeight = true,
typename IndexType = index_t,
typename ComputeTypeA = ADataType,
typename ComputeTypeB = BDataType>
typename ComputeTypeB = BDataType,
typename BiasDataType = CDataType>
struct GridwiseMoeGemmMXBNS
: public GridwiseGemm_xdl_cshuffle_base<
ALayout,
Expand Down Expand Up @@ -703,6 +705,7 @@ struct GridwiseMoeGemmMXBNS
const AScaleDataType* p_a_scale_grid_,
const BDataType* p_b_grid_,
const BScaleDataType* p_b_scale_grid_,
const BiasDataType* p_bias_grid_,
std::array<const void*, NumDTensor> p_ds_grid_,
CDataType* p_c_grid_,
index_t NumTokens_,
Expand Down Expand Up @@ -739,6 +742,7 @@ struct GridwiseMoeGemmMXBNS
p_a_scale_grid{p_a_scale_grid_},
p_b_grid{p_b_grid_},
p_b_scale_grid{p_b_scale_grid_},
p_bias_grid{p_bias_grid_},
p_ds_grid{},
p_c_grid{p_c_grid_},
a_element_op{a_element_op_},
Expand All @@ -762,6 +766,7 @@ struct GridwiseMoeGemmMXBNS
const AScaleDataType* p_a_scale_grid;
const BDataType* p_b_grid;
const BScaleDataType* p_b_scale_grid;
const BiasDataType* p_bias_grid;
DsGridPointer p_ds_grid;
CDataType* p_c_grid;

Expand Down Expand Up @@ -1099,6 +1104,7 @@ struct GridwiseMoeGemmMXBNS
const AScaleDataType* p_a_scale_grid,
const BDataType* p_b_grid,
const BScaleDataType* p_b_scale_grid,
const BiasDataType* p_bias_grid,
DsGridPointer& p_ds_grid,
CDataType* p_c_grid,
void* p_shared,
Expand Down Expand Up @@ -1517,10 +1523,34 @@ struct GridwiseMoeGemmMXBNS
static_assert(M5 == 4);
const index_t m1 = get_warp_local_1d_id() / NWave; // Mwave id
const index_t m4 = threadIdx.x % get_warp_size() / MPerXdl;
const BiasDataType* p_bias_col = nullptr;
const BiasDataType* p_bias_col_up = nullptr;
if(p_bias_grid != nullptr)
{
const long_index_t expert_bias_stride =
static_cast<long_index_t>(problem.N) * (IsInputGemm ? 2 : 1);
const long_index_t base_n = static_cast<long_index_t>(block_n_id) * NPerBlock +
static_cast<long_index_t>(waveId_n) * NXdlPack *
NPerXdl +
threadIdx.x % NPerXdl;
p_bias_col =
p_bias_grid + static_cast<long_index_t>(expert_id) * expert_bias_stride + base_n;
if constexpr(IsInputGemm)
{
p_bias_col_up = p_bias_col + problem.N;
}
}

vector_type<float, 4> topk_weights; // for gemm2 only
static_for<0, NXdlPerWave / NXdlPack, 1>{}([&](auto n0) {
static_for<0, NXdlPack, 1>{}([&](auto inxdl) { // NXdlPack
constexpr index_t n_offset =
n0 * NWave * NXdlPack * NPerXdl + inxdl * NPerXdl;
const float bias =
p_bias_col != nullptr ? type_convert<float>(p_bias_col[n_offset]) : 0.0f;
const float bias_up = p_bias_col_up != nullptr
? type_convert<float>(p_bias_col_up[n_offset])
: 0.0f;
static_for<0, MXdlPerWave / MXdlPack, 1>{}([&](auto m0) { // MXDLPerWave
static_for<0, MXdlPack, 1>{}([&](auto imxdl) { // MXdlPack
static_for<0, M3, 1>{}([&](auto m3) { // m_inst_num_groups_per_blk
Expand Down Expand Up @@ -1553,9 +1583,32 @@ struct GridwiseMoeGemmMXBNS
gate = gate * topk_weights.AsType<float>()[m5];
up = up * topk_weights.AsType<float>()[m5];
}
if(p_bias_col != nullptr)
{
gate += bias;
up += bias_up;
}
tensor_operation::element_wise::Silu{}(gate, gate);
c_thread_buf_fp32(cidx) = gate * up;
}
else if(ActivationOperation ==
Activation::swiglustep_and_mul)
{
float gate = c_thread_buf[cidx];
float up = c_thread_buf_up[cidx];
if constexpr(MulRoutedWeight)
{
gate = gate * topk_weights.AsType<float>()[m5];
up = up * topk_weights.AsType<float>()[m5];
}
if(p_bias_col != nullptr)
{
gate += bias;
up += bias_up;
}
c_thread_buf_fp32(cidx) =
compute_swiglu_and_mul(gate, up);
}
else if(ActivationOperation == Activation::gelu_and_mul)
{
float gate = c_thread_buf[cidx];
Expand All @@ -1565,6 +1618,11 @@ struct GridwiseMoeGemmMXBNS
gate = gate * topk_weights.AsType<float>()[m5];
up = up * topk_weights.AsType<float>()[m5];
}
if(p_bias_col != nullptr)
{
gate += bias;
up += bias_up;
}
tensor_operation::element_wise::Gelu{}(gate, gate);
c_thread_buf_fp32(cidx) = gate * up;

Expand All @@ -1581,13 +1639,16 @@ struct GridwiseMoeGemmMXBNS
}
else
{
c_thread_buf_fp32(cidx) = c_thread_buf[cidx];
float out_val = c_thread_buf[cidx];
if(p_bias_col != nullptr)
{
out_val += bias;
}
if constexpr(MulRoutedWeight)
{
c_thread_buf_fp32(cidx) =
topk_weights.AsType<float>()[m5] *
c_thread_buf_fp32[cidx];
out_val = topk_weights.AsType<float>()[m5] * out_val;
}
c_thread_buf_fp32(cidx) = out_val;
}
});
});
Expand Down Expand Up @@ -1630,6 +1691,7 @@ struct GridwiseMoeGemmMXBNS
const AScaleDataType* p_a_scale_grid,
const BDataType* p_b_grid,
const BScaleDataType* p_b_scale_grid,
const BiasDataType* p_bias_grid,
DsGridPointer& p_ds_grid,
CDataType* p_c_grid,
void* p_shared,
Expand Down
Loading
Loading