Fix addmm with empty matrices and beta != 1.0 (#2715)

This commit is contained in:
Harsh Sutaria
2025-11-03 17:16:15 -05:00
committed by GitHub
parent 1ff2b713b6
commit 50fa315d18
3 changed files with 105 additions and 16 deletions

View File

@@ -2,6 +2,8 @@
#include <cstring> #include <cstring>
#include "mlx/array.h" #include "mlx/array.h"
#include "mlx/backend/cpu/binary.h"
#include "mlx/backend/cpu/binary_ops.h"
#include "mlx/backend/cpu/copy.h" #include "mlx/backend/cpu/copy.h"
#include "mlx/backend/cpu/encoder.h" #include "mlx/backend/cpu/encoder.h"
#include "mlx/backend/cpu/gemm.h" #include "mlx/backend/cpu/gemm.h"
@@ -135,15 +137,58 @@ void AddMM::eval_cpu(const std::vector<array>& inputs, array& out) {
return; return;
} }
// Handle empty matrix case (K=0)
if (inputs[0].shape(-1) == 0) {
auto& c = inputs[2];
if (beta_ == 1.0f) {
CopyType ctype = c.data_size() == 1
? CopyType::Scalar
: (c.flags().row_contiguous ? CopyType::Vector : CopyType::General);
copy_cpu(c, out, ctype, stream());
} else {
array beta_scalar = array(beta_, c.dtype());
auto bopt = get_binary_op_type(c, beta_scalar);
set_binary_op_output_data(c, beta_scalar, out, bopt);
auto& encoder = cpu::get_command_encoder(stream());
encoder.set_input_array(c);
encoder.set_input_array(beta_scalar);
encoder.set_output_array(out);
encoder.dispatch([c = array::unsafe_weak_copy(c),
beta_scalar = array::unsafe_weak_copy(beta_scalar),
out = array::unsafe_weak_copy(out),
bopt]() mutable {
switch (out.dtype()) {
case float16:
binary_op<float16_t, detail::Multiply>(c, beta_scalar, out, bopt);
break;
case float32:
binary_op<float, detail::Multiply>(c, beta_scalar, out, bopt);
break;
case float64:
binary_op<double, detail::Multiply>(c, beta_scalar, out, bopt);
break;
case bfloat16:
binary_op<bfloat16_t, detail::Multiply>(c, beta_scalar, out, bopt);
break;
case complex64:
binary_op<complex64_t, detail::Multiply>(c, beta_scalar, out, bopt);
break;
default:
throw std::runtime_error(
"[AddMM::eval_cpu] Unsupported dtype for beta scaling");
}
});
encoder.add_temporary(std::move(beta_scalar));
}
return;
}
// Fill output with C // Fill output with C
auto& c = inputs[2]; auto& c = inputs[2];
CopyType ctype = c.data_size() == 1 CopyType ctype = c.data_size() == 1
? CopyType::Scalar ? CopyType::Scalar
: (c.flags().row_contiguous ? CopyType::Vector : CopyType::General); : (c.flags().row_contiguous ? CopyType::Vector : CopyType::General);
copy_cpu(c, out, ctype, stream()); copy_cpu(c, out, ctype, stream());
if (inputs[0].shape(-1) == 0) {
return;
}
matmul_general(inputs[0], inputs[1], out, stream(), alpha_, beta_); matmul_general(inputs[0], inputs[1], out, stream(), alpha_, beta_);
} }

View File

@@ -8,6 +8,7 @@
#include "mlx/backend/common/broadcasting.h" #include "mlx/backend/common/broadcasting.h"
#include "mlx/backend/common/matmul.h" #include "mlx/backend/common/matmul.h"
#include "mlx/backend/gpu/copy.h" #include "mlx/backend/gpu/copy.h"
#include "mlx/backend/metal/binary.h"
#include "mlx/backend/metal/device.h" #include "mlx/backend/metal/device.h"
#include "mlx/backend/metal/kernels.h" #include "mlx/backend/metal/kernels.h"
#include "mlx/backend/metal/kernels/defines.h" #include "mlx/backend/metal/kernels/defines.h"
@@ -925,19 +926,27 @@ void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
return; return;
} }
// Copy c into out and return auto& s = stream();
auto& d = metal::device(s.device);
// Handle empty matrix case (K=0)
if (inputs[0].shape(-1) == 0) { if (inputs[0].shape(-1) == 0) {
copy_gpu( auto& c = inputs[2];
inputs[2], if (beta_ == 1.0f) {
out, copy_gpu(
inputs[2].flags().row_contiguous ? CopyType::Vector : CopyType::General, c,
stream()); out,
c.flags().row_contiguous ? CopyType::Vector : CopyType::General,
s);
} else {
array beta_scalar = array(beta_, c.dtype());
binary_op_gpu({c, beta_scalar}, out, "Multiply", s);
d.add_temporary(std::move(beta_scalar), s.index);
}
return; return;
} }
out.set_data(allocator::malloc(out.nbytes())); out.set_data(allocator::malloc(out.nbytes()));
auto& s = stream();
auto& d = metal::device(s.device);
auto& a_pre = inputs[0]; auto& a_pre = inputs[0];
auto& b_pre = inputs[1]; auto& b_pre = inputs[1];

View File

@@ -785,11 +785,46 @@ class TestBlas(mlx_tests.MLXTestCase):
self.assertEqual(out.item(), 1.0) self.assertEqual(out.item(), 1.0)
self.assertEqual(out.shape, ()) self.assertEqual(out.shape, ())
a = mx.zeros(shape=(5, 0)) a = mx.ones((2, 0))
b = mx.zeros(shape=(0, 5)) b = mx.ones((0, 2))
c = mx.random.uniform(shape=(5, 5)) c = mx.ones((2, 2))
out = mx.addmm(c, a, b)
self.assertTrue(mx.allclose(out, c)) test_cases = [
(0.0, 1.0),
(0.0, 2.0),
(0.0, 0.5),
(0.0, 0.0),
(1.0, 2.0),
]
for alpha, beta in test_cases:
with self.subTest(alpha=alpha, beta=beta):
result = mx.addmm(c, a, b, alpha=alpha, beta=beta)
expected = c * beta # a @ b = 0 for empty matrices
self.assertTrue(mx.allclose(result, expected))
shapes_tests = [
((3, 0), (0, 3), (3, 3)),
((5, 0), (0, 5), (5, 5)),
((1, 0), (0, 10), (1, 10)),
((10, 0), (0, 1), (10, 1)),
]
for shape_a, shape_b, shape_c in shapes_tests:
with self.subTest(shape_a=shape_a, shape_b=shape_b, shape_c=shape_c):
a = mx.ones(shape_a)
b = mx.ones(shape_b)
c = mx.ones(shape_c)
result = mx.addmm(c, a, b, alpha=0.5, beta=2.0)
expected = c * 2.0
self.assertTrue(mx.allclose(result, expected))
a = mx.ones((2, 5, 0))
b = mx.ones((2, 0, 5))
c = mx.ones((2, 5, 5))
result = mx.addmm(c, a, b, alpha=0.0, beta=3.0)
expected = c * 3.0
self.assertTrue(mx.allclose(result, expected))
def test_block_masked_matmul(self): def test_block_masked_matmul(self):
def ref_block_masked_mm( def ref_block_masked_mm(