Put along axis + fixe for partition grad (#1430)

* put along axis, fixes for partition grad

* zeros for arg reduce
This commit is contained in:
Awni Hannun
2024-09-23 10:03:38 -07:00
committed by GitHub
parent 2b878e9dd7
commit 195b429d99
9 changed files with 220 additions and 9 deletions

View File

@@ -496,6 +496,16 @@ class TestAutograd(mlx_tests.MLXTestCase):
expected = mx.array([0.0, 0.0, 0.0, 9.0, 1.0])
self.assertTrue(mx.allclose(out, expected))
def test_topk_grad(self):
a = mx.array([[1, 2, 6, 4, 5], [9, 5, 6, 7, 8]], mx.float32)
def fun(x):
return mx.topk(x, 2)
out = mx.vjp(fun, (a,), (mx.ones((2, 2)),))[1][0]
expected = mx.array([[0, 0, 1, 0, 1], [1, 0, 0, 0, 1]], mx.float32)
self.assertTrue(mx.array_equal(out, expected))
def test_custom_function(self):
# Make a custom function
my_exp = mx.custom_function(mx.exp)