mirror of
https://github.com/ml-explore/mlx.git
synced 2025-10-19 16:48:10 +08:00
[CUDA] Set bias as input when using bias epilogue (#2584)
This commit is contained in:
@@ -230,15 +230,20 @@ void CublasGemm::set_out(
|
||||
batch_stride);
|
||||
}
|
||||
|
||||
void CublasGemm::set_bias(void* bias) {
|
||||
void CublasGemm::set_bias(cu::CommandEncoder& encoder, const array& bias) {
|
||||
encoder.set_input_array(bias);
|
||||
cublasLtEpilogue_t epilogue = CUBLASLT_EPILOGUE_BIAS;
|
||||
CHECK_CUBLAS_ERROR(cublasLtMatmulDescSetAttribute(
|
||||
matmul_desc_,
|
||||
CUBLASLT_MATMUL_DESC_EPILOGUE,
|
||||
&epilogue,
|
||||
sizeof(epilogue)));
|
||||
auto* bias_ptr = bias.data<void>();
|
||||
CHECK_CUBLAS_ERROR(cublasLtMatmulDescSetAttribute(
|
||||
matmul_desc_, CUBLASLT_MATMUL_DESC_BIAS_POINTER, &bias, sizeof(bias)));
|
||||
matmul_desc_,
|
||||
CUBLASLT_MATMUL_DESC_BIAS_POINTER,
|
||||
&bias_ptr,
|
||||
sizeof(bias_ptr)));
|
||||
}
|
||||
|
||||
void CublasGemm::run(
|
||||
|
@@ -55,7 +55,7 @@ class CublasGemm {
|
||||
int32_t batch_count,
|
||||
int64_t batch_stride);
|
||||
|
||||
void set_bias(void* bias);
|
||||
void set_bias(cu::CommandEncoder& encoder, const array& bias);
|
||||
|
||||
void run(
|
||||
cu::CommandEncoder& encoder,
|
||||
|
@@ -41,7 +41,7 @@ void gemm_and_bias(
|
||||
array& out,
|
||||
const array& a,
|
||||
const array& b,
|
||||
void* bias = nullptr,
|
||||
const std::optional<array>& bias = std::nullopt,
|
||||
float alpha = 1.0f) {
|
||||
// Check and collapse batch dimensions
|
||||
auto [batch_shape, a_batch_strides, b_batch_strides] = collapse_batches(a, b);
|
||||
@@ -93,7 +93,7 @@ void gemm_and_bias(
|
||||
a_batch_strides.back(),
|
||||
b_batch_strides.back());
|
||||
if (bias) {
|
||||
gemm.set_bias(bias);
|
||||
gemm.set_bias(encoder, *bias);
|
||||
}
|
||||
gemm.run(
|
||||
encoder, out, a, b, batch_shape, a_batch_strides, b_batch_strides, alpha);
|
||||
@@ -171,7 +171,7 @@ void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
out,
|
||||
a,
|
||||
b,
|
||||
c.data<void>(),
|
||||
c,
|
||||
alpha_);
|
||||
return;
|
||||
}
|
||||
|
Reference in New Issue
Block a user