mirror of
https://github.com/ml-explore/mlx.git
synced 2025-10-19 00:04:41 +08:00
Fix a couple of slicing bugs (#1827)
* fix a few bugs * fix conv grad * speedup test * comment
This commit is contained in:
@@ -700,6 +700,43 @@ class TestAutograd(mlx_tests.MLXTestCase):
|
||||
expected = mx.array([0.0, 2.0, 2.0, 0.0, 2.0])[:, None]
|
||||
self.assertTrue(mx.allclose(expected, jout))
|
||||
|
||||
def test_slice_grads(self):
|
||||
# Slice
|
||||
def fun(a):
|
||||
return a[5:-6:-1]
|
||||
|
||||
a = mx.ones(shape=(5,))
|
||||
cotan = mx.random.uniform(shape=(5,))
|
||||
_, (grad,) = mx.vjp(fun, (a,), (cotan,))
|
||||
self.assertTrue(mx.allclose(grad, cotan[::-1]))
|
||||
|
||||
tan = mx.random.uniform(shape=(5,))
|
||||
mx.eval(tan)
|
||||
_, (grad,) = mx.jvp(fun, (a,), (tan,))
|
||||
self.assertTrue(mx.allclose(grad, tan[::-1]))
|
||||
|
||||
# Slice update
|
||||
def fun(a, b):
|
||||
a[4:-5:-2] = b
|
||||
return a
|
||||
|
||||
a = mx.ones(shape=(4,))
|
||||
b = mx.zeros(shape=(2,))
|
||||
|
||||
cotan = mx.random.uniform(shape=(4,))
|
||||
_, (grad_a, grad_b) = mx.vjp(fun, (a, b), (cotan,))
|
||||
expected_a = mx.array(cotan)
|
||||
expected_a[1::2] = 0.0
|
||||
self.assertTrue(mx.allclose(grad_a, expected_a))
|
||||
self.assertTrue(mx.allclose(grad_b, cotan[4:-5:-2]))
|
||||
|
||||
tan_a = mx.random.uniform(shape=(4,))
|
||||
tan_b = mx.random.uniform(shape=(2,))
|
||||
_, (grad,) = mx.jvp(fun, (a, b), (tan_a, tan_b))
|
||||
expected = tan_a
|
||||
expected[4:-5:-2] = tan_b
|
||||
self.assertTrue(mx.allclose(grad, expected))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
@@ -911,6 +911,44 @@ class TestConv(mlx_tests.MLXTestCase):
|
||||
expected = mx.array([[dw00, dw01], [dw10, dw11]])
|
||||
self.assertTrue(mx.allclose(dw, expected, rtol=1e-5, atol=1e-5))
|
||||
|
||||
# Test with input dilation
|
||||
inputs = mx.random.normal((1, 14, 14, 2))
|
||||
kernel = mx.random.normal((2, 7, 7, 2))
|
||||
|
||||
def conv_flip(kernel):
|
||||
return mx.conv_general(
|
||||
inputs,
|
||||
kernel,
|
||||
stride=1,
|
||||
padding=([6, 6], [15, 15]),
|
||||
kernel_dilation=(1, 1),
|
||||
input_dilation=(16, 16),
|
||||
groups=1,
|
||||
flip=True,
|
||||
).sum()
|
||||
|
||||
def reverse_sequence(xs, axis=0):
|
||||
indices = mx.arange(xs.shape[axis] - 1, -1, -1)
|
||||
return mx.take(xs, indices, axis=axis)
|
||||
|
||||
def conv_manual_flip(kernel):
|
||||
for ax in range(1, kernel.ndim - 1):
|
||||
kernel = reverse_sequence(kernel, axis=ax)
|
||||
return mx.conv_general(
|
||||
inputs,
|
||||
kernel,
|
||||
stride=1,
|
||||
padding=([6, 6], [15, 15]),
|
||||
kernel_dilation=(1, 1),
|
||||
input_dilation=(16, 16),
|
||||
groups=1,
|
||||
flip=False,
|
||||
).sum()
|
||||
|
||||
grad = mx.grad(conv_flip)(kernel)
|
||||
expected_grad = mx.grad(conv_manual_flip)(kernel)
|
||||
self.assertTrue(mx.allclose(grad, expected_grad))
|
||||
|
||||
def test_conv_groups_grad(self):
|
||||
def fn(x, w):
|
||||
num_groups = x.shape[-1] // w.shape[-1]
|
||||
|
@@ -587,10 +587,10 @@ class TestConvTranspose(mlx_tests.MLXTestCase):
|
||||
for idim, kdim, stride, padding, dilation in (
|
||||
((1, 1, 1), (1, 1, 1), (1, 1, 1), (0, 0, 0), (1, 1, 1)),
|
||||
((3, 3, 3), (3, 1, 1), (1, 1, 1), (0, 0, 0), (1, 1, 1)),
|
||||
((15, 15, 15), (5, 5, 5), (5, 5, 5), (2, 2, 2), (1, 1, 1)),
|
||||
((16, 16, 16), (3, 3, 3), (2, 2, 2), (1, 1, 1), (1, 1, 1)),
|
||||
((15, 15, 15), (5, 5, 5), (5, 5, 5), (2, 2, 2), (3, 2, 2)),
|
||||
((16, 16, 16), (3, 3, 3), (2, 2, 2), (1, 1, 1), (3, 2, 2)),
|
||||
((7, 7, 7), (5, 5, 5), (5, 5, 5), (2, 2, 2), (1, 1, 1)),
|
||||
((8, 8, 8), (3, 3, 3), (2, 2, 2), (1, 1, 1), (1, 1, 1)),
|
||||
((7, 7, 7), (5, 5, 5), (3, 3, 3), (2, 2, 2), (3, 2, 2)),
|
||||
((8, 8, 8), (3, 3, 3), (2, 2, 2), (1, 1, 1), (3, 2, 2)),
|
||||
):
|
||||
run_conv_transpose3D_grad(
|
||||
N, C, O, idim, kdim, stride, padding, dilation, dtype=dtype
|
||||
|
@@ -2816,6 +2816,12 @@ class TestOps(mlx_tests.MLXTestCase):
|
||||
self.assertEqual(a.shape, (3, 4, 2))
|
||||
self.assertEqual(b.shape, (3, 4, 2))
|
||||
|
||||
def test_slice_update_reversed(self):
|
||||
a = mx.array([1, 2, 3, 4])
|
||||
b = a[::-1]
|
||||
b[::2] = 0
|
||||
self.assertTrue(mx.array_equal(b, mx.array([0, 3, 0, 1])))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
Reference in New Issue
Block a user