mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-14 17:12:49 +08:00
scatter axis + gather axis primitives (#1813)
* scatter axis + gather axis primitives * add transforms * comment
This commit is contained in:
@@ -1150,6 +1150,15 @@ class TestOps(mlx_tests.MLXTestCase):
|
||||
out_mlx = mx.put_along_axis(a_mlx, idx_mlx, values_mlx, axis=ax)
|
||||
self.assertTrue(np.array_equal(a_np, out_mlx))
|
||||
|
||||
source = mx.zeros((1, 1, 8, 32))
|
||||
indices = mx.array([0, 2, 4, 5]).reshape((1, 1, 4, 1))
|
||||
update = mx.array(1.0)
|
||||
|
||||
out_mlx = mx.put_along_axis(source, indices, update, axis=-2)
|
||||
out_np = np.array(source)
|
||||
np.put_along_axis(out_np, np.array(indices), np.array(update), axis=-2)
|
||||
self.assertTrue(np.array_equal(out_np, np.array(out_mlx)))
|
||||
|
||||
def test_split(self):
|
||||
a = mx.array([1, 2, 3])
|
||||
splits = mx.split(a, 3)
|
||||
|
||||
Reference in New Issue
Block a user