mirror of
				https://github.com/ml-explore/mlx.git
				synced 2025-10-31 07:58:14 +08:00 
			
		
		
		
	[CUDA] Fix alpha not respected when using bias epilogue (#2578)
This commit is contained in:
		| @@ -248,11 +248,19 @@ void CublasGemm::run( | ||||
|     const array& b, | ||||
|     const Shape& batch_shape, | ||||
|     const Strides& a_batch_strides, | ||||
|     const Strides& b_batch_strides) { | ||||
|     const Strides& b_batch_strides, | ||||
|     float alpha) { | ||||
|   int batch_count = out.size() / (M_ * N_); | ||||
|   if (batch_count / batch_shape.back() > 1) { | ||||
|     run_batched( | ||||
|         encoder, out, a, b, batch_shape, a_batch_strides, b_batch_strides); | ||||
|         encoder, | ||||
|         out, | ||||
|         a, | ||||
|         b, | ||||
|         batch_shape, | ||||
|         a_batch_strides, | ||||
|         b_batch_strides, | ||||
|         alpha); | ||||
|     return; | ||||
|   } | ||||
|  | ||||
| @@ -260,7 +268,13 @@ void CublasGemm::run( | ||||
|   encoder.set_input_array(b); | ||||
|   encoder.set_output_array(out); | ||||
|  | ||||
|   execute(encoder, out.data<void>(), a.data<void>(), b.data<void>(), nullptr); | ||||
|   execute( | ||||
|       encoder, | ||||
|       out.data<void>(), | ||||
|       a.data<void>(), | ||||
|       b.data<void>(), | ||||
|       nullptr, | ||||
|       alpha); | ||||
| } | ||||
|  | ||||
| void CublasGemm::run( | ||||
|   | ||||
| @@ -64,7 +64,8 @@ class CublasGemm { | ||||
|       const array& b, | ||||
|       const Shape& batch_shape, | ||||
|       const Strides& a_batch_strides, | ||||
|       const Strides& b_batch_strides); | ||||
|       const Strides& b_batch_strides, | ||||
|       float alpha = 1.0f); | ||||
|  | ||||
|   void run( | ||||
|       cu::CommandEncoder& encoder, | ||||
| @@ -87,7 +88,8 @@ class CublasGemm { | ||||
|       const array& b, | ||||
|       const Shape& batch_shape, | ||||
|       const Strides& a_batch_strides, | ||||
|       const Strides& b_batch_strides); | ||||
|       const Strides& b_batch_strides, | ||||
|       float alpha); | ||||
|  | ||||
|   void run_batched( | ||||
|       cu::CommandEncoder& encoder, | ||||
|   | ||||
| @@ -13,7 +13,8 @@ void CublasGemm::run_batched( | ||||
|     const array& b, | ||||
|     const Shape& batch_shape, | ||||
|     const Strides& a_batch_strides, | ||||
|     const Strides& b_batch_strides) { | ||||
|     const Strides& b_batch_strides, | ||||
|     float alpha) { | ||||
|   encoder.set_input_array(a); | ||||
|   encoder.set_input_array(b); | ||||
|   encoder.set_output_array(out); | ||||
| @@ -27,7 +28,8 @@ void CublasGemm::run_batched( | ||||
|         out.data<int8_t>() + out.itemsize() * i * batch_shape.back() * M_ * N_, | ||||
|         a.data<int8_t>() + a.itemsize() * a_it.loc, | ||||
|         b.data<int8_t>() + b.itemsize() * b_it.loc, | ||||
|         nullptr); | ||||
|         nullptr, | ||||
|         alpha); | ||||
|     a_it.step(); | ||||
|     b_it.step(); | ||||
|   } | ||||
|   | ||||
| @@ -154,7 +154,8 @@ void CublasGemm::run_batched( | ||||
|     const array& b, | ||||
|     const Shape& batch_shape, | ||||
|     const Strides& a_batch_strides, | ||||
|     const Strides& b_batch_strides) { | ||||
|     const Strides& b_batch_strides, | ||||
|     float alpha) { | ||||
|   int batch_count = out.size() / (M_ * N_); | ||||
|   set_pointer_mode(a_desc_, batch_count); | ||||
|   set_pointer_mode(b_desc_, batch_count); | ||||
| @@ -226,7 +227,8 @@ void CublasGemm::run_batched( | ||||
|       reinterpret_cast<void*>(out_pointers), | ||||
|       reinterpret_cast<void*>(a_pointers), | ||||
|       reinterpret_cast<void*>(b_pointers), | ||||
|       nullptr); | ||||
|       nullptr, | ||||
|       alpha); | ||||
| } | ||||
|  | ||||
| void CublasGemm::run_batched( | ||||
|   | ||||
| @@ -41,7 +41,8 @@ void gemm_and_bias( | ||||
|     array& out, | ||||
|     const array& a, | ||||
|     const array& b, | ||||
|     void* bias = nullptr) { | ||||
|     void* bias = nullptr, | ||||
|     float alpha = 1.0f) { | ||||
|   // Check and collapse batch dimensions | ||||
|   auto [batch_shape, a_batch_strides, b_batch_strides] = collapse_batches(a, b); | ||||
|  | ||||
| @@ -94,7 +95,8 @@ void gemm_and_bias( | ||||
|   if (bias) { | ||||
|     gemm.set_bias(bias); | ||||
|   } | ||||
|   gemm.run(encoder, out, a, b, batch_shape, a_batch_strides, b_batch_strides); | ||||
|   gemm.run( | ||||
|       encoder, out, a, b, batch_shape, a_batch_strides, b_batch_strides, alpha); | ||||
| } | ||||
|  | ||||
| } // namespace | ||||
| @@ -169,7 +171,8 @@ void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) { | ||||
|         out, | ||||
|         a, | ||||
|         b, | ||||
|         c.data<void>()); | ||||
|         c.data<void>(), | ||||
|         alpha_); | ||||
|     return; | ||||
|   } | ||||
|  | ||||
|   | ||||
| @@ -594,124 +594,123 @@ class TestBlas(mlx_tests.MLXTestCase): | ||||
|         np.random.seed(0) | ||||
|         # Batched matmul | ||||
|         alpha = 0.5 | ||||
|         beta = 2.0 | ||||
|         for beta in (1.0, 2.0): | ||||
|             # c must broadcast to the output shape | ||||
|             with self.assertRaises(ValueError): | ||||
|                 mx.addmm(mx.zeros((2, 2, 2)), mx.zeros((2, 2)), mx.zeros((2, 2))) | ||||
|  | ||||
|         # c must broadcast to the output shape | ||||
|         with self.assertRaises(ValueError): | ||||
|             mx.addmm(mx.zeros((2, 2, 2)), mx.zeros((2, 2)), mx.zeros((2, 2))) | ||||
|             # Regular batched case | ||||
|             a_npy = np.random.normal(0.0, 1.0 / 128, (32, 128, 16)).astype(np.float32) | ||||
|             b_npy = np.random.normal(0.0, 1.0 / 128, (32, 16, 16)).astype(np.float32) | ||||
|  | ||||
|         # Regular batched case | ||||
|         a_npy = np.random.normal(0.0, 1.0 / 128, (32, 128, 16)).astype(np.float32) | ||||
|         b_npy = np.random.normal(0.0, 1.0 / 128, (32, 16, 16)).astype(np.float32) | ||||
|             a_mlx = mx.array(a_npy) | ||||
|             b_mlx = mx.array(b_npy) | ||||
|  | ||||
|         a_mlx = mx.array(a_npy) | ||||
|         b_mlx = mx.array(b_npy) | ||||
|             for c_shape in ((1,), (1, 16), (32, 1, 16), (1, 128, 16)): | ||||
|                 c_npy = np.ones(c_shape).astype(np.float32) | ||||
|                 c_mlx = mx.array(c_npy) | ||||
|  | ||||
|         for c_shape in ((1,), (1, 16), (32, 1, 16), (1, 128, 16)): | ||||
|             c_npy = np.ones(c_shape).astype(np.float32) | ||||
|             c_mlx = mx.array(c_npy) | ||||
|                 d_npy = alpha * (a_npy @ b_npy) + beta * c_npy | ||||
|                 d_mlx = mx.addmm(c_mlx, a_mlx, b_mlx, alpha, beta) | ||||
|  | ||||
|             d_npy = alpha * (a_npy @ b_npy) + beta * c_npy | ||||
|             d_mlx = mx.addmm(c_mlx, a_mlx, b_mlx, alpha, beta) | ||||
|                 self.assertListEqual(list(d_npy.shape), list(d_mlx.shape)) | ||||
|                 self.assertTrue(np.allclose(d_mlx, d_npy, atol=1e-5)) | ||||
|  | ||||
|             self.assertListEqual(list(d_npy.shape), list(d_mlx.shape)) | ||||
|             self.assertTrue(np.allclose(d_mlx, d_npy, atol=1e-5)) | ||||
|             # Batched and transposed matmul | ||||
|             b_npy = np.random.normal(0.0, 1.0 / 128, (32, 128, 16)).astype(np.float32) | ||||
|             b_mlx = mx.array(b_npy) | ||||
|  | ||||
|         # Batched and transposed matmul | ||||
|         b_npy = np.random.normal(0.0, 1.0 / 128, (32, 128, 16)).astype(np.float32) | ||||
|         b_mlx = mx.array(b_npy) | ||||
|             for c_shape in ((1,), (32, 1, 128), (1, 128)): | ||||
|                 c_npy = np.ones(c_shape).astype(np.float32) | ||||
|                 c_mlx = mx.array(c_npy) | ||||
|  | ||||
|         for c_shape in ((1,), (32, 1, 128), (1, 128)): | ||||
|             c_npy = np.ones(c_shape).astype(np.float32) | ||||
|             c_mlx = mx.array(c_npy) | ||||
|                 b_np_t = np.transpose(b_npy, (0, 2, 1)) | ||||
|                 b_mx_t = mx.transpose(b_mlx, (0, 2, 1)) | ||||
|  | ||||
|             b_np_t = np.transpose(b_npy, (0, 2, 1)) | ||||
|             b_mx_t = mx.transpose(b_mlx, (0, 2, 1)) | ||||
|                 d_npy = alpha * (a_npy @ b_np_t) + beta * c_npy | ||||
|                 d_mlx = mx.addmm(c_mlx, a_mlx, b_mx_t, alpha, beta) | ||||
|  | ||||
|             d_npy = alpha * (a_npy @ b_np_t) + beta * c_npy | ||||
|             d_mlx = mx.addmm(c_mlx, a_mlx, b_mx_t, alpha, beta) | ||||
|                 self.assertListEqual(list(d_npy.shape), list(d_mlx.shape)) | ||||
|                 self.assertTrue(np.allclose(d_mlx, d_npy, atol=1e-5)) | ||||
|             # Batched matmul with simple broadcast | ||||
|             a_npy = np.random.normal(0.0, 1.0 / 128, (32, 128, 16)).astype(np.float32) | ||||
|             b_npy = np.random.normal(0.0, 1.0 / 128, (16, 16)).astype(np.float32) | ||||
|  | ||||
|             self.assertListEqual(list(d_npy.shape), list(d_mlx.shape)) | ||||
|             self.assertTrue(np.allclose(d_mlx, d_npy, atol=1e-5)) | ||||
|         # Batched matmul with simple broadcast | ||||
|         a_npy = np.random.normal(0.0, 1.0 / 128, (32, 128, 16)).astype(np.float32) | ||||
|         b_npy = np.random.normal(0.0, 1.0 / 128, (16, 16)).astype(np.float32) | ||||
|             a_mlx = mx.array(a_npy) | ||||
|             b_mlx = mx.array(b_npy) | ||||
|  | ||||
|         a_mlx = mx.array(a_npy) | ||||
|         b_mlx = mx.array(b_npy) | ||||
|             for c_shape in ((1,), (1, 16), (32, 1, 16), (1, 128, 16)): | ||||
|                 c_npy = np.ones(c_shape).astype(np.float32) | ||||
|                 c_mlx = mx.array(c_npy) | ||||
|  | ||||
|         for c_shape in ((1,), (1, 16), (32, 1, 16), (1, 128, 16)): | ||||
|             c_npy = np.ones(c_shape).astype(np.float32) | ||||
|             c_mlx = mx.array(c_npy) | ||||
|                 d_npy = alpha * (a_npy @ b_npy) + beta * c_npy | ||||
|                 d_mlx = mx.addmm(c_mlx, a_mlx, b_mlx, alpha, beta) | ||||
|  | ||||
|             d_npy = alpha * (a_npy @ b_npy) + beta * c_npy | ||||
|             d_mlx = mx.addmm(c_mlx, a_mlx, b_mlx, alpha, beta) | ||||
|                 self.assertListEqual(list(d_npy.shape), list(d_mlx.shape)) | ||||
|                 self.assertTrue(np.allclose(d_mlx, d_npy, atol=1e-5)) | ||||
|             # Matmul with vector | ||||
|             a_npy = np.random.normal(0.0, 1.0 / 128, (16,)).astype(np.float32) | ||||
|             b_npy = np.random.normal(0.0, 1.0 / 128, (32, 16, 128)).astype(np.float32) | ||||
|             a_mlx = mx.array(a_npy) | ||||
|             b_mlx = mx.array(b_npy) | ||||
|  | ||||
|             self.assertListEqual(list(d_npy.shape), list(d_mlx.shape)) | ||||
|             self.assertTrue(np.allclose(d_mlx, d_npy, atol=1e-5)) | ||||
|         # Matmul with vector | ||||
|         a_npy = np.random.normal(0.0, 1.0 / 128, (16,)).astype(np.float32) | ||||
|         b_npy = np.random.normal(0.0, 1.0 / 128, (32, 16, 128)).astype(np.float32) | ||||
|         a_mlx = mx.array(a_npy) | ||||
|         b_mlx = mx.array(b_npy) | ||||
|             for c_shape in ((1,), (128,), (32, 128)): | ||||
|                 c_npy = np.ones(c_shape).astype(np.float32) | ||||
|                 c_mlx = mx.array(c_npy) | ||||
|  | ||||
|         for c_shape in ((1,), (128,), (32, 128)): | ||||
|             c_npy = np.ones(c_shape).astype(np.float32) | ||||
|             c_mlx = mx.array(c_npy) | ||||
|                 d_npy = alpha * (a_npy @ b_npy) + beta * c_npy | ||||
|                 d_mlx = mx.addmm(c_mlx, a_mlx, b_mlx, alpha, beta) | ||||
|  | ||||
|             d_npy = alpha * (a_npy @ b_npy) + beta * c_npy | ||||
|             d_mlx = mx.addmm(c_mlx, a_mlx, b_mlx, alpha, beta) | ||||
|                 self.assertListEqual(list(d_npy.shape), list(d_mlx.shape)) | ||||
|                 self.assertTrue(np.allclose(d_mlx, d_npy, atol=1e-5)) | ||||
|  | ||||
|             self.assertListEqual(list(d_npy.shape), list(d_mlx.shape)) | ||||
|             self.assertTrue(np.allclose(d_mlx, d_npy, atol=1e-5)) | ||||
|             # Matmul with vector | ||||
|             a_npy = np.random.normal(0.0, 1.0 / 128, (32, 128, 16)).astype(np.float32) | ||||
|             b_npy = np.random.normal(0.0, 1.0 / 128, (16,)).astype(np.float32) | ||||
|             a_mlx = mx.array(a_npy) | ||||
|             b_mlx = mx.array(b_npy) | ||||
|  | ||||
|         # Matmul with vector | ||||
|         a_npy = np.random.normal(0.0, 1.0 / 128, (32, 128, 16)).astype(np.float32) | ||||
|         b_npy = np.random.normal(0.0, 1.0 / 128, (16,)).astype(np.float32) | ||||
|         a_mlx = mx.array(a_npy) | ||||
|         b_mlx = mx.array(b_npy) | ||||
|             for c_shape in ((1,), (32, 128)): | ||||
|                 c_npy = np.ones(c_shape).astype(np.float32) | ||||
|                 c_mlx = mx.array(c_npy) | ||||
|  | ||||
|         for c_shape in ((1,), (32, 128)): | ||||
|             c_npy = np.ones(c_shape).astype(np.float32) | ||||
|             c_mlx = mx.array(c_npy) | ||||
|                 d_npy = alpha * (a_npy @ b_npy) + beta * c_npy | ||||
|                 d_mlx = mx.addmm(c_mlx, a_mlx, b_mlx, alpha, beta) | ||||
|  | ||||
|             d_npy = alpha * (a_npy @ b_npy) + beta * c_npy | ||||
|             d_mlx = mx.addmm(c_mlx, a_mlx, b_mlx, alpha, beta) | ||||
|                 self.assertListEqual(list(d_npy.shape), list(d_mlx.shape)) | ||||
|                 self.assertTrue(np.allclose(d_mlx, d_npy, atol=1e-5)) | ||||
|  | ||||
|             self.assertListEqual(list(d_npy.shape), list(d_mlx.shape)) | ||||
|             self.assertTrue(np.allclose(d_mlx, d_npy, atol=1e-5)) | ||||
|             # Split K specializtion | ||||
|             a_npy = np.random.normal(0.0, 1.0 / 128, (64, 4096)).astype(np.float32) | ||||
|             b_npy = np.random.normal(0.0, 1.0 / 128, (4096, 32)).astype(np.float32) | ||||
|  | ||||
|         # Split K specializtion | ||||
|         a_npy = np.random.normal(0.0, 1.0 / 128, (64, 4096)).astype(np.float32) | ||||
|         b_npy = np.random.normal(0.0, 1.0 / 128, (4096, 32)).astype(np.float32) | ||||
|             a_mlx = mx.array(a_npy) | ||||
|             b_mlx = mx.array(b_npy) | ||||
|  | ||||
|         a_mlx = mx.array(a_npy) | ||||
|         b_mlx = mx.array(b_npy) | ||||
|             for c_shape in ((1,), (1, 32), (64, 1), (64, 32)): | ||||
|                 c_npy = np.ones(c_shape).astype(np.float32) | ||||
|                 c_mlx = mx.array(c_npy) | ||||
|  | ||||
|         for c_shape in ((1,), (1, 32), (64, 1), (64, 32)): | ||||
|             c_npy = np.ones(c_shape).astype(np.float32) | ||||
|             c_mlx = mx.array(c_npy) | ||||
|                 d_npy = alpha * (a_npy @ b_npy) + beta * c_npy | ||||
|                 d_mlx = mx.addmm(c_mlx, a_mlx, b_mlx, alpha, beta) | ||||
|  | ||||
|             d_npy = alpha * (a_npy @ b_npy) + beta * c_npy | ||||
|             d_mlx = mx.addmm(c_mlx, a_mlx, b_mlx, alpha, beta) | ||||
|                 self.assertListEqual(list(d_npy.shape), list(d_mlx.shape)) | ||||
|                 self.assertTrue(np.allclose(d_mlx, d_npy, atol=1e-5)) | ||||
|  | ||||
|             self.assertListEqual(list(d_npy.shape), list(d_mlx.shape)) | ||||
|             self.assertTrue(np.allclose(d_mlx, d_npy, atol=1e-5)) | ||||
|             # Transposed c | ||||
|             a = mx.ones((10, 5)).T | ||||
|             b = mx.ones((5, 5)) | ||||
|             out = mx.addmm(a, b, a, beta=beta, alpha=alpha) | ||||
|             expected = beta * a + alpha * (b @ a) | ||||
|             self.assertTrue(mx.allclose(expected, out)) | ||||
|  | ||||
|         # Transposed c | ||||
|         a = mx.ones((10, 5)).T | ||||
|         b = mx.ones((5, 5)) | ||||
|         out = mx.addmm(a, b, a, beta=1.5, alpha=0.5) | ||||
|         expected = 1.5 * a + 0.5 * (b @ a) | ||||
|         self.assertTrue(mx.allclose(expected, out)) | ||||
|  | ||||
|         # Broadcast c | ||||
|         a = mx.ones((5, 5)) | ||||
|         b = mx.ones((5, 5)) | ||||
|         c = mx.ones((1, 5)) | ||||
|         out = mx.addmm(c, a, b, beta=1.5, alpha=0.5) | ||||
|         expected = 1.5 * c + 0.5 * (a @ b) | ||||
|         self.assertTrue(mx.allclose(expected, out)) | ||||
|             # Broadcast c | ||||
|             a = mx.ones((5, 5)) | ||||
|             b = mx.ones((5, 5)) | ||||
|             c = mx.ones((1, 5)) | ||||
|             out = mx.addmm(c, a, b, beta=beta, alpha=alpha) | ||||
|             expected = beta * c + alpha * (a @ b) | ||||
|             self.assertTrue(mx.allclose(expected, out)) | ||||
|  | ||||
|     def test_addmm_grad(self): | ||||
|         def make_ref_addmm(alpha, beta): | ||||
| @@ -724,33 +723,32 @@ class TestBlas(mlx_tests.MLXTestCase): | ||||
|         shapes = ((1, 64, 32, 128), (4, 28, 24, 47), (1, 1, 24, 47)) | ||||
|  | ||||
|         alpha = 2.0 | ||||
|         beta = 0.5 | ||||
|         for beta in (1.0, 0.5): | ||||
|             f_test = make_addmm(alpha, beta) | ||||
|             f_ref = make_ref_addmm(alpha, beta) | ||||
|  | ||||
|         f_test = make_addmm(alpha, beta) | ||||
|         f_ref = make_ref_addmm(alpha, beta) | ||||
|             for B, M, N, K in shapes: | ||||
|                 cotan = mx.ones((B, M, N)) | ||||
|                 c = mx.random.normal((B, M, N)) | ||||
|                 a = mx.random.normal((B, M, K)) | ||||
|                 b = mx.random.normal((B, K, N)) | ||||
|  | ||||
|         for B, M, N, K in shapes: | ||||
|             cotan = mx.ones((B, M, N)) | ||||
|             c = mx.random.normal((B, M, N)) | ||||
|             a = mx.random.normal((B, M, K)) | ||||
|             b = mx.random.normal((B, K, N)) | ||||
|                 out_ref, dout_ref = mx.vjp( | ||||
|                     f_ref, | ||||
|                     [c, a, b], | ||||
|                     [cotan], | ||||
|                 ) | ||||
|                 out_test, dout_test = mx.vjp( | ||||
|                     f_test, | ||||
|                     [c, a, b], | ||||
|                     [cotan], | ||||
|                 ) | ||||
|  | ||||
|             out_ref, dout_ref = mx.vjp( | ||||
|                 f_ref, | ||||
|                 [c, a, b], | ||||
|                 [cotan], | ||||
|             ) | ||||
|             out_test, dout_test = mx.vjp( | ||||
|                 f_test, | ||||
|                 [c, a, b], | ||||
|                 [cotan], | ||||
|             ) | ||||
|                 self.assertTrue(mx.allclose(out_ref[0], out_test[0], atol=1e-4).item()) | ||||
|  | ||||
|             self.assertTrue(mx.allclose(out_ref[0], out_test[0], atol=1e-4).item()) | ||||
|  | ||||
|             for r, t in zip(dout_ref, dout_test): | ||||
|                 self.assertEqual(r.shape, t.shape) | ||||
|                 self.assertTrue(mx.allclose(r, t, atol=1e-4).item()) | ||||
|                 for r, t in zip(dout_ref, dout_test): | ||||
|                     self.assertEqual(r.shape, t.shape) | ||||
|                     self.assertTrue(mx.allclose(r, t, atol=1e-4).item()) | ||||
|  | ||||
|     def test_empty_matmul(self): | ||||
|         a = mx.array([[], []]).T | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 Cheng
					Cheng