mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-20 20:18:15 +08:00
[CUDA] Set bias as input when using bias epilogue (#2584)
This commit is contained in:
@@ -230,15 +230,20 @@ void CublasGemm::set_out(
|
||||
batch_stride);
|
||||
}
|
||||
|
||||
void CublasGemm::set_bias(void* bias) {
|
||||
void CublasGemm::set_bias(cu::CommandEncoder& encoder, const array& bias) {
|
||||
encoder.set_input_array(bias);
|
||||
cublasLtEpilogue_t epilogue = CUBLASLT_EPILOGUE_BIAS;
|
||||
CHECK_CUBLAS_ERROR(cublasLtMatmulDescSetAttribute(
|
||||
matmul_desc_,
|
||||
CUBLASLT_MATMUL_DESC_EPILOGUE,
|
||||
&epilogue,
|
||||
sizeof(epilogue)));
|
||||
auto* bias_ptr = bias.data<void>();
|
||||
CHECK_CUBLAS_ERROR(cublasLtMatmulDescSetAttribute(
|
||||
matmul_desc_, CUBLASLT_MATMUL_DESC_BIAS_POINTER, &bias, sizeof(bias)));
|
||||
matmul_desc_,
|
||||
CUBLASLT_MATMUL_DESC_BIAS_POINTER,
|
||||
&bias_ptr,
|
||||
sizeof(bias_ptr)));
|
||||
}
|
||||
|
||||
void CublasGemm::run(
|
||||
|
@@ -55,7 +55,7 @@ class CublasGemm {
|
||||
int32_t batch_count,
|
||||
int64_t batch_stride);
|
||||
|
||||
void set_bias(void* bias);
|
||||
void set_bias(cu::CommandEncoder& encoder, const array& bias);
|
||||
|
||||
void run(
|
||||
cu::CommandEncoder& encoder,
|
||||
|
@@ -41,7 +41,7 @@ void gemm_and_bias(
|
||||
array& out,
|
||||
const array& a,
|
||||
const array& b,
|
||||
void* bias = nullptr,
|
||||
const std::optional<array>& bias = std::nullopt,
|
||||
float alpha = 1.0f) {
|
||||
// Check and collapse batch dimensions
|
||||
auto [batch_shape, a_batch_strides, b_batch_strides] = collapse_batches(a, b);
|
||||
@@ -93,7 +93,7 @@ void gemm_and_bias(
|
||||
a_batch_strides.back(),
|
||||
b_batch_strides.back());
|
||||
if (bias) {
|
||||
gemm.set_bias(bias);
|
||||
gemm.set_bias(encoder, *bias);
|
||||
}
|
||||
gemm.run(
|
||||
encoder, out, a, b, batch_shape, a_batch_strides, b_batch_strides, alpha);
|
||||
@@ -171,7 +171,7 @@ void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
out,
|
||||
a,
|
||||
b,
|
||||
c.data<void>(),
|
||||
c,
|
||||
alpha_);
|
||||
return;
|
||||
}
|
||||
|
@@ -702,7 +702,7 @@ class TestBlas(mlx_tests.MLXTestCase):
|
||||
b = mx.ones((5, 5))
|
||||
out = mx.addmm(a, b, a, beta=beta, alpha=alpha)
|
||||
expected = beta * a + alpha * (b @ a)
|
||||
self.assertTrue(mx.allclose(expected, out, atol=1e-5))
|
||||
self.assertTrue(mx.allclose(expected, out))
|
||||
|
||||
# Broadcast c
|
||||
a = mx.ones((5, 5))
|
||||
@@ -710,7 +710,7 @@ class TestBlas(mlx_tests.MLXTestCase):
|
||||
c = mx.ones((1, 5))
|
||||
out = mx.addmm(c, a, b, beta=beta, alpha=alpha)
|
||||
expected = beta * c + alpha * (a @ b)
|
||||
self.assertTrue(mx.allclose(expected, out, atol=1e-5))
|
||||
self.assertTrue(mx.allclose(expected, out))
|
||||
|
||||
def test_addmm_grad(self):
|
||||
def make_ref_addmm(alpha, beta):
|
||||
|
Reference in New Issue
Block a user