mirror of
https://github.com/ml-explore/mlx.git
synced 2025-10-30 23:38:09 +08:00
Put along axis + fixe for partition grad (#1430)
* put along axis, fixes for partition grad * zeros for arg reduce
This commit is contained in:
@@ -496,6 +496,16 @@ class TestAutograd(mlx_tests.MLXTestCase):
|
||||
expected = mx.array([0.0, 0.0, 0.0, 9.0, 1.0])
|
||||
self.assertTrue(mx.allclose(out, expected))
|
||||
|
||||
def test_topk_grad(self):
|
||||
a = mx.array([[1, 2, 6, 4, 5], [9, 5, 6, 7, 8]], mx.float32)
|
||||
|
||||
def fun(x):
|
||||
return mx.topk(x, 2)
|
||||
|
||||
out = mx.vjp(fun, (a,), (mx.ones((2, 2)),))[1][0]
|
||||
expected = mx.array([[0, 0, 1, 0, 1], [1, 0, 0, 0, 1]], mx.float32)
|
||||
self.assertTrue(mx.array_equal(out, expected))
|
||||
|
||||
def test_custom_function(self):
|
||||
# Make a custom function
|
||||
my_exp = mx.custom_function(mx.exp)
|
||||
|
||||
@@ -1075,6 +1075,31 @@ class TestOps(mlx_tests.MLXTestCase):
|
||||
out_mlx = mx.take_along_axis(a_mlx, mx.reshape(idx_mlx, shape), axis=ax)
|
||||
self.assertTrue(np.array_equal(out_np, np.array(out_mlx)))
|
||||
|
||||
def test_put_along_axis(self):
|
||||
for ax in [None, 0, 1, 2]:
|
||||
|
||||
a_np = np.arange(16).reshape(2, 2, 4).astype(np.int32)
|
||||
a_mlx = mx.array(a_np)
|
||||
|
||||
if ax == None:
|
||||
idx_np = np.random.randint(low=0, high=a_np.size, size=(16,))
|
||||
values_np = np.random.randint(low=0, high=100, size=(16,))
|
||||
else:
|
||||
shape = list(a_np.shape)
|
||||
shape[ax] = 2
|
||||
idx_np = np.random.randint(low=0, high=a_np.shape[ax], size=shape)
|
||||
values_np = np.random.randint(low=0, high=100, size=shape)
|
||||
|
||||
idx_np.astype(np.int32)
|
||||
values_np.astype(a_np.dtype)
|
||||
|
||||
idx_mlx = mx.array(idx_np)
|
||||
values_mlx = mx.array(values_np)
|
||||
|
||||
np.put_along_axis(a_np, idx_np, values_np, axis=ax)
|
||||
out_mlx = mx.put_along_axis(a_mlx, idx_mlx, values_mlx, axis=ax)
|
||||
self.assertTrue(np.array_equal(a_np, out_mlx))
|
||||
|
||||
def test_split(self):
|
||||
a = mx.array([1, 2, 3])
|
||||
splits = mx.split(a, 3)
|
||||
|
||||
Reference in New Issue
Block a user