Put along axis + fixe for partition grad (#1430)

* put along axis, fixes for partition grad

* zeros for arg reduce
This commit is contained in:
Awni Hannun
2024-09-23 10:03:38 -07:00
committed by GitHub
parent 2b878e9dd7
commit 195b429d99
9 changed files with 220 additions and 9 deletions

View File

@@ -496,6 +496,16 @@ class TestAutograd(mlx_tests.MLXTestCase):
expected = mx.array([0.0, 0.0, 0.0, 9.0, 1.0])
self.assertTrue(mx.allclose(out, expected))
def test_topk_grad(self):
a = mx.array([[1, 2, 6, 4, 5], [9, 5, 6, 7, 8]], mx.float32)
def fun(x):
return mx.topk(x, 2)
out = mx.vjp(fun, (a,), (mx.ones((2, 2)),))[1][0]
expected = mx.array([[0, 0, 1, 0, 1], [1, 0, 0, 0, 1]], mx.float32)
self.assertTrue(mx.array_equal(out, expected))
def test_custom_function(self):
# Make a custom function
my_exp = mx.custom_function(mx.exp)

View File

@@ -1075,6 +1075,31 @@ class TestOps(mlx_tests.MLXTestCase):
out_mlx = mx.take_along_axis(a_mlx, mx.reshape(idx_mlx, shape), axis=ax)
self.assertTrue(np.array_equal(out_np, np.array(out_mlx)))
def test_put_along_axis(self):
for ax in [None, 0, 1, 2]:
a_np = np.arange(16).reshape(2, 2, 4).astype(np.int32)
a_mlx = mx.array(a_np)
if ax == None:
idx_np = np.random.randint(low=0, high=a_np.size, size=(16,))
values_np = np.random.randint(low=0, high=100, size=(16,))
else:
shape = list(a_np.shape)
shape[ax] = 2
idx_np = np.random.randint(low=0, high=a_np.shape[ax], size=shape)
values_np = np.random.randint(low=0, high=100, size=shape)
idx_np.astype(np.int32)
values_np.astype(a_np.dtype)
idx_mlx = mx.array(idx_np)
values_mlx = mx.array(values_np)
np.put_along_axis(a_np, idx_np, values_np, axis=ax)
out_mlx = mx.put_along_axis(a_mlx, idx_mlx, values_mlx, axis=ax)
self.assertTrue(np.array_equal(a_np, out_mlx))
def test_split(self):
a = mx.array([1, 2, 3])
splits = mx.split(a, 3)