Fix a couple of slicing bugs (#1827)

* fix a few bugs

* fix conv grad

* speedup test

* comment
This commit is contained in:
Awni Hannun
2025-02-05 19:50:08 -08:00
committed by GitHub
parent 9174606d4c
commit af1b725fda
14 changed files with 170 additions and 107 deletions

View File

@@ -700,6 +700,43 @@ class TestAutograd(mlx_tests.MLXTestCase):
expected = mx.array([0.0, 2.0, 2.0, 0.0, 2.0])[:, None]
self.assertTrue(mx.allclose(expected, jout))
def test_slice_grads(self):
# Slice
def fun(a):
return a[5:-6:-1]
a = mx.ones(shape=(5,))
cotan = mx.random.uniform(shape=(5,))
_, (grad,) = mx.vjp(fun, (a,), (cotan,))
self.assertTrue(mx.allclose(grad, cotan[::-1]))
tan = mx.random.uniform(shape=(5,))
mx.eval(tan)
_, (grad,) = mx.jvp(fun, (a,), (tan,))
self.assertTrue(mx.allclose(grad, tan[::-1]))
# Slice update
def fun(a, b):
a[4:-5:-2] = b
return a
a = mx.ones(shape=(4,))
b = mx.zeros(shape=(2,))
cotan = mx.random.uniform(shape=(4,))
_, (grad_a, grad_b) = mx.vjp(fun, (a, b), (cotan,))
expected_a = mx.array(cotan)
expected_a[1::2] = 0.0
self.assertTrue(mx.allclose(grad_a, expected_a))
self.assertTrue(mx.allclose(grad_b, cotan[4:-5:-2]))
tan_a = mx.random.uniform(shape=(4,))
tan_b = mx.random.uniform(shape=(2,))
_, (grad,) = mx.jvp(fun, (a, b), (tan_a, tan_b))
expected = tan_a
expected[4:-5:-2] = tan_b
self.assertTrue(mx.allclose(grad, expected))
if __name__ == "__main__":
unittest.main()