mirror of
https://github.com/ml-explore/mlx.git
synced 2025-10-19 08:38:09 +08:00
Transposed Convolution (#1245)
* initial implementation for conv_transpose ran pre-commit implemented conv_transpose updated conv_general docstring updated conv_general docstring updated code comments removed commented run_conv_checks updated acknowledgments added missing entry to ops.rst added op to nn.layers resolved merge conflicts * removed ConvolutionTranspose primitive as suggested by reviewer removed ConvolutionTranspose primitive as suggested by reviewer * remove transpose flag, add another test --------- Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:

committed by
GitHub

parent
ba3e913c7a
commit
efeb9c0f02
@@ -866,6 +866,37 @@ class TestConv(mlx_tests.MLXTestCase):
|
||||
flip=flip,
|
||||
)
|
||||
|
||||
def test_conv_general_flip_grad(self):
|
||||
for s in (1, 2):
|
||||
w = mx.random.normal(shape=(1, 2, 2, 1))
|
||||
x = mx.random.normal(shape=(1, 2, 2, 1))
|
||||
|
||||
def conv_t(w):
|
||||
return mx.conv_general(
|
||||
x,
|
||||
w,
|
||||
stride=1,
|
||||
padding=(1, 1),
|
||||
kernel_dilation=1,
|
||||
input_dilation=s,
|
||||
flip=True,
|
||||
)
|
||||
|
||||
cotan = mx.random.normal(shape=(1, 2 + s, 2 + s, 1))
|
||||
|
||||
dw = mx.vjp(conv_t, (w,), (cotan,))[1][0]
|
||||
|
||||
x = x.squeeze()
|
||||
cotan = cotan.squeeze()
|
||||
dw = dw.squeeze()
|
||||
|
||||
dw00 = (cotan[:-1:s, :-1:s] * x).sum()
|
||||
dw01 = (cotan[:-1:s, 1::s] * x).sum()
|
||||
dw10 = (cotan[1::s, :-1:s] * x).sum()
|
||||
dw11 = (cotan[1::s, 1::s] * x).sum()
|
||||
expected = mx.array([[dw00, dw01], [dw10, dw11]])
|
||||
self.assertTrue(mx.allclose(dw, expected))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
Reference in New Issue
Block a user