Put along axis + fixe for partition grad (#1430)

* put along axis, fixes for partition grad

* zeros for arg reduce
This commit is contained in:
Awni Hannun
2024-09-23 10:03:38 -07:00
committed by GitHub
parent 2b878e9dd7
commit 195b429d99
9 changed files with 220 additions and 9 deletions

View File

@@ -1075,6 +1075,31 @@ class TestOps(mlx_tests.MLXTestCase):
out_mlx = mx.take_along_axis(a_mlx, mx.reshape(idx_mlx, shape), axis=ax)
self.assertTrue(np.array_equal(out_np, np.array(out_mlx)))
def test_put_along_axis(self):
for ax in [None, 0, 1, 2]:
a_np = np.arange(16).reshape(2, 2, 4).astype(np.int32)
a_mlx = mx.array(a_np)
if ax == None:
idx_np = np.random.randint(low=0, high=a_np.size, size=(16,))
values_np = np.random.randint(low=0, high=100, size=(16,))
else:
shape = list(a_np.shape)
shape[ax] = 2
idx_np = np.random.randint(low=0, high=a_np.shape[ax], size=shape)
values_np = np.random.randint(low=0, high=100, size=shape)
idx_np.astype(np.int32)
values_np.astype(a_np.dtype)
idx_mlx = mx.array(idx_np)
values_mlx = mx.array(values_np)
np.put_along_axis(a_np, idx_np, values_np, axis=ax)
out_mlx = mx.put_along_axis(a_mlx, idx_mlx, values_mlx, axis=ax)
self.assertTrue(np.array_equal(a_np, out_mlx))
def test_split(self):
a = mx.array([1, 2, 3])
splits = mx.split(a, 3)