diff --git a/mlx/backend/cpu/matmul.cpp b/mlx/backend/cpu/matmul.cpp index b346e84db..0998c527c 100644 --- a/mlx/backend/cpu/matmul.cpp +++ b/mlx/backend/cpu/matmul.cpp @@ -2,6 +2,8 @@ #include #include "mlx/array.h" +#include "mlx/backend/cpu/binary.h" +#include "mlx/backend/cpu/binary_ops.h" #include "mlx/backend/cpu/copy.h" #include "mlx/backend/cpu/encoder.h" #include "mlx/backend/cpu/gemm.h" @@ -135,15 +137,58 @@ void AddMM::eval_cpu(const std::vector& inputs, array& out) { return; } + // Handle empty matrix case (K=0) + if (inputs[0].shape(-1) == 0) { + auto& c = inputs[2]; + if (beta_ == 1.0f) { + CopyType ctype = c.data_size() == 1 + ? CopyType::Scalar + : (c.flags().row_contiguous ? CopyType::Vector : CopyType::General); + copy_cpu(c, out, ctype, stream()); + } else { + array beta_scalar = array(beta_, c.dtype()); + auto bopt = get_binary_op_type(c, beta_scalar); + set_binary_op_output_data(c, beta_scalar, out, bopt); + auto& encoder = cpu::get_command_encoder(stream()); + encoder.set_input_array(c); + encoder.set_input_array(beta_scalar); + encoder.set_output_array(out); + encoder.dispatch([c = array::unsafe_weak_copy(c), + beta_scalar = array::unsafe_weak_copy(beta_scalar), + out = array::unsafe_weak_copy(out), + bopt]() mutable { + switch (out.dtype()) { + case float16: + binary_op(c, beta_scalar, out, bopt); + break; + case float32: + binary_op(c, beta_scalar, out, bopt); + break; + case float64: + binary_op(c, beta_scalar, out, bopt); + break; + case bfloat16: + binary_op(c, beta_scalar, out, bopt); + break; + case complex64: + binary_op(c, beta_scalar, out, bopt); + break; + default: + throw std::runtime_error( + "[AddMM::eval_cpu] Unsupported dtype for beta scaling"); + } + }); + encoder.add_temporary(std::move(beta_scalar)); + } + return; + } + // Fill output with C auto& c = inputs[2]; CopyType ctype = c.data_size() == 1 ? CopyType::Scalar : (c.flags().row_contiguous ? CopyType::Vector : CopyType::General); copy_cpu(c, out, ctype, stream()); - if (inputs[0].shape(-1) == 0) { - return; - } matmul_general(inputs[0], inputs[1], out, stream(), alpha_, beta_); } diff --git a/mlx/backend/metal/matmul.cpp b/mlx/backend/metal/matmul.cpp index d6bee651d..abc45575a 100644 --- a/mlx/backend/metal/matmul.cpp +++ b/mlx/backend/metal/matmul.cpp @@ -8,6 +8,7 @@ #include "mlx/backend/common/broadcasting.h" #include "mlx/backend/common/matmul.h" #include "mlx/backend/gpu/copy.h" +#include "mlx/backend/metal/binary.h" #include "mlx/backend/metal/device.h" #include "mlx/backend/metal/kernels.h" #include "mlx/backend/metal/kernels/defines.h" @@ -925,19 +926,27 @@ void AddMM::eval_gpu(const std::vector& inputs, array& out) { return; } - // Copy c into out and return + auto& s = stream(); + auto& d = metal::device(s.device); + + // Handle empty matrix case (K=0) if (inputs[0].shape(-1) == 0) { - copy_gpu( - inputs[2], - out, - inputs[2].flags().row_contiguous ? CopyType::Vector : CopyType::General, - stream()); + auto& c = inputs[2]; + if (beta_ == 1.0f) { + copy_gpu( + c, + out, + c.flags().row_contiguous ? CopyType::Vector : CopyType::General, + s); + } else { + array beta_scalar = array(beta_, c.dtype()); + binary_op_gpu({c, beta_scalar}, out, "Multiply", s); + d.add_temporary(std::move(beta_scalar), s.index); + } return; } out.set_data(allocator::malloc(out.nbytes())); - auto& s = stream(); - auto& d = metal::device(s.device); auto& a_pre = inputs[0]; auto& b_pre = inputs[1]; diff --git a/python/tests/test_blas.py b/python/tests/test_blas.py index 0cc6e621a..8e97e1f49 100644 --- a/python/tests/test_blas.py +++ b/python/tests/test_blas.py @@ -785,11 +785,46 @@ class TestBlas(mlx_tests.MLXTestCase): self.assertEqual(out.item(), 1.0) self.assertEqual(out.shape, ()) - a = mx.zeros(shape=(5, 0)) - b = mx.zeros(shape=(0, 5)) - c = mx.random.uniform(shape=(5, 5)) - out = mx.addmm(c, a, b) - self.assertTrue(mx.allclose(out, c)) + a = mx.ones((2, 0)) + b = mx.ones((0, 2)) + c = mx.ones((2, 2)) + + test_cases = [ + (0.0, 1.0), + (0.0, 2.0), + (0.0, 0.5), + (0.0, 0.0), + (1.0, 2.0), + ] + + for alpha, beta in test_cases: + with self.subTest(alpha=alpha, beta=beta): + result = mx.addmm(c, a, b, alpha=alpha, beta=beta) + expected = c * beta # a @ b = 0 for empty matrices + self.assertTrue(mx.allclose(result, expected)) + + shapes_tests = [ + ((3, 0), (0, 3), (3, 3)), + ((5, 0), (0, 5), (5, 5)), + ((1, 0), (0, 10), (1, 10)), + ((10, 0), (0, 1), (10, 1)), + ] + + for shape_a, shape_b, shape_c in shapes_tests: + with self.subTest(shape_a=shape_a, shape_b=shape_b, shape_c=shape_c): + a = mx.ones(shape_a) + b = mx.ones(shape_b) + c = mx.ones(shape_c) + result = mx.addmm(c, a, b, alpha=0.5, beta=2.0) + expected = c * 2.0 + self.assertTrue(mx.allclose(result, expected)) + + a = mx.ones((2, 5, 0)) + b = mx.ones((2, 0, 5)) + c = mx.ones((2, 5, 5)) + result = mx.addmm(c, a, b, alpha=0.0, beta=3.0) + expected = c * 3.0 + self.assertTrue(mx.allclose(result, expected)) def test_block_masked_matmul(self): def ref_block_masked_mm(