mirror of
https://github.com/ml-explore/mlx.git
synced 2025-10-19 00:04:41 +08:00
Fix a couple of slicing bugs (#1827)
* fix a few bugs * fix conv grad * speedup test * comment
This commit is contained in:
@@ -700,6 +700,43 @@ class TestAutograd(mlx_tests.MLXTestCase):
|
||||
expected = mx.array([0.0, 2.0, 2.0, 0.0, 2.0])[:, None]
|
||||
self.assertTrue(mx.allclose(expected, jout))
|
||||
|
||||
def test_slice_grads(self):
|
||||
# Slice
|
||||
def fun(a):
|
||||
return a[5:-6:-1]
|
||||
|
||||
a = mx.ones(shape=(5,))
|
||||
cotan = mx.random.uniform(shape=(5,))
|
||||
_, (grad,) = mx.vjp(fun, (a,), (cotan,))
|
||||
self.assertTrue(mx.allclose(grad, cotan[::-1]))
|
||||
|
||||
tan = mx.random.uniform(shape=(5,))
|
||||
mx.eval(tan)
|
||||
_, (grad,) = mx.jvp(fun, (a,), (tan,))
|
||||
self.assertTrue(mx.allclose(grad, tan[::-1]))
|
||||
|
||||
# Slice update
|
||||
def fun(a, b):
|
||||
a[4:-5:-2] = b
|
||||
return a
|
||||
|
||||
a = mx.ones(shape=(4,))
|
||||
b = mx.zeros(shape=(2,))
|
||||
|
||||
cotan = mx.random.uniform(shape=(4,))
|
||||
_, (grad_a, grad_b) = mx.vjp(fun, (a, b), (cotan,))
|
||||
expected_a = mx.array(cotan)
|
||||
expected_a[1::2] = 0.0
|
||||
self.assertTrue(mx.allclose(grad_a, expected_a))
|
||||
self.assertTrue(mx.allclose(grad_b, cotan[4:-5:-2]))
|
||||
|
||||
tan_a = mx.random.uniform(shape=(4,))
|
||||
tan_b = mx.random.uniform(shape=(2,))
|
||||
_, (grad,) = mx.jvp(fun, (a, b), (tan_a, tan_b))
|
||||
expected = tan_a
|
||||
expected[4:-5:-2] = tan_b
|
||||
self.assertTrue(mx.allclose(grad, expected))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
Reference in New Issue
Block a user