diff --git a/python/tests/test_reduce.py b/python/tests/test_reduce.py index b762aca50..0342173dd 100644 --- a/python/tests/test_reduce.py +++ b/python/tests/test_reduce.py @@ -167,7 +167,7 @@ class TestReduce(mlx_tests.MLXTestCase): for dtype in dtypes: with self.subTest(dtype=dtype): - x = (mx.random.normal((4, 4))).astype(getattr(mx, dtype)) + x = (mx.random.normal((4, 4)) * 10).astype(getattr(mx, dtype)) indices = mx.random.randint(0, 4, shape=(6,)).reshape(3, 2) for idx in indices: x[*idx] = mx.nan