Fix concatenate/slice_update vjp + reduce binary size (#1735)

* fix concatenate vjp + reduce binary size

* also cast in slice update
This commit is contained in:
Awni Hannun
2025-01-02 16:36:33 -08:00
committed by GitHub
parent ae69cb15e9
commit 6fa0501387
4 changed files with 38 additions and 26 deletions

View File

@@ -620,6 +620,18 @@ class TestAutograd(mlx_tests.MLXTestCase):
x = mx.zeros((2, 4, 8))
self.assertEqual(mx.grad(fun)(x).shape, (2, 4, 8))
def test_concatenate_vjps(self):
def fun(x, y):
return mx.concatenate([x, y])
x = mx.array([1, 2, 3], mx.float32)
y = mx.array([1, 2, 3], mx.float16)
grads = mx.vjp(fun, (x, y), (mx.ones((6,)),))[1]
self.assertTrue(mx.allclose(grads[0], mx.ones(3)))
self.assertTrue(mx.allclose(grads[1], mx.ones(3)))
self.assertEqual(grads[0].dtype, mx.float32)
self.assertEqual(grads[1].dtype, mx.float16)
if __name__ == "__main__":
unittest.main()