mirror of
https://github.com/ml-explore/mlx.git
synced 2025-11-03 01:48:12 +08:00
cpu fallback
This commit is contained in:
@@ -168,7 +168,6 @@ class TestFastSDPA(mlx_tests.MLXTestCase):
|
||||
if dtype == mx.float16:
|
||||
rtol = 1e-2
|
||||
|
||||
# np.testing.assert_allclose(o_q, reference, rtol=rtol, atol=atol)
|
||||
self.assertTrue(mx.allclose(o_q, reference, rtol=rtol, atol=atol))
|
||||
self.assertTrue(mx.allclose(o, reference, rtol=rtol, atol=atol))
|
||||
|
||||
|
||||
Reference in New Issue
Block a user