scatter axis + gather axis primitives (#1813)

* scatter axis + gather axis primitives

* add transforms

* comment
This commit is contained in:
Awni Hannun
2025-01-31 20:48:08 -08:00
committed by GitHub
parent c6fc07f1f4
commit b7c9f1d38f
15 changed files with 1037 additions and 85 deletions

View File

@@ -669,6 +669,37 @@ class TestAutograd(mlx_tests.MLXTestCase):
_, (expected,) = mx.jvp(lambda c: mx.addmm(c, a, b), (c,), (z,))
self.assertTrue(mx.allclose(tangent, expected))
def test_put_along_axis_grads(self):
a = mx.zeros((5, 1))
b = mx.ones((2, 1))
def fun(a, b):
idx = mx.array([[0], [3]])
return mx.put_along_axis(a, idx, b, axis=0)
# Test VJP
cotan = mx.full((5, 1), 2.0)
_, (da, db) = mx.vjp(fun, (a, b), (cotan,))
expected_da = mx.array([0.0, 2.0, 2.0, 0.0, 2.0])[:, None]
expected_db = mx.array([2.0, 2.0])[:, None]
self.assertTrue(mx.allclose(expected_da, da))
self.assertTrue(mx.allclose(expected_db, db))
# Test JVP
tan_a = mx.full((5, 1), 2.0)
tan_b = mx.full((2, 1), 3.0)
_, (jout,) = mx.jvp(fun, (a, b), (tan_a, tan_b))
expected = mx.array([3.0, 2.0, 2.0, 3.0, 2.0])[:, None]
self.assertTrue(mx.allclose(expected, jout))
def fun(a):
idx = mx.array([[0], [3]])
return mx.put_along_axis(a, idx, b, axis=0)
_, (jout,) = mx.jvp(fun, (a,), (tan_a,))
expected = mx.array([0.0, 2.0, 2.0, 0.0, 2.0])[:, None]
self.assertTrue(mx.allclose(expected, jout))
if __name__ == "__main__":
unittest.main()