Fix addmm with empty matrices and beta != 1.0 (#2715)

This commit is contained in:
Harsh Sutaria
2025-11-03 17:16:15 -05:00
committed by GitHub
parent 1ff2b713b6
commit 50fa315d18
3 changed files with 105 additions and 16 deletions

View File

@@ -785,11 +785,46 @@ class TestBlas(mlx_tests.MLXTestCase):
self.assertEqual(out.item(), 1.0)
self.assertEqual(out.shape, ())
a = mx.zeros(shape=(5, 0))
b = mx.zeros(shape=(0, 5))
c = mx.random.uniform(shape=(5, 5))
out = mx.addmm(c, a, b)
self.assertTrue(mx.allclose(out, c))
a = mx.ones((2, 0))
b = mx.ones((0, 2))
c = mx.ones((2, 2))
test_cases = [
(0.0, 1.0),
(0.0, 2.0),
(0.0, 0.5),
(0.0, 0.0),
(1.0, 2.0),
]
for alpha, beta in test_cases:
with self.subTest(alpha=alpha, beta=beta):
result = mx.addmm(c, a, b, alpha=alpha, beta=beta)
expected = c * beta # a @ b = 0 for empty matrices
self.assertTrue(mx.allclose(result, expected))
shapes_tests = [
((3, 0), (0, 3), (3, 3)),
((5, 0), (0, 5), (5, 5)),
((1, 0), (0, 10), (1, 10)),
((10, 0), (0, 1), (10, 1)),
]
for shape_a, shape_b, shape_c in shapes_tests:
with self.subTest(shape_a=shape_a, shape_b=shape_b, shape_c=shape_c):
a = mx.ones(shape_a)
b = mx.ones(shape_b)
c = mx.ones(shape_c)
result = mx.addmm(c, a, b, alpha=0.5, beta=2.0)
expected = c * 2.0
self.assertTrue(mx.allclose(result, expected))
a = mx.ones((2, 5, 0))
b = mx.ones((2, 0, 5))
c = mx.ones((2, 5, 5))
result = mx.addmm(c, a, b, alpha=0.0, beta=3.0)
expected = c * 3.0
self.assertTrue(mx.allclose(result, expected))
def test_block_masked_matmul(self):
def ref_block_masked_mm(