Fix a couple of slicing bugs (#1827)

* fix a few bugs

* fix conv grad

* speedup test

* comment
This commit is contained in:
Awni Hannun
2025-02-05 19:50:08 -08:00
committed by GitHub
parent 9174606d4c
commit af1b725fda
14 changed files with 170 additions and 107 deletions

View File

@@ -911,6 +911,44 @@ class TestConv(mlx_tests.MLXTestCase):
expected = mx.array([[dw00, dw01], [dw10, dw11]])
self.assertTrue(mx.allclose(dw, expected, rtol=1e-5, atol=1e-5))
# Test with input dilation
inputs = mx.random.normal((1, 14, 14, 2))
kernel = mx.random.normal((2, 7, 7, 2))
def conv_flip(kernel):
return mx.conv_general(
inputs,
kernel,
stride=1,
padding=([6, 6], [15, 15]),
kernel_dilation=(1, 1),
input_dilation=(16, 16),
groups=1,
flip=True,
).sum()
def reverse_sequence(xs, axis=0):
indices = mx.arange(xs.shape[axis] - 1, -1, -1)
return mx.take(xs, indices, axis=axis)
def conv_manual_flip(kernel):
for ax in range(1, kernel.ndim - 1):
kernel = reverse_sequence(kernel, axis=ax)
return mx.conv_general(
inputs,
kernel,
stride=1,
padding=([6, 6], [15, 15]),
kernel_dilation=(1, 1),
input_dilation=(16, 16),
groups=1,
flip=False,
).sum()
grad = mx.grad(conv_flip)(kernel)
expected_grad = mx.grad(conv_manual_flip)(kernel)
self.assertTrue(mx.allclose(grad, expected_grad))
def test_conv_groups_grad(self):
def fn(x, w):
num_groups = x.shape[-1] // w.shape[-1]