mirror of
https://github.com/ml-explore/mlx.git
synced 2025-10-19 00:04:41 +08:00
Fix concatenate/slice_update vjp + reduce binary size (#1735)
* fix concatenate vjp + reduce binary size * also cast in slice update
This commit is contained in:
@@ -620,6 +620,18 @@ class TestAutograd(mlx_tests.MLXTestCase):
|
||||
x = mx.zeros((2, 4, 8))
|
||||
self.assertEqual(mx.grad(fun)(x).shape, (2, 4, 8))
|
||||
|
||||
def test_concatenate_vjps(self):
|
||||
def fun(x, y):
|
||||
return mx.concatenate([x, y])
|
||||
|
||||
x = mx.array([1, 2, 3], mx.float32)
|
||||
y = mx.array([1, 2, 3], mx.float16)
|
||||
grads = mx.vjp(fun, (x, y), (mx.ones((6,)),))[1]
|
||||
self.assertTrue(mx.allclose(grads[0], mx.ones(3)))
|
||||
self.assertTrue(mx.allclose(grads[1], mx.ones(3)))
|
||||
self.assertEqual(grads[0].dtype, mx.float32)
|
||||
self.assertEqual(grads[1].dtype, mx.float16)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
Reference in New Issue
Block a user