improvements to scatter / gather (#1541)

This commit is contained in:
Awni Hannun
2024-10-30 19:30:54 -07:00
committed by GitHub
parent 960e3f0f05
commit 4f72c66911
9 changed files with 195 additions and 248 deletions

View File

@@ -1089,12 +1089,14 @@ class TestOps(mlx_tests.MLXTestCase):
a_mlx = mx.array(a_np)
if ax == None:
idx_np = np.random.randint(low=0, high=a_np.size, size=(16,))
idx_np = np.random.permutation(a_np.size)
values_np = np.random.randint(low=0, high=100, size=(16,))
else:
shape = list(a_np.shape)
shape[ax] = 2
idx_np = np.random.randint(low=0, high=a_np.shape[ax], size=shape)
idx_np = np.random.choice(a_np.shape[ax], replace=False, size=(2,))
idx_np = np.expand_dims(idx_np, list(range(1, 2 - ax + 1)))
idx_np = np.broadcast_to(idx_np, shape)
values_np = np.random.randint(low=0, high=100, size=shape)
idx_np.astype(np.int32)