Fix the segments type in the test

This commit is contained in:
Angelos Katharopoulos
2025-07-07 17:25:19 -07:00
parent 1c589298ec
commit 3336a35512

View File

@@ -1247,7 +1247,7 @@ class TestBlas(mlx_tests.MLXTestCase):
a = mx.ones((10, 1000)) a = mx.ones((10, 1000))
s = mx.random.randint(0, 16, shape=(1000,)) s = mx.random.randint(0, 16, shape=(1000,))
s = mx.zeros(16).at[s].add(1) s = mx.zeros(16, dtype=s.dtype).at[s].add(1)
s = mx.sort(s) s = mx.sort(s)
s = mx.cumsum(s) s = mx.cumsum(s)
s = mx.concatenate([mx.array([0]), s]) s = mx.concatenate([mx.array([0]), s])