scatter axis + gather axis primitives (#1813)

* scatter axis + gather axis primitives

* add transforms

* comment
This commit is contained in:
Awni Hannun
2025-01-31 20:48:08 -08:00
committed by GitHub
parent c6fc07f1f4
commit b7c9f1d38f
15 changed files with 1037 additions and 85 deletions

View File

@@ -1150,6 +1150,15 @@ class TestOps(mlx_tests.MLXTestCase):
out_mlx = mx.put_along_axis(a_mlx, idx_mlx, values_mlx, axis=ax)
self.assertTrue(np.array_equal(a_np, out_mlx))
source = mx.zeros((1, 1, 8, 32))
indices = mx.array([0, 2, 4, 5]).reshape((1, 1, 4, 1))
update = mx.array(1.0)
out_mlx = mx.put_along_axis(source, indices, update, axis=-2)
out_np = np.array(source)
np.put_along_axis(out_np, np.array(indices), np.array(update), axis=-2)
self.assertTrue(np.array_equal(out_np, np.array(out_mlx)))
def test_split(self):
a = mx.array([1, 2, 3])
splits = mx.split(a, 3)