diff --git a/mlx/ops.cpp b/mlx/ops.cpp index 1a577caaf7..922680110c 100644 --- a/mlx/ops.cpp +++ b/mlx/ops.cpp @@ -4330,6 +4330,10 @@ array addmm( c = reshape(c, c_reshape, s); } + if (c.shape() != out_shape) { + throw std::invalid_argument( + "[addmm] input c must broadcast to the output shape"); + } auto out = array( std::move(out_shape), diff --git a/python/tests/test_blas.py b/python/tests/test_blas.py index ea4a8752e6..df459eadc7 100644 --- a/python/tests/test_blas.py +++ b/python/tests/test_blas.py @@ -589,6 +589,10 @@ class TestBlas(mlx_tests.MLXTestCase): alpha = 0.5 beta = 2.0 + # c must broadcast to the output shape + with self.assertRaises(ValueError): + mx.addmm(mx.zeros((2, 2, 2)), mx.zeros((2, 2)), mx.zeros((2, 2))) + # Regular batched case a_npy = np.random.normal(0.0, 1.0 / 128, (32, 128, 16)).astype(np.float32) b_npy = np.random.normal(0.0, 1.0 / 128, (32, 16, 16)).astype(np.float32) @@ -745,11 +749,11 @@ class TestBlas(mlx_tests.MLXTestCase): mx.eval(c) self.assertEqual(c.shape, (0, 0)) - c = mx.array([], dtype=mx.float32) + c = mx.array(1.0, dtype=mx.float32) a = mx.array([], dtype=mx.float32) b = mx.array([], dtype=mx.float32) - out = mx.addmm(a, b, c) - mx.eval(out) + out = mx.addmm(c, a, b) + self.assertEqual(out.item(), 1.0) self.assertEqual(out.shape, ()) a = mx.zeros(shape=(5, 0))