From 6a3acf230161e0904bc087c1d4970e444b2f8ed4 Mon Sep 17 00:00:00 2001 From: Cheng Date: Thu, 11 Sep 2025 15:31:09 +0900 Subject: [PATCH] [CUDA] Set bias as input when using bias epilogue (#2584) --- mlx/backend/cuda/gemms/cublas_gemm.cpp | 9 +++++++-- mlx/backend/cuda/gemms/cublas_gemm.h | 2 +- mlx/backend/cuda/matmul.cpp | 6 +++--- python/tests/test_blas.py | 4 ++-- 4 files changed, 13 insertions(+), 8 deletions(-) diff --git a/mlx/backend/cuda/gemms/cublas_gemm.cpp b/mlx/backend/cuda/gemms/cublas_gemm.cpp index 2283351f1..e58653178 100644 --- a/mlx/backend/cuda/gemms/cublas_gemm.cpp +++ b/mlx/backend/cuda/gemms/cublas_gemm.cpp @@ -230,15 +230,20 @@ void CublasGemm::set_out( batch_stride); } -void CublasGemm::set_bias(void* bias) { +void CublasGemm::set_bias(cu::CommandEncoder& encoder, const array& bias) { + encoder.set_input_array(bias); cublasLtEpilogue_t epilogue = CUBLASLT_EPILOGUE_BIAS; CHECK_CUBLAS_ERROR(cublasLtMatmulDescSetAttribute( matmul_desc_, CUBLASLT_MATMUL_DESC_EPILOGUE, &epilogue, sizeof(epilogue))); + auto* bias_ptr = bias.data(); CHECK_CUBLAS_ERROR(cublasLtMatmulDescSetAttribute( - matmul_desc_, CUBLASLT_MATMUL_DESC_BIAS_POINTER, &bias, sizeof(bias))); + matmul_desc_, + CUBLASLT_MATMUL_DESC_BIAS_POINTER, + &bias_ptr, + sizeof(bias_ptr))); } void CublasGemm::run( diff --git a/mlx/backend/cuda/gemms/cublas_gemm.h b/mlx/backend/cuda/gemms/cublas_gemm.h index e12c3f5c5..b202818a1 100644 --- a/mlx/backend/cuda/gemms/cublas_gemm.h +++ b/mlx/backend/cuda/gemms/cublas_gemm.h @@ -55,7 +55,7 @@ class CublasGemm { int32_t batch_count, int64_t batch_stride); - void set_bias(void* bias); + void set_bias(cu::CommandEncoder& encoder, const array& bias); void run( cu::CommandEncoder& encoder, diff --git a/mlx/backend/cuda/matmul.cpp b/mlx/backend/cuda/matmul.cpp index 744a1bebf..50c8ee629 100644 --- a/mlx/backend/cuda/matmul.cpp +++ b/mlx/backend/cuda/matmul.cpp @@ -41,7 +41,7 @@ void gemm_and_bias( array& out, const array& a, const array& b, - void* bias = nullptr, + const std::optional& bias = std::nullopt, float alpha = 1.0f) { // Check and collapse batch dimensions auto [batch_shape, a_batch_strides, b_batch_strides] = collapse_batches(a, b); @@ -93,7 +93,7 @@ void gemm_and_bias( a_batch_strides.back(), b_batch_strides.back()); if (bias) { - gemm.set_bias(bias); + gemm.set_bias(encoder, *bias); } gemm.run( encoder, out, a, b, batch_shape, a_batch_strides, b_batch_strides, alpha); @@ -171,7 +171,7 @@ void AddMM::eval_gpu(const std::vector& inputs, array& out) { out, a, b, - c.data(), + c, alpha_); return; } diff --git a/python/tests/test_blas.py b/python/tests/test_blas.py index dc9e93699..67289ceef 100644 --- a/python/tests/test_blas.py +++ b/python/tests/test_blas.py @@ -702,7 +702,7 @@ class TestBlas(mlx_tests.MLXTestCase): b = mx.ones((5, 5)) out = mx.addmm(a, b, a, beta=beta, alpha=alpha) expected = beta * a + alpha * (b @ a) - self.assertTrue(mx.allclose(expected, out, atol=1e-5)) + self.assertTrue(mx.allclose(expected, out)) # Broadcast c a = mx.ones((5, 5)) @@ -710,7 +710,7 @@ class TestBlas(mlx_tests.MLXTestCase): c = mx.ones((1, 5)) out = mx.addmm(c, a, b, beta=beta, alpha=alpha) expected = beta * c + alpha * (a @ b) - self.assertTrue(mx.allclose(expected, out, atol=1e-5)) + self.assertTrue(mx.allclose(expected, out)) def test_addmm_grad(self): def make_ref_addmm(alpha, beta):