mirror of
https://github.com/ml-explore/mlx.git
synced 2025-10-18 15:28:16 +08:00
Transposed Convolution (#1245)
* initial implementation for conv_transpose ran pre-commit implemented conv_transpose updated conv_general docstring updated conv_general docstring updated code comments removed commented run_conv_checks updated acknowledgments added missing entry to ops.rst added op to nn.layers resolved merge conflicts * removed ConvolutionTranspose primitive as suggested by reviewer removed ConvolutionTranspose primitive as suggested by reviewer * remove transpose flag, add another test --------- Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:

committed by
GitHub

parent
ba3e913c7a
commit
efeb9c0f02
@@ -866,6 +866,37 @@ class TestConv(mlx_tests.MLXTestCase):
|
||||
flip=flip,
|
||||
)
|
||||
|
||||
def test_conv_general_flip_grad(self):
|
||||
for s in (1, 2):
|
||||
w = mx.random.normal(shape=(1, 2, 2, 1))
|
||||
x = mx.random.normal(shape=(1, 2, 2, 1))
|
||||
|
||||
def conv_t(w):
|
||||
return mx.conv_general(
|
||||
x,
|
||||
w,
|
||||
stride=1,
|
||||
padding=(1, 1),
|
||||
kernel_dilation=1,
|
||||
input_dilation=s,
|
||||
flip=True,
|
||||
)
|
||||
|
||||
cotan = mx.random.normal(shape=(1, 2 + s, 2 + s, 1))
|
||||
|
||||
dw = mx.vjp(conv_t, (w,), (cotan,))[1][0]
|
||||
|
||||
x = x.squeeze()
|
||||
cotan = cotan.squeeze()
|
||||
dw = dw.squeeze()
|
||||
|
||||
dw00 = (cotan[:-1:s, :-1:s] * x).sum()
|
||||
dw01 = (cotan[:-1:s, 1::s] * x).sum()
|
||||
dw10 = (cotan[1::s, :-1:s] * x).sum()
|
||||
dw11 = (cotan[1::s, 1::s] * x).sum()
|
||||
expected = mx.array([[dw00, dw01], [dw10, dw11]])
|
||||
self.assertTrue(mx.allclose(dw, expected))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
601
python/tests/test_conv_transpose.py
Normal file
601
python/tests/test_conv_transpose.py
Normal file
@@ -0,0 +1,601 @@
|
||||
# Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
import math
|
||||
import unittest
|
||||
from itertools import permutations
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx_tests
|
||||
import numpy as np
|
||||
|
||||
try:
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
has_torch = True
|
||||
except ImportError as e:
|
||||
has_torch = False
|
||||
|
||||
|
||||
class TestConvTranspose(mlx_tests.MLXTestCase):
|
||||
@unittest.skipIf(not has_torch, "requires Torch")
|
||||
def test_torch_conv_transpose_1D(self):
|
||||
def run_conv_transpose_1D(
|
||||
N,
|
||||
C,
|
||||
O,
|
||||
iH,
|
||||
kH,
|
||||
stride,
|
||||
padding,
|
||||
output_padding=0,
|
||||
dilation=1,
|
||||
groups=1,
|
||||
dtype="float32",
|
||||
atol=1e-5,
|
||||
):
|
||||
with self.subTest(
|
||||
dtype=dtype,
|
||||
N=N,
|
||||
C=C,
|
||||
O=O,
|
||||
iH=iH,
|
||||
kH=kH,
|
||||
stride=stride,
|
||||
padding=padding,
|
||||
dilation=dilation,
|
||||
groups=groups,
|
||||
):
|
||||
np_dtype = getattr(np, dtype)
|
||||
np.random.seed(0)
|
||||
in_np = np.random.normal(0, 1.0 / C, (N, iH, C)).astype(np_dtype)
|
||||
wt_np = np.random.normal(0, 1.0 / C, (O, kH, int(C / groups))).astype(
|
||||
np_dtype
|
||||
)
|
||||
|
||||
in_mx, wt_mx = map(mx.array, (in_np, wt_np))
|
||||
in_pt = torch.from_numpy(in_np.transpose(0, 2, 1))
|
||||
wt_pt = torch.from_numpy(wt_np.transpose(2, 0, 1))
|
||||
|
||||
out_mx = mx.conv_transpose1d(
|
||||
in_mx,
|
||||
wt_mx,
|
||||
stride=stride,
|
||||
padding=padding,
|
||||
dilation=dilation,
|
||||
groups=groups,
|
||||
)
|
||||
out_pt = torch.conv_transpose1d(
|
||||
in_pt,
|
||||
wt_pt,
|
||||
stride=stride,
|
||||
padding=padding,
|
||||
dilation=dilation,
|
||||
groups=groups,
|
||||
)
|
||||
out_pt = torch.transpose(out_pt, 2, 1)
|
||||
|
||||
self.assertEqual(out_pt.shape, out_mx.shape)
|
||||
self.assertTrue(np.allclose(out_pt.numpy(), out_mx, atol=atol))
|
||||
|
||||
for dtype in ("float32",):
|
||||
for N, C, O in (
|
||||
(1, 1, 1),
|
||||
(1, 6, 1),
|
||||
(1, 1, 6),
|
||||
(4, 32, 64),
|
||||
):
|
||||
for iH, kH, stride, padding in (
|
||||
(1, 1, 1, 0),
|
||||
(3, 3, 1, 0),
|
||||
(31, 5, 5, 2),
|
||||
):
|
||||
run_conv_transpose_1D(N, C, O, iH, kH, stride, padding, dtype=dtype)
|
||||
|
||||
# Groups tests
|
||||
N, C, O = (4, 32, 64)
|
||||
for iH, kH, stride, padding in (
|
||||
(1, 1, 1, 0),
|
||||
(3, 3, 1, 0),
|
||||
(31, 5, 5, 2),
|
||||
):
|
||||
for group in (1,):
|
||||
run_conv_transpose_1D(
|
||||
N, C, O, iH, kH, stride, padding, groups=group, dtype=dtype
|
||||
)
|
||||
|
||||
# Strided inputs tests
|
||||
for tpose_in, tpose_wt in (
|
||||
((0, 2, 1), (0, 1, 2)),
|
||||
((0, 2, 1), (0, 2, 1)),
|
||||
):
|
||||
with self.subTest(name="strided", tpose_in=tpose_in, tpose_wt=tpose_wt):
|
||||
in_np = np.random.normal(0, 1.0 / 16, (16, 16, 16)).astype(np.float32)
|
||||
wt_np = np.random.normal(0, 1.0 / 16, (16, 16, 16)).astype(np.float32)
|
||||
|
||||
in_mx, wt_mx = map(mx.array, (in_np, wt_np))
|
||||
in_mx_t = mx.transpose(in_mx, tpose_in)
|
||||
wt_mx_t = mx.transpose(wt_mx, tpose_wt)
|
||||
out_mx = mx.conv_transpose1d(in_mx_t, wt_mx_t)
|
||||
|
||||
in_pt = torch.from_numpy(in_np.transpose(tpose_in).transpose(0, 2, 1))
|
||||
wt_pt = torch.from_numpy(wt_np.transpose(tpose_wt).transpose(2, 0, 1))
|
||||
|
||||
out_pt = torch.conv_transpose1d(in_pt, wt_pt)
|
||||
out_pt = torch.transpose(out_pt, 2, 1)
|
||||
|
||||
self.assertEqual(out_pt.shape, out_mx.shape)
|
||||
self.assertTrue(np.allclose(out_pt.numpy(), out_mx, atol=1e-5))
|
||||
|
||||
@unittest.skipIf(not has_torch, "requires Torch")
|
||||
def test_torch_conv_transpose_1D_grad(self):
|
||||
def run_conv_transpose1D_grad(
|
||||
N,
|
||||
C,
|
||||
O,
|
||||
iH,
|
||||
kH,
|
||||
stride,
|
||||
padding,
|
||||
dilation=1,
|
||||
groups=1,
|
||||
dtype="float32",
|
||||
atol=1e-5,
|
||||
):
|
||||
with self.subTest(
|
||||
dtype=dtype,
|
||||
N=N,
|
||||
C=C,
|
||||
O=O,
|
||||
iH=iH,
|
||||
kH=kH,
|
||||
stride=stride,
|
||||
padding=padding,
|
||||
dilation=dilation,
|
||||
groups=groups,
|
||||
):
|
||||
np_dtype = getattr(np, dtype)
|
||||
np.random.seed(0)
|
||||
# oH = 1 + ((iH + 2 * padding - dilation * (kH - 1) - 1) // stride)
|
||||
|
||||
in_np = np.random.normal(0, 1.0 / C, (N, iH, C)).astype(np_dtype)
|
||||
wt_np = np.random.normal(0, 1.0 / C, (O, kH, C)).astype(np_dtype)
|
||||
|
||||
in_mx, wt_mx = map(mx.array, (in_np, wt_np))
|
||||
in_pt = torch.from_numpy(in_np.transpose(0, 2, 1)).requires_grad_(True)
|
||||
wt_pt = torch.from_numpy(wt_np.transpose(2, 0, 1)).requires_grad_(True)
|
||||
|
||||
out_pt = F.conv_transpose1d(
|
||||
in_pt, wt_pt, stride=stride, padding=padding, dilation=dilation
|
||||
)
|
||||
|
||||
# use torch to compute ct
|
||||
out_pt.retain_grad()
|
||||
(out_pt - torch.randn_like(out_pt)).abs().sum().backward()
|
||||
|
||||
pt_grad_in = in_pt.grad.permute(0, 2, 1).numpy()
|
||||
pt_grad_wt = wt_pt.grad.permute(1, 2, 0).numpy()
|
||||
|
||||
ct_mx = mx.array(out_pt.grad.numpy().transpose(0, 2, 1))
|
||||
|
||||
def f(a, b):
|
||||
return mx.conv_transpose1d(
|
||||
a,
|
||||
b,
|
||||
stride=stride,
|
||||
padding=padding,
|
||||
dilation=dilation,
|
||||
groups=groups,
|
||||
)
|
||||
|
||||
_, outs_mx = mx.vjp(
|
||||
f,
|
||||
[
|
||||
in_mx,
|
||||
wt_mx,
|
||||
],
|
||||
[
|
||||
ct_mx,
|
||||
],
|
||||
)
|
||||
|
||||
mx_grad_in, mx_grad_wt = outs_mx
|
||||
|
||||
self.assertEqual(pt_grad_in.shape, mx_grad_in.shape)
|
||||
self.assertEqual(in_mx.shape, mx_grad_in.shape)
|
||||
self.assertTrue(np.allclose(pt_grad_in, mx_grad_in, atol=atol))
|
||||
|
||||
self.assertEqual(pt_grad_wt.shape, mx_grad_wt.shape)
|
||||
self.assertEqual(wt_mx.shape, mx_grad_wt.shape)
|
||||
self.assertTrue(np.allclose(pt_grad_wt, mx_grad_wt, atol=atol))
|
||||
|
||||
for dtype in ("float32",):
|
||||
for N, C, O in (
|
||||
(1, 1, 1),
|
||||
(1, 6, 1),
|
||||
(1, 1, 6),
|
||||
(4, 32, 64),
|
||||
):
|
||||
for iH, kH, stride, padding in (
|
||||
(1, 1, 1, 0),
|
||||
(3, 3, 1, 0),
|
||||
(31, 5, 5, 2),
|
||||
):
|
||||
run_conv_transpose1D_grad(
|
||||
N, C, O, iH, kH, stride, padding, dtype=dtype
|
||||
)
|
||||
|
||||
@unittest.skipIf(not has_torch, "requires Torch")
|
||||
def test_torch_conv_transpose_2D(self):
|
||||
def run_conv_transpose2D(
|
||||
N,
|
||||
C,
|
||||
O,
|
||||
idim,
|
||||
kdim,
|
||||
stride,
|
||||
padding,
|
||||
dilation=(1, 1),
|
||||
groups=1,
|
||||
dtype="float32",
|
||||
atol=1e-5,
|
||||
):
|
||||
with self.subTest(
|
||||
dtype=dtype,
|
||||
N=N,
|
||||
C=C,
|
||||
O=O,
|
||||
idim=idim,
|
||||
kdim=kdim,
|
||||
stride=stride,
|
||||
padding=padding,
|
||||
dilation=dilation,
|
||||
groups=groups,
|
||||
):
|
||||
np_dtype = getattr(np, dtype)
|
||||
np.random.seed(0)
|
||||
iH, iW = idim
|
||||
kH, kW = kdim
|
||||
scale = 1.0 / math.sqrt(kH * kW * C)
|
||||
in_np = np.random.normal(0.0, scale, (N, iH, iW, C)).astype(np_dtype)
|
||||
wt_np = np.random.normal(0.0, 1.0, (O, kH, kW, int(C / groups))).astype(
|
||||
np_dtype
|
||||
)
|
||||
|
||||
in_mx, wt_mx = map(mx.array, (in_np, wt_np))
|
||||
in_pt = torch.from_numpy(in_np.transpose(0, 3, 1, 2)).to("cpu")
|
||||
wt_pt = torch.from_numpy(wt_np.transpose(3, 0, 1, 2)).to("cpu")
|
||||
|
||||
out_mx = mx.conv_transpose2d(
|
||||
in_mx,
|
||||
wt_mx,
|
||||
stride=stride,
|
||||
padding=padding,
|
||||
dilation=dilation,
|
||||
groups=groups,
|
||||
)
|
||||
out_pt = torch.conv_transpose2d(
|
||||
in_pt,
|
||||
wt_pt,
|
||||
stride=stride,
|
||||
padding=padding,
|
||||
dilation=dilation,
|
||||
groups=groups,
|
||||
)
|
||||
out_pt = torch.permute(out_pt, (0, 2, 3, 1)).numpy(force=True)
|
||||
|
||||
self.assertEqual(out_pt.shape, out_mx.shape)
|
||||
self.assertTrue(np.allclose(out_pt, out_mx, atol=atol))
|
||||
|
||||
for dtype in ("float32",):
|
||||
for N, C, O in (
|
||||
(1, 1, 1),
|
||||
(1, 6, 1),
|
||||
(1, 1, 6),
|
||||
(4, 32, 64),
|
||||
):
|
||||
for idim, kdim, stride, padding in (
|
||||
((1, 1), (1, 1), (1, 1), (0, 0)),
|
||||
((3, 3), (3, 1), (1, 1), (0, 0)),
|
||||
((31, 31), (5, 5), (5, 5), (2, 2)),
|
||||
):
|
||||
run_conv_transpose2D(
|
||||
N, C, O, idim, kdim, stride, padding, dtype=dtype
|
||||
)
|
||||
|
||||
# Groups tests
|
||||
N, C, O = (4, 32, 64)
|
||||
for idim, kdim, stride, padding in (
|
||||
((1, 1), (1, 1), (1, 1), (0, 0)),
|
||||
((3, 3), (3, 1), (1, 1), (0, 0)),
|
||||
((31, 31), (5, 5), (5, 5), (2, 2)),
|
||||
):
|
||||
for group in (1,):
|
||||
run_conv_transpose2D(
|
||||
N, C, O, idim, kdim, stride, padding, groups=group, dtype=dtype
|
||||
)
|
||||
|
||||
@unittest.skipIf(not has_torch, "requires Torch")
|
||||
def test_torch_conv_transpose_2D_grad(self):
|
||||
def run_conv_transpose2D_grad(
|
||||
N,
|
||||
C,
|
||||
O,
|
||||
idim,
|
||||
kdim,
|
||||
stride,
|
||||
padding,
|
||||
dilation=(1, 1),
|
||||
groups=1,
|
||||
dtype="float32",
|
||||
atol=1e-5,
|
||||
):
|
||||
with self.subTest(
|
||||
dtype=dtype,
|
||||
N=N,
|
||||
C=C,
|
||||
O=O,
|
||||
idim=idim,
|
||||
kdim=kdim,
|
||||
stride=stride,
|
||||
padding=padding,
|
||||
dilation=dilation,
|
||||
groups=groups,
|
||||
):
|
||||
np_dtype = getattr(np, dtype)
|
||||
np.random.seed(0)
|
||||
iH, iW = idim
|
||||
kH, kW = kdim
|
||||
scale = 1.0 / math.sqrt(kH * kW * C * O)
|
||||
|
||||
in_np = np.random.normal(0.0, scale, (N, iH, iW, C)).astype(np_dtype)
|
||||
wt_np = np.random.normal(0.0, scale, (O, kH, kW, C)).astype(np_dtype)
|
||||
|
||||
in_mx, wt_mx = map(mx.array, (in_np, wt_np))
|
||||
in_pt = torch.from_numpy(in_np.transpose(0, 3, 1, 2)).requires_grad_(
|
||||
True
|
||||
)
|
||||
wt_pt = torch.from_numpy(wt_np.transpose(3, 0, 1, 2)).requires_grad_(
|
||||
True
|
||||
)
|
||||
|
||||
out_pt = F.conv_transpose2d(
|
||||
in_pt, wt_pt, stride=stride, padding=padding, dilation=dilation
|
||||
)
|
||||
|
||||
# use torch to compute ct
|
||||
out_pt.retain_grad()
|
||||
(out_pt - torch.randn_like(out_pt)).abs().sum().backward()
|
||||
|
||||
pt_grad_in = in_pt.grad.permute(0, 2, 3, 1).numpy()
|
||||
pt_grad_wt = wt_pt.grad.permute(1, 2, 3, 0).numpy()
|
||||
|
||||
ct_mx = mx.array(out_pt.grad.numpy().transpose(0, 2, 3, 1))
|
||||
|
||||
def f(a, b):
|
||||
return mx.conv_transpose2d(
|
||||
a,
|
||||
b,
|
||||
stride=stride,
|
||||
padding=padding,
|
||||
dilation=dilation,
|
||||
groups=groups,
|
||||
)
|
||||
|
||||
_, outs_mx = mx.vjp(
|
||||
f,
|
||||
[in_mx, wt_mx],
|
||||
[ct_mx],
|
||||
)
|
||||
|
||||
mx_grad_in, mx_grad_wt = outs_mx
|
||||
|
||||
self.assertEqual(pt_grad_in.shape, mx_grad_in.shape)
|
||||
self.assertEqual(in_mx.shape, mx_grad_in.shape)
|
||||
self.assertTrue(np.allclose(pt_grad_in, mx_grad_in, atol=atol))
|
||||
|
||||
self.assertEqual(pt_grad_wt.shape, mx_grad_wt.shape)
|
||||
self.assertEqual(wt_mx.shape, mx_grad_wt.shape)
|
||||
self.assertTrue(np.allclose(pt_grad_wt, mx_grad_wt, atol=atol))
|
||||
|
||||
for dtype in ("float32",):
|
||||
for N, C, O in ((1, 1, 1), (1, 6, 1), (1, 1, 6), (4, 32, 64), (4, 16, 32)):
|
||||
for idim, kdim, stride, padding, dilation in (
|
||||
((1, 1), (1, 1), (1, 1), (0, 0), (1, 1)),
|
||||
((3, 3), (3, 1), (1, 1), (0, 0), (1, 1)),
|
||||
((31, 31), (5, 5), (5, 5), (2, 2), (1, 1)),
|
||||
((32, 32), (3, 3), (2, 2), (1, 1), (1, 1)),
|
||||
((31, 31), (5, 5), (5, 5), (2, 2), (3, 2)),
|
||||
((32, 32), (3, 3), (2, 2), (1, 1), (3, 2)),
|
||||
):
|
||||
run_conv_transpose2D_grad(
|
||||
N, C, O, idim, kdim, stride, padding, dilation, dtype=dtype
|
||||
)
|
||||
|
||||
@unittest.skipIf(not has_torch, "requires Torch")
|
||||
def test_torch_conv_transpose_3D(self):
|
||||
def run_conv_transpose3D(
|
||||
N,
|
||||
C,
|
||||
O,
|
||||
idim,
|
||||
kdim,
|
||||
stride,
|
||||
padding,
|
||||
dilation=(1, 1, 1),
|
||||
groups=1,
|
||||
dtype="float32",
|
||||
atol=1e-5,
|
||||
):
|
||||
with self.subTest(
|
||||
dtype=dtype,
|
||||
N=N,
|
||||
C=C,
|
||||
O=O,
|
||||
idim=idim,
|
||||
kdim=kdim,
|
||||
stride=stride,
|
||||
padding=padding,
|
||||
dilation=dilation,
|
||||
groups=groups,
|
||||
):
|
||||
np_dtype = getattr(np, dtype)
|
||||
np.random.seed(0)
|
||||
iD, iH, iW = idim
|
||||
kD, kH, kW = kdim
|
||||
scale = 1.0 / math.sqrt(kD * kH * kW * C * O)
|
||||
in_np = np.random.normal(0.0, scale, (N, iD, iH, iW, C)).astype(
|
||||
np_dtype
|
||||
)
|
||||
wt_np = np.random.normal(0.0, 1.0, (O, kD, kH, kW, C)).astype(np_dtype)
|
||||
|
||||
in_mx, wt_mx = map(mx.array, (in_np, wt_np))
|
||||
in_pt = torch.from_numpy(in_np.transpose(0, 4, 1, 2, 3))
|
||||
wt_pt = torch.from_numpy(wt_np.transpose(4, 0, 1, 2, 3))
|
||||
|
||||
out_mx = mx.conv_transpose3d(
|
||||
in_mx,
|
||||
wt_mx,
|
||||
stride=stride,
|
||||
padding=padding,
|
||||
dilation=dilation,
|
||||
groups=groups,
|
||||
)
|
||||
out_pt = torch.conv_transpose3d(
|
||||
in_pt,
|
||||
wt_pt,
|
||||
stride=stride,
|
||||
padding=padding,
|
||||
dilation=dilation,
|
||||
groups=groups,
|
||||
)
|
||||
out_pt = torch.permute(out_pt, (0, 2, 3, 4, 1)).numpy(force=True)
|
||||
|
||||
self.assertEqual(out_pt.shape, out_mx.shape)
|
||||
self.assertTrue(np.allclose(out_pt, out_mx, atol=atol))
|
||||
|
||||
for dtype in ("float32",):
|
||||
for N, C, O in (
|
||||
(1, 1, 1),
|
||||
(1, 6, 1),
|
||||
(1, 1, 6),
|
||||
(2, 8, 16),
|
||||
):
|
||||
for idim, kdim, stride, padding in (
|
||||
((1, 1, 1), (1, 1, 1), (1, 1, 1), (0, 0, 0)),
|
||||
((3, 3, 3), (3, 1, 1), (1, 1, 1), (0, 0, 0)),
|
||||
((15, 15, 15), (3, 3, 3), (3, 3, 3), (2, 2, 2)),
|
||||
):
|
||||
run_conv_transpose3D(
|
||||
N, C, O, idim, kdim, stride, padding, dtype=dtype
|
||||
)
|
||||
|
||||
@unittest.skipIf(not has_torch, "requires Torch")
|
||||
def test_torch_conv_transpose_3D_grad(self):
|
||||
def run_conv_transpose3D_grad(
|
||||
N,
|
||||
C,
|
||||
O,
|
||||
idim,
|
||||
kdim,
|
||||
stride,
|
||||
padding,
|
||||
dilation=(1, 1, 1),
|
||||
groups=1,
|
||||
dtype="float32",
|
||||
atol=1e-4,
|
||||
):
|
||||
with self.subTest(
|
||||
dtype=dtype,
|
||||
N=N,
|
||||
C=C,
|
||||
O=O,
|
||||
idim=idim,
|
||||
kdim=kdim,
|
||||
stride=stride,
|
||||
padding=padding,
|
||||
dilation=dilation,
|
||||
groups=groups,
|
||||
):
|
||||
np_dtype = getattr(np, dtype)
|
||||
np.random.seed(0)
|
||||
iD, iH, iW = idim
|
||||
kD, kH, kW = kdim
|
||||
scale = 1.0 / math.sqrt(kD * kH * kW * C * O)
|
||||
|
||||
in_np = np.random.normal(0.0, scale, (N, iD, iH, iW, C)).astype(
|
||||
np_dtype
|
||||
)
|
||||
wt_np = np.random.normal(0.0, scale, (O, kD, kH, kW, C)).astype(
|
||||
np_dtype
|
||||
)
|
||||
|
||||
in_mx, wt_mx = map(mx.array, (in_np, wt_np))
|
||||
in_pt = torch.from_numpy(in_np.transpose(0, 4, 1, 2, 3)).requires_grad_(
|
||||
True
|
||||
)
|
||||
wt_pt = torch.from_numpy(wt_np.transpose(4, 0, 1, 2, 3)).requires_grad_(
|
||||
True
|
||||
)
|
||||
|
||||
out_pt = F.conv_transpose3d(
|
||||
in_pt,
|
||||
wt_pt,
|
||||
stride=stride,
|
||||
padding=padding,
|
||||
dilation=dilation,
|
||||
groups=groups,
|
||||
)
|
||||
|
||||
# use torch to compute ct
|
||||
out_pt.retain_grad()
|
||||
(out_pt - torch.randn_like(out_pt)).abs().sum().backward()
|
||||
|
||||
pt_grad_in = in_pt.grad.permute(0, 2, 3, 4, 1).numpy()
|
||||
pt_grad_wt = wt_pt.grad.permute(1, 2, 3, 4, 0).numpy()
|
||||
|
||||
ct_mx = mx.array(out_pt.grad.numpy().transpose(0, 2, 3, 4, 1))
|
||||
|
||||
def f(a, b):
|
||||
return mx.conv_transpose3d(
|
||||
a,
|
||||
b,
|
||||
stride=stride,
|
||||
padding=padding,
|
||||
dilation=dilation,
|
||||
groups=groups,
|
||||
)
|
||||
|
||||
_, outs_mx = mx.vjp(
|
||||
f,
|
||||
[in_mx, wt_mx],
|
||||
[ct_mx],
|
||||
)
|
||||
|
||||
mx_grad_in, mx_grad_wt = outs_mx
|
||||
|
||||
self.assertEqual(pt_grad_in.shape, mx_grad_in.shape)
|
||||
self.assertEqual(in_mx.shape, mx_grad_in.shape)
|
||||
self.assertTrue(np.allclose(pt_grad_in, mx_grad_in, atol=atol))
|
||||
|
||||
self.assertEqual(pt_grad_wt.shape, mx_grad_wt.shape)
|
||||
self.assertEqual(wt_mx.shape, mx_grad_wt.shape)
|
||||
self.assertTrue(np.allclose(pt_grad_wt, mx_grad_wt, atol=atol))
|
||||
|
||||
for dtype in ("float32",):
|
||||
for N, C, O in ((1, 1, 1), (1, 6, 1), (1, 1, 6), (2, 4, 8), (2, 8, 16)):
|
||||
for idim, kdim, stride, padding, dilation in (
|
||||
((1, 1, 1), (1, 1, 1), (1, 1, 1), (0, 0, 0), (1, 1, 1)),
|
||||
((3, 3, 3), (3, 1, 1), (1, 1, 1), (0, 0, 0), (1, 1, 1)),
|
||||
((15, 15, 15), (5, 5, 5), (5, 5, 5), (2, 2, 2), (1, 1, 1)),
|
||||
((16, 16, 16), (3, 3, 3), (2, 2, 2), (1, 1, 1), (1, 1, 1)),
|
||||
((15, 15, 15), (5, 5, 5), (5, 5, 5), (2, 2, 2), (3, 2, 2)),
|
||||
((16, 16, 16), (3, 3, 3), (2, 2, 2), (1, 1, 1), (3, 2, 2)),
|
||||
):
|
||||
run_conv_transpose3D_grad(
|
||||
N, C, O, idim, kdim, stride, padding, dilation, dtype=dtype
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
Reference in New Issue
Block a user