diff --git a/python/tests/test_blas.py b/python/tests/test_blas.py index 014e8a9dd..5e096d9c5 100644 --- a/python/tests/test_blas.py +++ b/python/tests/test_blas.py @@ -1247,7 +1247,7 @@ class TestBlas(mlx_tests.MLXTestCase): a = mx.ones((10, 1000)) s = mx.random.randint(0, 16, shape=(1000,)) - s = mx.zeros(16).at[s].add(1) + s = mx.zeros(16, dtype=s.dtype).at[s].add(1) s = mx.sort(s) s = mx.cumsum(s) s = mx.concatenate([mx.array([0]), s])