mirror of
				https://github.com/ml-explore/mlx.git
				synced 2025-11-04 18:48:15 +08:00 
			
		
		
		
	Conv grad with groups + bugfix (#1449)
* fix bug in flipped conv with groups, start of grad for groups * fix * fix * fix + test
This commit is contained in:
		@@ -47,6 +47,13 @@ class TestConv(mlx_tests.MLXTestCase):
 | 
			
		||||
                    self.assertEqual(c_mx.shape, c_np.shape)
 | 
			
		||||
                    self.assertTrue(np.allclose(c_mx, c_np, atol=atol))
 | 
			
		||||
 | 
			
		||||
    def test_conv_1d_groups_flipped(self):
 | 
			
		||||
        x = mx.broadcast_to(mx.arange(5).astype(mx.float32), (2, 5)).T
 | 
			
		||||
        w = mx.broadcast_to(mx.arange(4).astype(mx.float32), (2, 4))
 | 
			
		||||
        out = mx.conv_general(x[None], w[..., None], flip=True, groups=2)
 | 
			
		||||
        expected = mx.array([4.0, 4.0, 10.0, 10.0]).reshape(1, 2, 2)
 | 
			
		||||
        self.assertTrue(mx.allclose(out, expected))
 | 
			
		||||
 | 
			
		||||
    @unittest.skipIf(not has_torch, "requires Torch")
 | 
			
		||||
    def test_torch_conv_1D(self):
 | 
			
		||||
        def run_conv1D(
 | 
			
		||||
@@ -897,6 +904,99 @@ class TestConv(mlx_tests.MLXTestCase):
 | 
			
		||||
            expected = mx.array([[dw00, dw01], [dw10, dw11]])
 | 
			
		||||
            self.assertTrue(mx.allclose(dw, expected))
 | 
			
		||||
 | 
			
		||||
    def test_conv_groups_grad(self):
 | 
			
		||||
        def fn(x, w):
 | 
			
		||||
            num_groups = x.shape[-1] // w.shape[-1]
 | 
			
		||||
            return mx.conv1d(x, w, groups=num_groups)
 | 
			
		||||
 | 
			
		||||
        def fn_gt(x, w):
 | 
			
		||||
            num_groups = x.shape[-1] // w.shape[-1]
 | 
			
		||||
            group_size = w.shape[-1]
 | 
			
		||||
            ws = w.reshape(num_groups, -1, *w.shape[1:]).split(num_groups)
 | 
			
		||||
            xs = x.reshape(*x.shape[:-1], num_groups, -1).split(num_groups, axis=-2)
 | 
			
		||||
            return mx.concatenate(
 | 
			
		||||
                [mx.conv_general(x.squeeze(-2), w.squeeze(0)) for x, w in zip(xs, ws)],
 | 
			
		||||
                axis=-1,
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
        mx.random.seed(3)
 | 
			
		||||
 | 
			
		||||
        w = mx.random.normal(shape=(2, 3, 1))
 | 
			
		||||
        x = mx.random.normal(shape=(1, 5, 2))
 | 
			
		||||
        cotans = (mx.ones(shape=(1, 3, 2)),)
 | 
			
		||||
        grads = mx.vjp(fn, (x, w), cotans)[1]
 | 
			
		||||
        expected = mx.vjp(fn_gt, (x, w), cotans)[1]
 | 
			
		||||
        self.assertTrue(mx.allclose(expected[0], grads[0]))
 | 
			
		||||
        self.assertTrue(mx.allclose(expected[1], grads[1]))
 | 
			
		||||
 | 
			
		||||
        w = mx.random.normal(shape=(2, 3, 2))
 | 
			
		||||
        x = mx.random.normal(shape=(1, 5, 4))
 | 
			
		||||
        cotans = (mx.ones(shape=(1, 3, 2)),)
 | 
			
		||||
        grads = mx.vjp(fn, (x, w), cotans)[1]
 | 
			
		||||
        expected = mx.vjp(fn_gt, (x, w), cotans)[1]
 | 
			
		||||
        self.assertTrue(mx.allclose(expected[0], grads[0]))
 | 
			
		||||
        self.assertTrue(mx.allclose(expected[1], grads[1]))
 | 
			
		||||
 | 
			
		||||
        w = mx.random.normal(shape=(6, 3, 2))
 | 
			
		||||
        x = mx.random.normal(shape=(1, 5, 4))
 | 
			
		||||
        cotans = (mx.ones(shape=(1, 3, 6)),)
 | 
			
		||||
        grads = mx.vjp(fn, (x, w), cotans)[1]
 | 
			
		||||
        expected = mx.vjp(fn_gt, (x, w), cotans)[1]
 | 
			
		||||
        self.assertTrue(mx.allclose(expected[0], grads[0]))
 | 
			
		||||
        self.assertTrue(mx.allclose(expected[1], grads[1]))
 | 
			
		||||
 | 
			
		||||
        # Test 2D
 | 
			
		||||
        w = mx.random.normal(shape=(2, 3, 3, 1))
 | 
			
		||||
        x = mx.random.normal(shape=(1, 5, 5, 2))
 | 
			
		||||
        cotans = (mx.ones(shape=(1, 3, 3, 2)),)
 | 
			
		||||
        grads = mx.vjp(fn, (x, w), cotans)[1]
 | 
			
		||||
        expected = mx.vjp(fn_gt, (x, w), cotans)[1]
 | 
			
		||||
        self.assertTrue(mx.allclose(expected[0], grads[0]))
 | 
			
		||||
        self.assertTrue(mx.allclose(expected[1], grads[1]))
 | 
			
		||||
 | 
			
		||||
        # Test with flip
 | 
			
		||||
        def fn(x, w):
 | 
			
		||||
            num_groups = x.shape[-1] // w.shape[-1]
 | 
			
		||||
            return mx.conv_general(x, w, groups=num_groups, flip=True)
 | 
			
		||||
 | 
			
		||||
        def fn_gt(x, w):
 | 
			
		||||
            num_groups = x.shape[-1] // w.shape[-1]
 | 
			
		||||
            group_size = w.shape[-1]
 | 
			
		||||
            ws = w.reshape(num_groups, -1, *w.shape[1:]).split(num_groups)
 | 
			
		||||
            xs = x.reshape(*x.shape[:-1], num_groups, -1).split(num_groups, axis=-2)
 | 
			
		||||
            return mx.concatenate(
 | 
			
		||||
                [
 | 
			
		||||
                    mx.conv_general(x.squeeze(-2), w.squeeze(0), flip=True)
 | 
			
		||||
                    for x, w in zip(xs, ws)
 | 
			
		||||
                ],
 | 
			
		||||
                axis=-1,
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
        w = mx.random.normal(shape=(2, 3, 1))
 | 
			
		||||
        x = mx.random.normal(shape=(1, 5, 2))
 | 
			
		||||
        cotans = (mx.ones(shape=(1, 3, 2)),)
 | 
			
		||||
        grads = mx.vjp(fn, (x, w), cotans)[1]
 | 
			
		||||
        expected = mx.vjp(fn_gt, (x, w), cotans)[1]
 | 
			
		||||
        self.assertTrue(mx.allclose(expected[0], grads[0]))
 | 
			
		||||
        self.assertTrue(mx.allclose(expected[1], grads[1]))
 | 
			
		||||
 | 
			
		||||
        w = mx.random.normal(shape=(2, 3, 2))
 | 
			
		||||
        x = mx.random.normal(shape=(1, 5, 4))
 | 
			
		||||
        cotans = (mx.ones(shape=(1, 3, 2)),)
 | 
			
		||||
        grads = mx.vjp(fn, (x, w), cotans)[1]
 | 
			
		||||
        expected = mx.vjp(fn_gt, (x, w), cotans)[1]
 | 
			
		||||
        self.assertTrue(mx.allclose(expected[0], grads[0]))
 | 
			
		||||
        self.assertTrue(mx.allclose(expected[1], grads[1]))
 | 
			
		||||
 | 
			
		||||
        # Test 2D
 | 
			
		||||
        w = mx.random.normal(shape=(2, 3, 3, 1))
 | 
			
		||||
        x = mx.random.normal(shape=(1, 5, 5, 2))
 | 
			
		||||
        cotans = (mx.ones(shape=(1, 3, 3, 2)),)
 | 
			
		||||
        grads = mx.vjp(fn, (x, w), cotans)[1]
 | 
			
		||||
        expected = mx.vjp(fn_gt, (x, w), cotans)[1]
 | 
			
		||||
        self.assertTrue(mx.allclose(expected[0], grads[0]))
 | 
			
		||||
        self.assertTrue(mx.allclose(expected[1], grads[1]))
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
if __name__ == "__main__":
 | 
			
		||||
    unittest.main()
 | 
			
		||||
 
 | 
			
		||||
		Reference in New Issue
	
	Block a user