mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-15 09:29:26 +08:00
Fix addmm with empty matrices and beta != 1.0 (#2715)
This commit is contained in:
@@ -785,11 +785,46 @@ class TestBlas(mlx_tests.MLXTestCase):
|
||||
self.assertEqual(out.item(), 1.0)
|
||||
self.assertEqual(out.shape, ())
|
||||
|
||||
a = mx.zeros(shape=(5, 0))
|
||||
b = mx.zeros(shape=(0, 5))
|
||||
c = mx.random.uniform(shape=(5, 5))
|
||||
out = mx.addmm(c, a, b)
|
||||
self.assertTrue(mx.allclose(out, c))
|
||||
a = mx.ones((2, 0))
|
||||
b = mx.ones((0, 2))
|
||||
c = mx.ones((2, 2))
|
||||
|
||||
test_cases = [
|
||||
(0.0, 1.0),
|
||||
(0.0, 2.0),
|
||||
(0.0, 0.5),
|
||||
(0.0, 0.0),
|
||||
(1.0, 2.0),
|
||||
]
|
||||
|
||||
for alpha, beta in test_cases:
|
||||
with self.subTest(alpha=alpha, beta=beta):
|
||||
result = mx.addmm(c, a, b, alpha=alpha, beta=beta)
|
||||
expected = c * beta # a @ b = 0 for empty matrices
|
||||
self.assertTrue(mx.allclose(result, expected))
|
||||
|
||||
shapes_tests = [
|
||||
((3, 0), (0, 3), (3, 3)),
|
||||
((5, 0), (0, 5), (5, 5)),
|
||||
((1, 0), (0, 10), (1, 10)),
|
||||
((10, 0), (0, 1), (10, 1)),
|
||||
]
|
||||
|
||||
for shape_a, shape_b, shape_c in shapes_tests:
|
||||
with self.subTest(shape_a=shape_a, shape_b=shape_b, shape_c=shape_c):
|
||||
a = mx.ones(shape_a)
|
||||
b = mx.ones(shape_b)
|
||||
c = mx.ones(shape_c)
|
||||
result = mx.addmm(c, a, b, alpha=0.5, beta=2.0)
|
||||
expected = c * 2.0
|
||||
self.assertTrue(mx.allclose(result, expected))
|
||||
|
||||
a = mx.ones((2, 5, 0))
|
||||
b = mx.ones((2, 0, 5))
|
||||
c = mx.ones((2, 5, 5))
|
||||
result = mx.addmm(c, a, b, alpha=0.0, beta=3.0)
|
||||
expected = c * 3.0
|
||||
self.assertTrue(mx.allclose(result, expected))
|
||||
|
||||
def test_block_masked_matmul(self):
|
||||
def ref_block_masked_mm(
|
||||
|
||||
Reference in New Issue
Block a user