diff --git a/mlx/backend/cpu/matmul.cpp b/mlx/backend/cpu/matmul.cpp index 8ae99ab2d..b944aacc0 100644 --- a/mlx/backend/cpu/matmul.cpp +++ b/mlx/backend/cpu/matmul.cpp @@ -132,6 +132,10 @@ void AddMM::eval_cpu(const std::vector& inputs, array& out) { throw std::runtime_error( "[AddMM::eval_cpu] Currently only supports float32."); } + if (out.size() == 0) { + out.set_data(allocator::malloc(out.nbytes())); + return; + } // Fill output with C auto& c = inputs[2]; @@ -139,7 +143,9 @@ void AddMM::eval_cpu(const std::vector& inputs, array& out) { ? CopyType::Scalar : (c.flags().row_contiguous ? CopyType::Vector : CopyType::General); copy(c, out, ctype, stream()); - + if (inputs[0].shape(-1) == 0) { + return; + } matmul_general(inputs[0], inputs[1], out, stream(), alpha_, beta_); } diff --git a/mlx/backend/metal/matmul.cpp b/mlx/backend/metal/matmul.cpp index 71221f8d9..e0ff44200 100644 --- a/mlx/backend/metal/matmul.cpp +++ b/mlx/backend/metal/matmul.cpp @@ -716,6 +716,23 @@ void AddMM::eval_gpu(const std::vector& inputs, array& out) { throw std::runtime_error( "[matmul] Does not yet support non-floating point types."); } + + // Return 0s if either input is empty + if (out.size() == 0) { + out.set_data(allocator::malloc(out.nbytes())); + return; + } + + // Copy c into out and return + if (inputs[0].shape(-1) == 0) { + copy_gpu( + inputs[2], + out, + inputs[2].flags().row_contiguous ? CopyType::Vector : CopyType::General, + stream()); + return; + } + out.set_data(allocator::malloc(out.nbytes())); auto& s = stream(); auto& d = metal::device(s.device); diff --git a/mlx/ops.cpp b/mlx/ops.cpp index e8c260425..922680110 100644 --- a/mlx/ops.cpp +++ b/mlx/ops.cpp @@ -472,6 +472,10 @@ array hadamard_transform( const array& a, std::optional scale_ /* = std::nullopt */, StreamOrDevice s /* = {} */) { + if (a.size() == 0) { + throw std::invalid_argument( + "[hadamard_transform] Does not support empty arrays."); + } // Default to an orthonormal Hadamard matrix scaled by 1/sqrt(N) int n = a.ndim() > 0 ? a.shape(-1) : 1; float scale = scale_.has_value() ? *scale_ : 1.0f / std::sqrt(n); @@ -4326,6 +4330,10 @@ array addmm( c = reshape(c, c_reshape, s); } + if (c.shape() != out_shape) { + throw std::invalid_argument( + "[addmm] input c must broadcast to the output shape"); + } auto out = array( std::move(out_shape), diff --git a/python/tests/test_blas.py b/python/tests/test_blas.py index 6fca4885b..df459eadc 100644 --- a/python/tests/test_blas.py +++ b/python/tests/test_blas.py @@ -589,6 +589,10 @@ class TestBlas(mlx_tests.MLXTestCase): alpha = 0.5 beta = 2.0 + # c must broadcast to the output shape + with self.assertRaises(ValueError): + mx.addmm(mx.zeros((2, 2, 2)), mx.zeros((2, 2)), mx.zeros((2, 2))) + # Regular batched case a_npy = np.random.normal(0.0, 1.0 / 128, (32, 128, 16)).astype(np.float32) b_npy = np.random.normal(0.0, 1.0 / 128, (32, 16, 16)).astype(np.float32) @@ -745,6 +749,19 @@ class TestBlas(mlx_tests.MLXTestCase): mx.eval(c) self.assertEqual(c.shape, (0, 0)) + c = mx.array(1.0, dtype=mx.float32) + a = mx.array([], dtype=mx.float32) + b = mx.array([], dtype=mx.float32) + out = mx.addmm(c, a, b) + self.assertEqual(out.item(), 1.0) + self.assertEqual(out.shape, ()) + + a = mx.zeros(shape=(5, 0)) + b = mx.zeros(shape=(0, 5)) + c = mx.random.uniform(shape=(5, 5)) + out = mx.addmm(c, a, b) + self.assertTrue(mx.allclose(out, c)) + def test_block_masked_matmul(self): def ref_block_masked_mm( a, b, block_size, out_mask=None, lhs_mask=None, rhs_mask=None diff --git a/python/tests/test_ops.py b/python/tests/test_ops.py index d9e143d82..0921de788 100644 --- a/python/tests/test_ops.py +++ b/python/tests/test_ops.py @@ -2830,6 +2830,9 @@ class TestOps(mlx_tests.MLXTestCase): return H def test_hadamard(self): + with self.assertRaises(ValueError): + mx.hadamard_transform(mx.array([])) + h28_str = """ +------++----++-+--+-+--++-- -+-----+++-----+-+--+-+--++-