diff --git a/python/tests/test_ops.py b/python/tests/test_ops.py index 8521d8f80..bbea9ad8e 100644 --- a/python/tests/test_ops.py +++ b/python/tests/test_ops.py @@ -1478,7 +1478,7 @@ class TestOps(mlx_tests.MLXTestCase): r_mlx = mlxop(y) mx.eval(r_mlx) - self.assertTrue(np.allclose(r_np, r_mlx, atol=atol)) + self.assertTrue(np.allclose(r_np, r_mlx, atol=atol, equal_nan=True)) x = np.random.rand(9, 12, 18) xi = np.random.rand(9, 12, 18)