mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Fix the segments type in the test
This commit is contained in:
@@ -1247,7 +1247,7 @@ class TestBlas(mlx_tests.MLXTestCase):
|
||||
|
||||
a = mx.ones((10, 1000))
|
||||
s = mx.random.randint(0, 16, shape=(1000,))
|
||||
s = mx.zeros(16).at[s].add(1)
|
||||
s = mx.zeros(16, dtype=s.dtype).at[s].add(1)
|
||||
s = mx.sort(s)
|
||||
s = mx.cumsum(s)
|
||||
s = mx.concatenate([mx.array([0]), s])
|
||||
|
||||
Reference in New Issue
Block a user