diff --git a/mlx/backend/cuda/gemms/cublas_gemm.cpp b/mlx/backend/cuda/gemms/cublas_gemm.cpp index 89d2a9e7b..2283351f1 100644 --- a/mlx/backend/cuda/gemms/cublas_gemm.cpp +++ b/mlx/backend/cuda/gemms/cublas_gemm.cpp @@ -248,11 +248,19 @@ void CublasGemm::run( const array& b, const Shape& batch_shape, const Strides& a_batch_strides, - const Strides& b_batch_strides) { + const Strides& b_batch_strides, + float alpha) { int batch_count = out.size() / (M_ * N_); if (batch_count / batch_shape.back() > 1) { run_batched( - encoder, out, a, b, batch_shape, a_batch_strides, b_batch_strides); + encoder, + out, + a, + b, + batch_shape, + a_batch_strides, + b_batch_strides, + alpha); return; } @@ -260,7 +268,13 @@ void CublasGemm::run( encoder.set_input_array(b); encoder.set_output_array(out); - execute(encoder, out.data(), a.data(), b.data(), nullptr); + execute( + encoder, + out.data(), + a.data(), + b.data(), + nullptr, + alpha); } void CublasGemm::run( diff --git a/mlx/backend/cuda/gemms/cublas_gemm.h b/mlx/backend/cuda/gemms/cublas_gemm.h index 857910e7f..e12c3f5c5 100644 --- a/mlx/backend/cuda/gemms/cublas_gemm.h +++ b/mlx/backend/cuda/gemms/cublas_gemm.h @@ -64,7 +64,8 @@ class CublasGemm { const array& b, const Shape& batch_shape, const Strides& a_batch_strides, - const Strides& b_batch_strides); + const Strides& b_batch_strides, + float alpha = 1.0f); void run( cu::CommandEncoder& encoder, @@ -87,7 +88,8 @@ class CublasGemm { const array& b, const Shape& batch_shape, const Strides& a_batch_strides, - const Strides& b_batch_strides); + const Strides& b_batch_strides, + float alpha); void run_batched( cu::CommandEncoder& encoder, diff --git a/mlx/backend/cuda/gemms/cublas_gemm_batched_12_0.cpp b/mlx/backend/cuda/gemms/cublas_gemm_batched_12_0.cpp index 56c731587..70df21fda 100644 --- a/mlx/backend/cuda/gemms/cublas_gemm_batched_12_0.cpp +++ b/mlx/backend/cuda/gemms/cublas_gemm_batched_12_0.cpp @@ -13,7 +13,8 @@ void CublasGemm::run_batched( const array& b, const Shape& batch_shape, const Strides& a_batch_strides, - const Strides& b_batch_strides) { + const Strides& b_batch_strides, + float alpha) { encoder.set_input_array(a); encoder.set_input_array(b); encoder.set_output_array(out); @@ -27,7 +28,8 @@ void CublasGemm::run_batched( out.data() + out.itemsize() * i * batch_shape.back() * M_ * N_, a.data() + a.itemsize() * a_it.loc, b.data() + b.itemsize() * b_it.loc, - nullptr); + nullptr, + alpha); a_it.step(); b_it.step(); } diff --git a/mlx/backend/cuda/gemms/cublas_gemm_batched_12_9.cu b/mlx/backend/cuda/gemms/cublas_gemm_batched_12_9.cu index 570b79463..41ab9c8bd 100644 --- a/mlx/backend/cuda/gemms/cublas_gemm_batched_12_9.cu +++ b/mlx/backend/cuda/gemms/cublas_gemm_batched_12_9.cu @@ -154,7 +154,8 @@ void CublasGemm::run_batched( const array& b, const Shape& batch_shape, const Strides& a_batch_strides, - const Strides& b_batch_strides) { + const Strides& b_batch_strides, + float alpha) { int batch_count = out.size() / (M_ * N_); set_pointer_mode(a_desc_, batch_count); set_pointer_mode(b_desc_, batch_count); @@ -226,7 +227,8 @@ void CublasGemm::run_batched( reinterpret_cast(out_pointers), reinterpret_cast(a_pointers), reinterpret_cast(b_pointers), - nullptr); + nullptr, + alpha); } void CublasGemm::run_batched( diff --git a/mlx/backend/cuda/matmul.cpp b/mlx/backend/cuda/matmul.cpp index 66cd025df..744a1bebf 100644 --- a/mlx/backend/cuda/matmul.cpp +++ b/mlx/backend/cuda/matmul.cpp @@ -41,7 +41,8 @@ void gemm_and_bias( array& out, const array& a, const array& b, - void* bias = nullptr) { + void* bias = nullptr, + float alpha = 1.0f) { // Check and collapse batch dimensions auto [batch_shape, a_batch_strides, b_batch_strides] = collapse_batches(a, b); @@ -94,7 +95,8 @@ void gemm_and_bias( if (bias) { gemm.set_bias(bias); } - gemm.run(encoder, out, a, b, batch_shape, a_batch_strides, b_batch_strides); + gemm.run( + encoder, out, a, b, batch_shape, a_batch_strides, b_batch_strides, alpha); } } // namespace @@ -169,7 +171,8 @@ void AddMM::eval_gpu(const std::vector& inputs, array& out) { out, a, b, - c.data()); + c.data(), + alpha_); return; } diff --git a/python/tests/test_blas.py b/python/tests/test_blas.py index bc675535a..67289ceef 100644 --- a/python/tests/test_blas.py +++ b/python/tests/test_blas.py @@ -594,124 +594,123 @@ class TestBlas(mlx_tests.MLXTestCase): np.random.seed(0) # Batched matmul alpha = 0.5 - beta = 2.0 + for beta in (1.0, 2.0): + # c must broadcast to the output shape + with self.assertRaises(ValueError): + mx.addmm(mx.zeros((2, 2, 2)), mx.zeros((2, 2)), mx.zeros((2, 2))) - # c must broadcast to the output shape - with self.assertRaises(ValueError): - mx.addmm(mx.zeros((2, 2, 2)), mx.zeros((2, 2)), mx.zeros((2, 2))) + # Regular batched case + a_npy = np.random.normal(0.0, 1.0 / 128, (32, 128, 16)).astype(np.float32) + b_npy = np.random.normal(0.0, 1.0 / 128, (32, 16, 16)).astype(np.float32) - # Regular batched case - a_npy = np.random.normal(0.0, 1.0 / 128, (32, 128, 16)).astype(np.float32) - b_npy = np.random.normal(0.0, 1.0 / 128, (32, 16, 16)).astype(np.float32) + a_mlx = mx.array(a_npy) + b_mlx = mx.array(b_npy) - a_mlx = mx.array(a_npy) - b_mlx = mx.array(b_npy) + for c_shape in ((1,), (1, 16), (32, 1, 16), (1, 128, 16)): + c_npy = np.ones(c_shape).astype(np.float32) + c_mlx = mx.array(c_npy) - for c_shape in ((1,), (1, 16), (32, 1, 16), (1, 128, 16)): - c_npy = np.ones(c_shape).astype(np.float32) - c_mlx = mx.array(c_npy) + d_npy = alpha * (a_npy @ b_npy) + beta * c_npy + d_mlx = mx.addmm(c_mlx, a_mlx, b_mlx, alpha, beta) - d_npy = alpha * (a_npy @ b_npy) + beta * c_npy - d_mlx = mx.addmm(c_mlx, a_mlx, b_mlx, alpha, beta) + self.assertListEqual(list(d_npy.shape), list(d_mlx.shape)) + self.assertTrue(np.allclose(d_mlx, d_npy, atol=1e-5)) - self.assertListEqual(list(d_npy.shape), list(d_mlx.shape)) - self.assertTrue(np.allclose(d_mlx, d_npy, atol=1e-5)) + # Batched and transposed matmul + b_npy = np.random.normal(0.0, 1.0 / 128, (32, 128, 16)).astype(np.float32) + b_mlx = mx.array(b_npy) - # Batched and transposed matmul - b_npy = np.random.normal(0.0, 1.0 / 128, (32, 128, 16)).astype(np.float32) - b_mlx = mx.array(b_npy) + for c_shape in ((1,), (32, 1, 128), (1, 128)): + c_npy = np.ones(c_shape).astype(np.float32) + c_mlx = mx.array(c_npy) - for c_shape in ((1,), (32, 1, 128), (1, 128)): - c_npy = np.ones(c_shape).astype(np.float32) - c_mlx = mx.array(c_npy) + b_np_t = np.transpose(b_npy, (0, 2, 1)) + b_mx_t = mx.transpose(b_mlx, (0, 2, 1)) - b_np_t = np.transpose(b_npy, (0, 2, 1)) - b_mx_t = mx.transpose(b_mlx, (0, 2, 1)) + d_npy = alpha * (a_npy @ b_np_t) + beta * c_npy + d_mlx = mx.addmm(c_mlx, a_mlx, b_mx_t, alpha, beta) - d_npy = alpha * (a_npy @ b_np_t) + beta * c_npy - d_mlx = mx.addmm(c_mlx, a_mlx, b_mx_t, alpha, beta) + self.assertListEqual(list(d_npy.shape), list(d_mlx.shape)) + self.assertTrue(np.allclose(d_mlx, d_npy, atol=1e-5)) + # Batched matmul with simple broadcast + a_npy = np.random.normal(0.0, 1.0 / 128, (32, 128, 16)).astype(np.float32) + b_npy = np.random.normal(0.0, 1.0 / 128, (16, 16)).astype(np.float32) - self.assertListEqual(list(d_npy.shape), list(d_mlx.shape)) - self.assertTrue(np.allclose(d_mlx, d_npy, atol=1e-5)) - # Batched matmul with simple broadcast - a_npy = np.random.normal(0.0, 1.0 / 128, (32, 128, 16)).astype(np.float32) - b_npy = np.random.normal(0.0, 1.0 / 128, (16, 16)).astype(np.float32) + a_mlx = mx.array(a_npy) + b_mlx = mx.array(b_npy) - a_mlx = mx.array(a_npy) - b_mlx = mx.array(b_npy) + for c_shape in ((1,), (1, 16), (32, 1, 16), (1, 128, 16)): + c_npy = np.ones(c_shape).astype(np.float32) + c_mlx = mx.array(c_npy) - for c_shape in ((1,), (1, 16), (32, 1, 16), (1, 128, 16)): - c_npy = np.ones(c_shape).astype(np.float32) - c_mlx = mx.array(c_npy) + d_npy = alpha * (a_npy @ b_npy) + beta * c_npy + d_mlx = mx.addmm(c_mlx, a_mlx, b_mlx, alpha, beta) - d_npy = alpha * (a_npy @ b_npy) + beta * c_npy - d_mlx = mx.addmm(c_mlx, a_mlx, b_mlx, alpha, beta) + self.assertListEqual(list(d_npy.shape), list(d_mlx.shape)) + self.assertTrue(np.allclose(d_mlx, d_npy, atol=1e-5)) + # Matmul with vector + a_npy = np.random.normal(0.0, 1.0 / 128, (16,)).astype(np.float32) + b_npy = np.random.normal(0.0, 1.0 / 128, (32, 16, 128)).astype(np.float32) + a_mlx = mx.array(a_npy) + b_mlx = mx.array(b_npy) - self.assertListEqual(list(d_npy.shape), list(d_mlx.shape)) - self.assertTrue(np.allclose(d_mlx, d_npy, atol=1e-5)) - # Matmul with vector - a_npy = np.random.normal(0.0, 1.0 / 128, (16,)).astype(np.float32) - b_npy = np.random.normal(0.0, 1.0 / 128, (32, 16, 128)).astype(np.float32) - a_mlx = mx.array(a_npy) - b_mlx = mx.array(b_npy) + for c_shape in ((1,), (128,), (32, 128)): + c_npy = np.ones(c_shape).astype(np.float32) + c_mlx = mx.array(c_npy) - for c_shape in ((1,), (128,), (32, 128)): - c_npy = np.ones(c_shape).astype(np.float32) - c_mlx = mx.array(c_npy) + d_npy = alpha * (a_npy @ b_npy) + beta * c_npy + d_mlx = mx.addmm(c_mlx, a_mlx, b_mlx, alpha, beta) - d_npy = alpha * (a_npy @ b_npy) + beta * c_npy - d_mlx = mx.addmm(c_mlx, a_mlx, b_mlx, alpha, beta) + self.assertListEqual(list(d_npy.shape), list(d_mlx.shape)) + self.assertTrue(np.allclose(d_mlx, d_npy, atol=1e-5)) - self.assertListEqual(list(d_npy.shape), list(d_mlx.shape)) - self.assertTrue(np.allclose(d_mlx, d_npy, atol=1e-5)) + # Matmul with vector + a_npy = np.random.normal(0.0, 1.0 / 128, (32, 128, 16)).astype(np.float32) + b_npy = np.random.normal(0.0, 1.0 / 128, (16,)).astype(np.float32) + a_mlx = mx.array(a_npy) + b_mlx = mx.array(b_npy) - # Matmul with vector - a_npy = np.random.normal(0.0, 1.0 / 128, (32, 128, 16)).astype(np.float32) - b_npy = np.random.normal(0.0, 1.0 / 128, (16,)).astype(np.float32) - a_mlx = mx.array(a_npy) - b_mlx = mx.array(b_npy) + for c_shape in ((1,), (32, 128)): + c_npy = np.ones(c_shape).astype(np.float32) + c_mlx = mx.array(c_npy) - for c_shape in ((1,), (32, 128)): - c_npy = np.ones(c_shape).astype(np.float32) - c_mlx = mx.array(c_npy) + d_npy = alpha * (a_npy @ b_npy) + beta * c_npy + d_mlx = mx.addmm(c_mlx, a_mlx, b_mlx, alpha, beta) - d_npy = alpha * (a_npy @ b_npy) + beta * c_npy - d_mlx = mx.addmm(c_mlx, a_mlx, b_mlx, alpha, beta) + self.assertListEqual(list(d_npy.shape), list(d_mlx.shape)) + self.assertTrue(np.allclose(d_mlx, d_npy, atol=1e-5)) - self.assertListEqual(list(d_npy.shape), list(d_mlx.shape)) - self.assertTrue(np.allclose(d_mlx, d_npy, atol=1e-5)) + # Split K specializtion + a_npy = np.random.normal(0.0, 1.0 / 128, (64, 4096)).astype(np.float32) + b_npy = np.random.normal(0.0, 1.0 / 128, (4096, 32)).astype(np.float32) - # Split K specializtion - a_npy = np.random.normal(0.0, 1.0 / 128, (64, 4096)).astype(np.float32) - b_npy = np.random.normal(0.0, 1.0 / 128, (4096, 32)).astype(np.float32) + a_mlx = mx.array(a_npy) + b_mlx = mx.array(b_npy) - a_mlx = mx.array(a_npy) - b_mlx = mx.array(b_npy) + for c_shape in ((1,), (1, 32), (64, 1), (64, 32)): + c_npy = np.ones(c_shape).astype(np.float32) + c_mlx = mx.array(c_npy) - for c_shape in ((1,), (1, 32), (64, 1), (64, 32)): - c_npy = np.ones(c_shape).astype(np.float32) - c_mlx = mx.array(c_npy) + d_npy = alpha * (a_npy @ b_npy) + beta * c_npy + d_mlx = mx.addmm(c_mlx, a_mlx, b_mlx, alpha, beta) - d_npy = alpha * (a_npy @ b_npy) + beta * c_npy - d_mlx = mx.addmm(c_mlx, a_mlx, b_mlx, alpha, beta) + self.assertListEqual(list(d_npy.shape), list(d_mlx.shape)) + self.assertTrue(np.allclose(d_mlx, d_npy, atol=1e-5)) - self.assertListEqual(list(d_npy.shape), list(d_mlx.shape)) - self.assertTrue(np.allclose(d_mlx, d_npy, atol=1e-5)) + # Transposed c + a = mx.ones((10, 5)).T + b = mx.ones((5, 5)) + out = mx.addmm(a, b, a, beta=beta, alpha=alpha) + expected = beta * a + alpha * (b @ a) + self.assertTrue(mx.allclose(expected, out)) - # Transposed c - a = mx.ones((10, 5)).T - b = mx.ones((5, 5)) - out = mx.addmm(a, b, a, beta=1.5, alpha=0.5) - expected = 1.5 * a + 0.5 * (b @ a) - self.assertTrue(mx.allclose(expected, out)) - - # Broadcast c - a = mx.ones((5, 5)) - b = mx.ones((5, 5)) - c = mx.ones((1, 5)) - out = mx.addmm(c, a, b, beta=1.5, alpha=0.5) - expected = 1.5 * c + 0.5 * (a @ b) - self.assertTrue(mx.allclose(expected, out)) + # Broadcast c + a = mx.ones((5, 5)) + b = mx.ones((5, 5)) + c = mx.ones((1, 5)) + out = mx.addmm(c, a, b, beta=beta, alpha=alpha) + expected = beta * c + alpha * (a @ b) + self.assertTrue(mx.allclose(expected, out)) def test_addmm_grad(self): def make_ref_addmm(alpha, beta): @@ -724,33 +723,32 @@ class TestBlas(mlx_tests.MLXTestCase): shapes = ((1, 64, 32, 128), (4, 28, 24, 47), (1, 1, 24, 47)) alpha = 2.0 - beta = 0.5 + for beta in (1.0, 0.5): + f_test = make_addmm(alpha, beta) + f_ref = make_ref_addmm(alpha, beta) - f_test = make_addmm(alpha, beta) - f_ref = make_ref_addmm(alpha, beta) + for B, M, N, K in shapes: + cotan = mx.ones((B, M, N)) + c = mx.random.normal((B, M, N)) + a = mx.random.normal((B, M, K)) + b = mx.random.normal((B, K, N)) - for B, M, N, K in shapes: - cotan = mx.ones((B, M, N)) - c = mx.random.normal((B, M, N)) - a = mx.random.normal((B, M, K)) - b = mx.random.normal((B, K, N)) + out_ref, dout_ref = mx.vjp( + f_ref, + [c, a, b], + [cotan], + ) + out_test, dout_test = mx.vjp( + f_test, + [c, a, b], + [cotan], + ) - out_ref, dout_ref = mx.vjp( - f_ref, - [c, a, b], - [cotan], - ) - out_test, dout_test = mx.vjp( - f_test, - [c, a, b], - [cotan], - ) + self.assertTrue(mx.allclose(out_ref[0], out_test[0], atol=1e-4).item()) - self.assertTrue(mx.allclose(out_ref[0], out_test[0], atol=1e-4).item()) - - for r, t in zip(dout_ref, dout_test): - self.assertEqual(r.shape, t.shape) - self.assertTrue(mx.allclose(r, t, atol=1e-4).item()) + for r, t in zip(dout_ref, dout_test): + self.assertEqual(r.shape, t.shape) + self.assertTrue(mx.allclose(r, t, atol=1e-4).item()) def test_empty_matmul(self): a = mx.array([[], []]).T