mirror of
https://github.com/ml-explore/mlx.git
synced 2025-11-06 12:09:43 +08:00
Fix flaky addmm tests (#2581)
This commit is contained in:
@@ -702,7 +702,7 @@ class TestBlas(mlx_tests.MLXTestCase):
|
|||||||
b = mx.ones((5, 5))
|
b = mx.ones((5, 5))
|
||||||
out = mx.addmm(a, b, a, beta=beta, alpha=alpha)
|
out = mx.addmm(a, b, a, beta=beta, alpha=alpha)
|
||||||
expected = beta * a + alpha * (b @ a)
|
expected = beta * a + alpha * (b @ a)
|
||||||
self.assertTrue(mx.allclose(expected, out))
|
self.assertTrue(mx.allclose(expected, out, atol=1e-5))
|
||||||
|
|
||||||
# Broadcast c
|
# Broadcast c
|
||||||
a = mx.ones((5, 5))
|
a = mx.ones((5, 5))
|
||||||
@@ -710,7 +710,7 @@ class TestBlas(mlx_tests.MLXTestCase):
|
|||||||
c = mx.ones((1, 5))
|
c = mx.ones((1, 5))
|
||||||
out = mx.addmm(c, a, b, beta=beta, alpha=alpha)
|
out = mx.addmm(c, a, b, beta=beta, alpha=alpha)
|
||||||
expected = beta * c + alpha * (a @ b)
|
expected = beta * c + alpha * (a @ b)
|
||||||
self.assertTrue(mx.allclose(expected, out))
|
self.assertTrue(mx.allclose(expected, out, atol=1e-5))
|
||||||
|
|
||||||
def test_addmm_grad(self):
|
def test_addmm_grad(self):
|
||||||
def make_ref_addmm(alpha, beta):
|
def make_ref_addmm(alpha, beta):
|
||||||
|
|||||||
Reference in New Issue
Block a user