Conv grad with groups + bugfix (#1449)

* fix bug in flipped conv with groups, start of grad for groups

* fix

* fix

* fix + test
This commit is contained in:
Awni Hannun
2024-10-06 07:08:53 -07:00
committed by GitHub
parent fef3c4ec1d
commit e4534dac17
6 changed files with 197 additions and 176 deletions

View File

@@ -47,6 +47,13 @@ class TestConv(mlx_tests.MLXTestCase):
self.assertEqual(c_mx.shape, c_np.shape)
self.assertTrue(np.allclose(c_mx, c_np, atol=atol))
def test_conv_1d_groups_flipped(self):
x = mx.broadcast_to(mx.arange(5).astype(mx.float32), (2, 5)).T
w = mx.broadcast_to(mx.arange(4).astype(mx.float32), (2, 4))
out = mx.conv_general(x[None], w[..., None], flip=True, groups=2)
expected = mx.array([4.0, 4.0, 10.0, 10.0]).reshape(1, 2, 2)
self.assertTrue(mx.allclose(out, expected))
@unittest.skipIf(not has_torch, "requires Torch")
def test_torch_conv_1D(self):
def run_conv1D(
@@ -897,6 +904,99 @@ class TestConv(mlx_tests.MLXTestCase):
expected = mx.array([[dw00, dw01], [dw10, dw11]])
self.assertTrue(mx.allclose(dw, expected))
def test_conv_groups_grad(self):
def fn(x, w):
num_groups = x.shape[-1] // w.shape[-1]
return mx.conv1d(x, w, groups=num_groups)
def fn_gt(x, w):
num_groups = x.shape[-1] // w.shape[-1]
group_size = w.shape[-1]
ws = w.reshape(num_groups, -1, *w.shape[1:]).split(num_groups)
xs = x.reshape(*x.shape[:-1], num_groups, -1).split(num_groups, axis=-2)
return mx.concatenate(
[mx.conv_general(x.squeeze(-2), w.squeeze(0)) for x, w in zip(xs, ws)],
axis=-1,
)
mx.random.seed(3)
w = mx.random.normal(shape=(2, 3, 1))
x = mx.random.normal(shape=(1, 5, 2))
cotans = (mx.ones(shape=(1, 3, 2)),)
grads = mx.vjp(fn, (x, w), cotans)[1]
expected = mx.vjp(fn_gt, (x, w), cotans)[1]
self.assertTrue(mx.allclose(expected[0], grads[0]))
self.assertTrue(mx.allclose(expected[1], grads[1]))
w = mx.random.normal(shape=(2, 3, 2))
x = mx.random.normal(shape=(1, 5, 4))
cotans = (mx.ones(shape=(1, 3, 2)),)
grads = mx.vjp(fn, (x, w), cotans)[1]
expected = mx.vjp(fn_gt, (x, w), cotans)[1]
self.assertTrue(mx.allclose(expected[0], grads[0]))
self.assertTrue(mx.allclose(expected[1], grads[1]))
w = mx.random.normal(shape=(6, 3, 2))
x = mx.random.normal(shape=(1, 5, 4))
cotans = (mx.ones(shape=(1, 3, 6)),)
grads = mx.vjp(fn, (x, w), cotans)[1]
expected = mx.vjp(fn_gt, (x, w), cotans)[1]
self.assertTrue(mx.allclose(expected[0], grads[0]))
self.assertTrue(mx.allclose(expected[1], grads[1]))
# Test 2D
w = mx.random.normal(shape=(2, 3, 3, 1))
x = mx.random.normal(shape=(1, 5, 5, 2))
cotans = (mx.ones(shape=(1, 3, 3, 2)),)
grads = mx.vjp(fn, (x, w), cotans)[1]
expected = mx.vjp(fn_gt, (x, w), cotans)[1]
self.assertTrue(mx.allclose(expected[0], grads[0]))
self.assertTrue(mx.allclose(expected[1], grads[1]))
# Test with flip
def fn(x, w):
num_groups = x.shape[-1] // w.shape[-1]
return mx.conv_general(x, w, groups=num_groups, flip=True)
def fn_gt(x, w):
num_groups = x.shape[-1] // w.shape[-1]
group_size = w.shape[-1]
ws = w.reshape(num_groups, -1, *w.shape[1:]).split(num_groups)
xs = x.reshape(*x.shape[:-1], num_groups, -1).split(num_groups, axis=-2)
return mx.concatenate(
[
mx.conv_general(x.squeeze(-2), w.squeeze(0), flip=True)
for x, w in zip(xs, ws)
],
axis=-1,
)
w = mx.random.normal(shape=(2, 3, 1))
x = mx.random.normal(shape=(1, 5, 2))
cotans = (mx.ones(shape=(1, 3, 2)),)
grads = mx.vjp(fn, (x, w), cotans)[1]
expected = mx.vjp(fn_gt, (x, w), cotans)[1]
self.assertTrue(mx.allclose(expected[0], grads[0]))
self.assertTrue(mx.allclose(expected[1], grads[1]))
w = mx.random.normal(shape=(2, 3, 2))
x = mx.random.normal(shape=(1, 5, 4))
cotans = (mx.ones(shape=(1, 3, 2)),)
grads = mx.vjp(fn, (x, w), cotans)[1]
expected = mx.vjp(fn_gt, (x, w), cotans)[1]
self.assertTrue(mx.allclose(expected[0], grads[0]))
self.assertTrue(mx.allclose(expected[1], grads[1]))
# Test 2D
w = mx.random.normal(shape=(2, 3, 3, 1))
x = mx.random.normal(shape=(1, 5, 5, 2))
cotans = (mx.ones(shape=(1, 3, 3, 2)),)
grads = mx.vjp(fn, (x, w), cotans)[1]
expected = mx.vjp(fn_gt, (x, w), cotans)[1]
self.assertTrue(mx.allclose(expected[0], grads[0]))
self.assertTrue(mx.allclose(expected[1], grads[1]))
if __name__ == "__main__":
unittest.main()