mirror of
https://github.com/ml-explore/mlx.git
synced 2025-11-01 08:38:12 +08:00
Put along axis + fixe for partition grad (#1430)
* put along axis, fixes for partition grad * zeros for arg reduce
This commit is contained in:
@@ -496,6 +496,16 @@ class TestAutograd(mlx_tests.MLXTestCase):
|
||||
expected = mx.array([0.0, 0.0, 0.0, 9.0, 1.0])
|
||||
self.assertTrue(mx.allclose(out, expected))
|
||||
|
||||
def test_topk_grad(self):
|
||||
a = mx.array([[1, 2, 6, 4, 5], [9, 5, 6, 7, 8]], mx.float32)
|
||||
|
||||
def fun(x):
|
||||
return mx.topk(x, 2)
|
||||
|
||||
out = mx.vjp(fun, (a,), (mx.ones((2, 2)),))[1][0]
|
||||
expected = mx.array([[0, 0, 1, 0, 1], [1, 0, 0, 0, 1]], mx.float32)
|
||||
self.assertTrue(mx.array_equal(out, expected))
|
||||
|
||||
def test_custom_function(self):
|
||||
# Make a custom function
|
||||
my_exp = mx.custom_function(mx.exp)
|
||||
|
||||
Reference in New Issue
Block a user