mirror of
https://github.com/ml-explore/mlx.git
synced 2025-10-19 00:04:41 +08:00
Conv grad with groups + bugfix (#1449)
* fix bug in flipped conv with groups, start of grad for groups * fix * fix * fix + test
This commit is contained in:
@@ -47,6 +47,13 @@ class TestConv(mlx_tests.MLXTestCase):
|
||||
self.assertEqual(c_mx.shape, c_np.shape)
|
||||
self.assertTrue(np.allclose(c_mx, c_np, atol=atol))
|
||||
|
||||
def test_conv_1d_groups_flipped(self):
|
||||
x = mx.broadcast_to(mx.arange(5).astype(mx.float32), (2, 5)).T
|
||||
w = mx.broadcast_to(mx.arange(4).astype(mx.float32), (2, 4))
|
||||
out = mx.conv_general(x[None], w[..., None], flip=True, groups=2)
|
||||
expected = mx.array([4.0, 4.0, 10.0, 10.0]).reshape(1, 2, 2)
|
||||
self.assertTrue(mx.allclose(out, expected))
|
||||
|
||||
@unittest.skipIf(not has_torch, "requires Torch")
|
||||
def test_torch_conv_1D(self):
|
||||
def run_conv1D(
|
||||
@@ -897,6 +904,99 @@ class TestConv(mlx_tests.MLXTestCase):
|
||||
expected = mx.array([[dw00, dw01], [dw10, dw11]])
|
||||
self.assertTrue(mx.allclose(dw, expected))
|
||||
|
||||
def test_conv_groups_grad(self):
|
||||
def fn(x, w):
|
||||
num_groups = x.shape[-1] // w.shape[-1]
|
||||
return mx.conv1d(x, w, groups=num_groups)
|
||||
|
||||
def fn_gt(x, w):
|
||||
num_groups = x.shape[-1] // w.shape[-1]
|
||||
group_size = w.shape[-1]
|
||||
ws = w.reshape(num_groups, -1, *w.shape[1:]).split(num_groups)
|
||||
xs = x.reshape(*x.shape[:-1], num_groups, -1).split(num_groups, axis=-2)
|
||||
return mx.concatenate(
|
||||
[mx.conv_general(x.squeeze(-2), w.squeeze(0)) for x, w in zip(xs, ws)],
|
||||
axis=-1,
|
||||
)
|
||||
|
||||
mx.random.seed(3)
|
||||
|
||||
w = mx.random.normal(shape=(2, 3, 1))
|
||||
x = mx.random.normal(shape=(1, 5, 2))
|
||||
cotans = (mx.ones(shape=(1, 3, 2)),)
|
||||
grads = mx.vjp(fn, (x, w), cotans)[1]
|
||||
expected = mx.vjp(fn_gt, (x, w), cotans)[1]
|
||||
self.assertTrue(mx.allclose(expected[0], grads[0]))
|
||||
self.assertTrue(mx.allclose(expected[1], grads[1]))
|
||||
|
||||
w = mx.random.normal(shape=(2, 3, 2))
|
||||
x = mx.random.normal(shape=(1, 5, 4))
|
||||
cotans = (mx.ones(shape=(1, 3, 2)),)
|
||||
grads = mx.vjp(fn, (x, w), cotans)[1]
|
||||
expected = mx.vjp(fn_gt, (x, w), cotans)[1]
|
||||
self.assertTrue(mx.allclose(expected[0], grads[0]))
|
||||
self.assertTrue(mx.allclose(expected[1], grads[1]))
|
||||
|
||||
w = mx.random.normal(shape=(6, 3, 2))
|
||||
x = mx.random.normal(shape=(1, 5, 4))
|
||||
cotans = (mx.ones(shape=(1, 3, 6)),)
|
||||
grads = mx.vjp(fn, (x, w), cotans)[1]
|
||||
expected = mx.vjp(fn_gt, (x, w), cotans)[1]
|
||||
self.assertTrue(mx.allclose(expected[0], grads[0]))
|
||||
self.assertTrue(mx.allclose(expected[1], grads[1]))
|
||||
|
||||
# Test 2D
|
||||
w = mx.random.normal(shape=(2, 3, 3, 1))
|
||||
x = mx.random.normal(shape=(1, 5, 5, 2))
|
||||
cotans = (mx.ones(shape=(1, 3, 3, 2)),)
|
||||
grads = mx.vjp(fn, (x, w), cotans)[1]
|
||||
expected = mx.vjp(fn_gt, (x, w), cotans)[1]
|
||||
self.assertTrue(mx.allclose(expected[0], grads[0]))
|
||||
self.assertTrue(mx.allclose(expected[1], grads[1]))
|
||||
|
||||
# Test with flip
|
||||
def fn(x, w):
|
||||
num_groups = x.shape[-1] // w.shape[-1]
|
||||
return mx.conv_general(x, w, groups=num_groups, flip=True)
|
||||
|
||||
def fn_gt(x, w):
|
||||
num_groups = x.shape[-1] // w.shape[-1]
|
||||
group_size = w.shape[-1]
|
||||
ws = w.reshape(num_groups, -1, *w.shape[1:]).split(num_groups)
|
||||
xs = x.reshape(*x.shape[:-1], num_groups, -1).split(num_groups, axis=-2)
|
||||
return mx.concatenate(
|
||||
[
|
||||
mx.conv_general(x.squeeze(-2), w.squeeze(0), flip=True)
|
||||
for x, w in zip(xs, ws)
|
||||
],
|
||||
axis=-1,
|
||||
)
|
||||
|
||||
w = mx.random.normal(shape=(2, 3, 1))
|
||||
x = mx.random.normal(shape=(1, 5, 2))
|
||||
cotans = (mx.ones(shape=(1, 3, 2)),)
|
||||
grads = mx.vjp(fn, (x, w), cotans)[1]
|
||||
expected = mx.vjp(fn_gt, (x, w), cotans)[1]
|
||||
self.assertTrue(mx.allclose(expected[0], grads[0]))
|
||||
self.assertTrue(mx.allclose(expected[1], grads[1]))
|
||||
|
||||
w = mx.random.normal(shape=(2, 3, 2))
|
||||
x = mx.random.normal(shape=(1, 5, 4))
|
||||
cotans = (mx.ones(shape=(1, 3, 2)),)
|
||||
grads = mx.vjp(fn, (x, w), cotans)[1]
|
||||
expected = mx.vjp(fn_gt, (x, w), cotans)[1]
|
||||
self.assertTrue(mx.allclose(expected[0], grads[0]))
|
||||
self.assertTrue(mx.allclose(expected[1], grads[1]))
|
||||
|
||||
# Test 2D
|
||||
w = mx.random.normal(shape=(2, 3, 3, 1))
|
||||
x = mx.random.normal(shape=(1, 5, 5, 2))
|
||||
cotans = (mx.ones(shape=(1, 3, 3, 2)),)
|
||||
grads = mx.vjp(fn, (x, w), cotans)[1]
|
||||
expected = mx.vjp(fn_gt, (x, w), cotans)[1]
|
||||
self.assertTrue(mx.allclose(expected[0], grads[0]))
|
||||
self.assertTrue(mx.allclose(expected[1], grads[1]))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
Reference in New Issue
Block a user