Transposed Convolution (#1245)

* initial implementation for conv_transpose

ran pre-commit

implemented conv_transpose

updated conv_general docstring

updated conv_general docstring

updated code comments

removed commented run_conv_checks

updated acknowledgments

added missing entry to ops.rst

added op to nn.layers

resolved merge conflicts

* removed ConvolutionTranspose primitive as suggested by reviewer

removed ConvolutionTranspose primitive as suggested by reviewer

* remove transpose flag, add another test

---------

Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
Max-Heinrich Laves
2024-09-07 04:52:38 +02:00
committed by GitHub
parent ba3e913c7a
commit efeb9c0f02
15 changed files with 1337 additions and 45 deletions

View File

@@ -866,6 +866,37 @@ class TestConv(mlx_tests.MLXTestCase):
flip=flip,
)
def test_conv_general_flip_grad(self):
for s in (1, 2):
w = mx.random.normal(shape=(1, 2, 2, 1))
x = mx.random.normal(shape=(1, 2, 2, 1))
def conv_t(w):
return mx.conv_general(
x,
w,
stride=1,
padding=(1, 1),
kernel_dilation=1,
input_dilation=s,
flip=True,
)
cotan = mx.random.normal(shape=(1, 2 + s, 2 + s, 1))
dw = mx.vjp(conv_t, (w,), (cotan,))[1][0]
x = x.squeeze()
cotan = cotan.squeeze()
dw = dw.squeeze()
dw00 = (cotan[:-1:s, :-1:s] * x).sum()
dw01 = (cotan[:-1:s, 1::s] * x).sum()
dw10 = (cotan[1::s, :-1:s] * x).sum()
dw11 = (cotan[1::s, 1::s] * x).sum()
expected = mx.array([[dw00, dw01], [dw10, dw11]])
self.assertTrue(mx.allclose(dw, expected))
if __name__ == "__main__":
unittest.main()