mlx/python/tests/test_conv_transpose.py

811 lines
28 KiB
Python

# Copyright © 2023-2024 Apple Inc.
import math
import unittest
from itertools import permutations
import mlx.core as mx
import mlx_tests
import numpy as np
try:
import torch
import torch.nn.functional as F
has_torch = True
except ImportError as e:
has_torch = False
class TestConvTranspose(mlx_tests.MLXTestCase):
@unittest.skipIf(not has_torch, "requires Torch")
def test_torch_conv_transpose_1D(self):
def run_conv_transpose_1D(
N,
C,
O,
iH,
kH,
stride,
padding,
output_padding=0,
dilation=1,
groups=1,
dtype="float32",
atol=1e-5,
):
with self.subTest(
dtype=dtype,
N=N,
C=C,
O=O,
iH=iH,
kH=kH,
stride=stride,
padding=padding,
dilation=dilation,
groups=groups,
):
np_dtype = getattr(np, dtype)
np.random.seed(0)
in_np = np.random.normal(0, 1.0 / C, (N, iH, C)).astype(np_dtype)
wt_np = np.random.normal(0, 1.0 / C, (O, kH, int(C / groups))).astype(
np_dtype
)
in_mx, wt_mx = map(mx.array, (in_np, wt_np))
in_pt = torch.from_numpy(in_np.transpose(0, 2, 1))
wt_pt = torch.from_numpy(wt_np.transpose(2, 0, 1))
out_mx = mx.conv_transpose1d(
in_mx,
wt_mx,
stride=stride,
padding=padding,
dilation=dilation,
groups=groups,
)
out_pt = torch.conv_transpose1d(
in_pt,
wt_pt,
stride=stride,
padding=padding,
dilation=dilation,
groups=groups,
)
out_pt = torch.transpose(out_pt, 2, 1)
self.assertEqual(out_pt.shape, out_mx.shape)
self.assertTrue(np.allclose(out_pt.numpy(), out_mx, atol=atol))
for dtype in ("float32",):
for N, C, O in (
(1, 1, 1),
(1, 6, 1),
(1, 1, 6),
(4, 32, 64),
):
for iH, kH, stride, padding in (
(1, 1, 1, 0),
(3, 3, 1, 0),
(31, 5, 5, 2),
):
run_conv_transpose_1D(N, C, O, iH, kH, stride, padding, dtype=dtype)
# Groups tests
N, C, O = (4, 32, 64)
for iH, kH, stride, padding in (
(1, 1, 1, 0),
(3, 3, 1, 0),
(31, 5, 5, 2),
):
for group in (1,):
run_conv_transpose_1D(
N, C, O, iH, kH, stride, padding, groups=group, dtype=dtype
)
# Strided inputs tests
for tpose_in, tpose_wt in (
((0, 2, 1), (0, 1, 2)),
((0, 2, 1), (0, 2, 1)),
):
with self.subTest(name="strided", tpose_in=tpose_in, tpose_wt=tpose_wt):
in_np = np.random.normal(0, 1.0 / 16, (16, 16, 16)).astype(np.float32)
wt_np = np.random.normal(0, 1.0 / 16, (16, 16, 16)).astype(np.float32)
in_mx, wt_mx = map(mx.array, (in_np, wt_np))
in_mx_t = mx.transpose(in_mx, tpose_in)
wt_mx_t = mx.transpose(wt_mx, tpose_wt)
out_mx = mx.conv_transpose1d(in_mx_t, wt_mx_t)
in_pt = torch.from_numpy(in_np.transpose(tpose_in).transpose(0, 2, 1))
wt_pt = torch.from_numpy(wt_np.transpose(tpose_wt).transpose(2, 0, 1))
out_pt = torch.conv_transpose1d(in_pt, wt_pt)
out_pt = torch.transpose(out_pt, 2, 1)
self.assertEqual(out_pt.shape, out_mx.shape)
self.assertTrue(np.allclose(out_pt.numpy(), out_mx, atol=1e-5))
@unittest.skipIf(not has_torch, "requires Torch")
def test_torch_conv_transpose_1D_grad(self):
def run_conv_transpose1D_grad(
N,
C,
O,
iH,
kH,
stride,
padding,
dilation=1,
groups=1,
dtype="float32",
atol=1e-5,
):
with self.subTest(
dtype=dtype,
N=N,
C=C,
O=O,
iH=iH,
kH=kH,
stride=stride,
padding=padding,
dilation=dilation,
groups=groups,
):
np_dtype = getattr(np, dtype)
np.random.seed(0)
# oH = 1 + ((iH + 2 * padding - dilation * (kH - 1) - 1) // stride)
in_np = np.random.normal(0, 1.0 / C, (N, iH, C)).astype(np_dtype)
wt_np = np.random.normal(0, 1.0 / C, (O, kH, C)).astype(np_dtype)
in_mx, wt_mx = map(mx.array, (in_np, wt_np))
in_pt = torch.from_numpy(in_np.transpose(0, 2, 1)).requires_grad_(True)
wt_pt = torch.from_numpy(wt_np.transpose(2, 0, 1)).requires_grad_(True)
out_pt = F.conv_transpose1d(
in_pt, wt_pt, stride=stride, padding=padding, dilation=dilation
)
# use torch to compute ct
out_pt.retain_grad()
out_pt.sum().backward()
pt_grad_in = in_pt.grad.permute(0, 2, 1).numpy()
pt_grad_wt = wt_pt.grad.permute(1, 2, 0).numpy()
ct_mx = mx.array(out_pt.grad.numpy().transpose(0, 2, 1))
def f(a, b):
return mx.conv_transpose1d(
a,
b,
stride=stride,
padding=padding,
dilation=dilation,
groups=groups,
)
_, outs_mx = mx.vjp(
f,
[
in_mx,
wt_mx,
],
[
ct_mx,
],
)
mx_grad_in, mx_grad_wt = outs_mx
self.assertEqual(pt_grad_in.shape, mx_grad_in.shape)
self.assertEqual(in_mx.shape, mx_grad_in.shape)
self.assertTrue(np.allclose(pt_grad_in, mx_grad_in, atol=atol))
self.assertEqual(pt_grad_wt.shape, mx_grad_wt.shape)
self.assertEqual(wt_mx.shape, mx_grad_wt.shape)
self.assertTrue(np.allclose(pt_grad_wt, mx_grad_wt, atol=atol))
for dtype in ("float32",):
for N, C, O in (
(1, 1, 1),
(1, 6, 1),
(1, 1, 6),
(4, 32, 64),
):
for iH, kH, stride, padding in (
(1, 1, 1, 0),
(3, 3, 1, 0),
(31, 5, 5, 2),
):
run_conv_transpose1D_grad(
N, C, O, iH, kH, stride, padding, dtype=dtype
)
@unittest.skipIf(not has_torch, "requires Torch")
def test_torch_conv_transpose_2D(self):
def run_conv_transpose2D(
N,
C,
O,
idim,
kdim,
stride,
padding,
dilation=(1, 1),
groups=1,
dtype="float32",
atol=1e-5,
):
with self.subTest(
dtype=dtype,
N=N,
C=C,
O=O,
idim=idim,
kdim=kdim,
stride=stride,
padding=padding,
dilation=dilation,
groups=groups,
):
np_dtype = getattr(np, dtype)
np.random.seed(0)
iH, iW = idim
kH, kW = kdim
scale = 1.0 / math.sqrt(kH * kW * C)
in_np = np.random.normal(0.0, scale, (N, iH, iW, C)).astype(np_dtype)
wt_np = np.random.normal(0.0, 1.0, (O, kH, kW, int(C / groups))).astype(
np_dtype
)
in_mx, wt_mx = map(mx.array, (in_np, wt_np))
in_pt = torch.from_numpy(in_np.transpose(0, 3, 1, 2)).to("cpu")
wt_pt = torch.from_numpy(wt_np.transpose(3, 0, 1, 2)).to("cpu")
out_mx = mx.conv_transpose2d(
in_mx,
wt_mx,
stride=stride,
padding=padding,
dilation=dilation,
groups=groups,
)
out_pt = torch.conv_transpose2d(
in_pt,
wt_pt,
stride=stride,
padding=padding,
dilation=dilation,
groups=groups,
)
out_pt = torch.permute(out_pt, (0, 2, 3, 1)).numpy(force=True)
self.assertEqual(out_pt.shape, out_mx.shape)
self.assertTrue(np.allclose(out_pt, out_mx, atol=atol))
for dtype in ("float32",):
for N, C, O in (
(1, 1, 1),
(1, 6, 1),
(1, 1, 6),
(4, 32, 64),
):
for idim, kdim, stride, padding in (
((1, 1), (1, 1), (1, 1), (0, 0)),
((3, 3), (3, 1), (1, 1), (0, 0)),
((31, 31), (5, 5), (5, 5), (2, 2)),
):
run_conv_transpose2D(
N, C, O, idim, kdim, stride, padding, dtype=dtype
)
# Groups tests
N, C, O = (4, 32, 64)
for idim, kdim, stride, padding in (
((1, 1), (1, 1), (1, 1), (0, 0)),
((3, 3), (3, 1), (1, 1), (0, 0)),
((31, 31), (5, 5), (5, 5), (2, 2)),
):
for group in (1,):
run_conv_transpose2D(
N, C, O, idim, kdim, stride, padding, groups=group, dtype=dtype
)
@unittest.skipIf(not has_torch, "requires Torch")
def test_torch_conv_transpose_2D_grad(self):
def run_conv_transpose2D_grad(
N,
C,
O,
idim,
kdim,
stride,
padding,
dilation=(1, 1),
groups=1,
dtype="float32",
atol=1e-5,
):
with self.subTest(
dtype=dtype,
N=N,
C=C,
O=O,
idim=idim,
kdim=kdim,
stride=stride,
padding=padding,
dilation=dilation,
groups=groups,
):
np_dtype = getattr(np, dtype)
np.random.seed(0)
iH, iW = idim
kH, kW = kdim
scale = 1.0 / math.sqrt(kH * kW * C * O)
in_np = np.random.normal(0.0, scale, (N, iH, iW, C)).astype(np_dtype)
wt_np = np.random.normal(0.0, scale, (O, kH, kW, C)).astype(np_dtype)
in_mx, wt_mx = map(mx.array, (in_np, wt_np))
in_pt = torch.from_numpy(in_np.transpose(0, 3, 1, 2)).requires_grad_(
True
)
wt_pt = torch.from_numpy(wt_np.transpose(3, 0, 1, 2)).requires_grad_(
True
)
out_pt = F.conv_transpose2d(
in_pt, wt_pt, stride=stride, padding=padding, dilation=dilation
)
# use torch to compute ct
out_pt.retain_grad()
out_pt.sum().backward()
pt_grad_in = in_pt.grad.permute(0, 2, 3, 1).numpy()
pt_grad_wt = wt_pt.grad.permute(1, 2, 3, 0).numpy()
ct_mx = mx.array(out_pt.grad.numpy().transpose(0, 2, 3, 1))
def f(a, b):
return mx.conv_transpose2d(
a,
b,
stride=stride,
padding=padding,
dilation=dilation,
groups=groups,
)
_, outs_mx = mx.vjp(
f,
[in_mx, wt_mx],
[ct_mx],
)
mx_grad_in, mx_grad_wt = outs_mx
self.assertEqual(pt_grad_in.shape, mx_grad_in.shape)
self.assertEqual(in_mx.shape, mx_grad_in.shape)
self.assertTrue(np.allclose(pt_grad_in, mx_grad_in, atol=atol))
self.assertEqual(pt_grad_wt.shape, mx_grad_wt.shape)
self.assertEqual(wt_mx.shape, mx_grad_wt.shape)
self.assertTrue(np.allclose(pt_grad_wt, mx_grad_wt, atol=atol))
for dtype in ("float32",):
for N, C, O in ((1, 1, 1), (1, 6, 1), (1, 1, 6), (4, 32, 64), (4, 16, 32)):
for idim, kdim, stride, padding, dilation in (
((1, 1), (1, 1), (1, 1), (0, 0), (1, 1)),
((3, 3), (3, 1), (1, 1), (0, 0), (1, 1)),
((31, 31), (5, 5), (5, 5), (2, 2), (1, 1)),
((32, 32), (3, 3), (2, 2), (1, 1), (1, 1)),
((31, 31), (5, 5), (5, 5), (2, 2), (3, 2)),
((32, 32), (3, 3), (2, 2), (1, 1), (3, 2)),
):
run_conv_transpose2D_grad(
N, C, O, idim, kdim, stride, padding, dilation, dtype=dtype
)
@unittest.skipIf(not has_torch, "requires Torch")
def test_torch_conv_transpose_3D(self):
def run_conv_transpose3D(
N,
C,
O,
idim,
kdim,
stride,
padding,
dilation=(1, 1, 1),
groups=1,
dtype="float32",
atol=1e-5,
):
with self.subTest(
dtype=dtype,
N=N,
C=C,
O=O,
idim=idim,
kdim=kdim,
stride=stride,
padding=padding,
dilation=dilation,
groups=groups,
):
np_dtype = getattr(np, dtype)
np.random.seed(0)
iD, iH, iW = idim
kD, kH, kW = kdim
scale = 1.0 / math.sqrt(kD * kH * kW * C * O)
in_np = np.random.normal(0.0, scale, (N, iD, iH, iW, C)).astype(
np_dtype
)
wt_np = np.random.normal(0.0, 1.0, (O, kD, kH, kW, C)).astype(np_dtype)
in_mx, wt_mx = map(mx.array, (in_np, wt_np))
in_pt = torch.from_numpy(in_np.transpose(0, 4, 1, 2, 3))
wt_pt = torch.from_numpy(wt_np.transpose(4, 0, 1, 2, 3))
out_mx = mx.conv_transpose3d(
in_mx,
wt_mx,
stride=stride,
padding=padding,
dilation=dilation,
groups=groups,
)
out_pt = torch.conv_transpose3d(
in_pt,
wt_pt,
stride=stride,
padding=padding,
dilation=dilation,
groups=groups,
)
out_pt = torch.permute(out_pt, (0, 2, 3, 4, 1)).numpy(force=True)
self.assertEqual(out_pt.shape, out_mx.shape)
self.assertTrue(np.allclose(out_pt, out_mx, atol=atol))
for dtype in ("float32",):
for N, C, O in (
(1, 1, 1),
(1, 6, 1),
(1, 1, 6),
(2, 8, 16),
):
for idim, kdim, stride, padding in (
((1, 1, 1), (1, 1, 1), (1, 1, 1), (0, 0, 0)),
((3, 3, 3), (3, 1, 1), (1, 1, 1), (0, 0, 0)),
((15, 15, 15), (3, 3, 3), (3, 3, 3), (2, 2, 2)),
):
run_conv_transpose3D(
N, C, O, idim, kdim, stride, padding, dtype=dtype
)
@unittest.skipIf(not has_torch, "requires Torch")
def test_torch_conv_transpose_3D_grad(self):
def run_conv_transpose3D_grad(
N,
C,
O,
idim,
kdim,
stride,
padding,
dilation=(1, 1, 1),
groups=1,
dtype="float32",
atol=1e-4,
):
with self.subTest(
dtype=dtype,
N=N,
C=C,
O=O,
idim=idim,
kdim=kdim,
stride=stride,
padding=padding,
dilation=dilation,
groups=groups,
):
np_dtype = getattr(np, dtype)
np.random.seed(0)
iD, iH, iW = idim
kD, kH, kW = kdim
scale = 1.0 / math.sqrt(kD * kH * kW * C * O)
in_np = np.random.normal(0.0, scale, (N, iD, iH, iW, C)).astype(
np_dtype
)
wt_np = np.random.normal(0.0, scale, (O, kD, kH, kW, C)).astype(
np_dtype
)
in_mx, wt_mx = map(mx.array, (in_np, wt_np))
in_pt = torch.from_numpy(in_np.transpose(0, 4, 1, 2, 3)).requires_grad_(
True
)
wt_pt = torch.from_numpy(wt_np.transpose(4, 0, 1, 2, 3)).requires_grad_(
True
)
out_pt = F.conv_transpose3d(
in_pt,
wt_pt,
stride=stride,
padding=padding,
dilation=dilation,
groups=groups,
)
# use torch to compute ct
out_pt.retain_grad()
out_pt.sum().backward()
pt_grad_in = in_pt.grad.permute(0, 2, 3, 4, 1).numpy()
pt_grad_wt = wt_pt.grad.permute(1, 2, 3, 4, 0).numpy()
ct_mx = mx.array(out_pt.grad.numpy().transpose(0, 2, 3, 4, 1))
def f(a, b):
return mx.conv_transpose3d(
a,
b,
stride=stride,
padding=padding,
dilation=dilation,
groups=groups,
)
_, outs_mx = mx.vjp(
f,
[in_mx, wt_mx],
[ct_mx],
)
mx_grad_in, mx_grad_wt = outs_mx
self.assertEqual(pt_grad_in.shape, mx_grad_in.shape)
self.assertEqual(in_mx.shape, mx_grad_in.shape)
self.assertTrue(np.allclose(pt_grad_in, mx_grad_in, atol=atol))
self.assertEqual(pt_grad_wt.shape, mx_grad_wt.shape)
self.assertEqual(wt_mx.shape, mx_grad_wt.shape)
self.assertTrue(np.allclose(pt_grad_wt, mx_grad_wt, atol=atol))
for dtype in ("float32",):
for N, C, O in ((1, 1, 1), (1, 6, 1), (1, 1, 6), (2, 4, 8), (2, 8, 16)):
for idim, kdim, stride, padding, dilation in (
((1, 1, 1), (1, 1, 1), (1, 1, 1), (0, 0, 0), (1, 1, 1)),
((3, 3, 3), (3, 1, 1), (1, 1, 1), (0, 0, 0), (1, 1, 1)),
((7, 7, 7), (5, 5, 5), (5, 5, 5), (2, 2, 2), (1, 1, 1)),
((8, 8, 8), (3, 3, 3), (2, 2, 2), (1, 1, 1), (1, 1, 1)),
((7, 7, 7), (5, 5, 5), (3, 3, 3), (2, 2, 2), (3, 2, 2)),
((8, 8, 8), (3, 3, 3), (2, 2, 2), (1, 1, 1), (3, 2, 2)),
):
run_conv_transpose3D_grad(
N, C, O, idim, kdim, stride, padding, dilation, dtype=dtype
)
@unittest.skipIf(not has_torch, "requires Torch")
def test_torch_conv_tranpose_1d_output_padding(self):
def run_conv_transpose_1d_output_padding(
N, C, O, iH, kH, stride, padding, output_padding, dtype="float32", atol=1e-5
):
with self.subTest(
dtype=dtype,
N=N,
C=C,
O=O,
iH=iH,
kH=kH,
stride=stride,
padding=padding,
output_padding=output_padding,
):
np_dtype = getattr(np, dtype)
np.random.seed(0)
in_np = np.random.normal(0, 1.0 / C, (N, iH, C)).astype(np_dtype)
wt_np = np.random.normal(0, 1.0 / C, (O, kH, C)).astype(np_dtype)
in_mx, wt_mx = map(mx.array, (in_np, wt_np))
in_pt = torch.from_numpy(in_np.transpose(0, 2, 1))
wt_pt = torch.from_numpy(wt_np.transpose(2, 0, 1))
out_mx = mx.conv_transpose1d(
in_mx,
wt_mx,
stride=stride,
padding=padding,
output_padding=output_padding,
)
out_pt = torch.conv_transpose1d(
in_pt,
wt_pt,
stride=stride,
padding=padding,
output_padding=output_padding,
)
out_pt = torch.transpose(out_pt, 2, 1)
self.assertEqual(out_pt.shape, out_mx.shape)
self.assertTrue(np.allclose(out_pt.numpy(), out_mx, atol=atol))
for dtype in ("float32",):
for N, C, O in ((1, 1, 1), (1, 6, 1), (4, 32, 64)):
for iH, kH, stride, padding, output_padding in (
(3, 2, 2, 0, 1),
(5, 3, 2, 1, 0),
(7, 4, 3, 1, 2),
):
run_conv_transpose_1d_output_padding(
N, C, O, iH, kH, stride, padding, output_padding, dtype=dtype
)
@unittest.skipIf(not has_torch, "requires Torch")
def test_torch_conv_transpose_2d_output_padding(self):
def run_conv_transpose_2d_output_padding(
N,
C,
O,
idim,
kdim,
stride,
padding,
output_padding,
dtype="float32",
atol=1e-5,
):
with self.subTest(
dtype=dtype,
N=N,
C=C,
O=O,
idim=idim,
kdim=kdim,
stride=stride,
padding=padding,
output_padding=output_padding,
):
np_dtype = getattr(np, dtype)
np.random.seed(0)
iH, iW = idim
kH, kW = kdim
in_np = np.random.normal(0, 1.0 / C, (N, iH, iW, C)).astype(np_dtype)
wt_np = np.random.normal(0, 1.0 / C, (O, kH, kW, C)).astype(np_dtype)
in_mx, wt_mx = map(mx.array, (in_np, wt_np))
in_pt = torch.from_numpy(in_np.transpose(0, 3, 1, 2))
wt_pt = torch.from_numpy(wt_np.transpose(3, 0, 1, 2))
out_mx = mx.conv_transpose2d(
in_mx,
wt_mx,
stride=stride,
padding=padding,
output_padding=output_padding,
)
out_pt = torch.conv_transpose2d(
in_pt,
wt_pt,
stride=stride,
padding=padding,
output_padding=output_padding,
)
out_pt = torch.permute(out_pt, (0, 2, 3, 1)).numpy(force=True)
self.assertEqual(out_pt.shape, out_mx.shape)
self.assertTrue(np.allclose(out_pt, out_mx, atol=atol))
for dtype in ("float32",):
for N, C, O in ((1, 1, 1), (1, 6, 1), (4, 32, 64)):
for idim, kdim, stride, padding, output_padding in (
((3, 3), (2, 2), (2, 2), (0, 0), (1, 1)),
((5, 5), (3, 3), (2, 2), (1, 1), (0, 0)),
((7, 7), (4, 4), (3, 3), (1, 1), (2, 2)),
):
run_conv_transpose_2d_output_padding(
N,
C,
O,
idim,
kdim,
stride,
padding,
output_padding,
dtype=dtype,
)
@unittest.skipIf(not has_torch, "requires Torch")
def test_torch_conv_transpose_3d_output_padding(self):
def run_conv_transpose_3d_output_padding(
N,
C,
O,
idim,
kdim,
stride,
padding,
output_padding,
dtype="float32",
atol=1e-5,
):
with self.subTest(
dtype=dtype,
N=N,
C=C,
O=O,
idim=idim,
kdim=kdim,
stride=stride,
padding=padding,
output_padding=output_padding,
):
np_dtype = getattr(np, dtype)
np.random.seed(0)
iD, iH, iW = idim
kD, kH, kW = kdim
in_np = np.random.normal(0, 1.0 / C, (N, iD, iH, iW, C)).astype(
np_dtype
)
wt_np = np.random.normal(0, 1.0 / C, (O, kD, kH, kW, C)).astype(
np_dtype
)
in_mx, wt_mx = map(mx.array, (in_np, wt_np))
in_pt = torch.from_numpy(in_np.transpose(0, 4, 1, 2, 3))
wt_pt = torch.from_numpy(wt_np.transpose(4, 0, 1, 2, 3))
out_mx = mx.conv_transpose3d(
in_mx,
wt_mx,
stride=stride,
padding=padding,
output_padding=output_padding,
)
out_pt = torch.conv_transpose3d(
in_pt,
wt_pt,
stride=stride,
padding=padding,
output_padding=output_padding,
)
out_pt = torch.permute(out_pt, (0, 2, 3, 4, 1)).numpy(force=True)
self.assertEqual(out_pt.shape, out_mx.shape)
self.assertTrue(np.allclose(out_pt, out_mx, atol=atol))
for dtype in ("float32",):
for N, C, O in ((1, 1, 1), (1, 6, 1), (4, 32, 64)):
for idim, kdim, stride, padding, output_padding in (
((3, 3, 3), (2, 2, 2), (2, 2, 2), (0, 0, 0), (1, 1, 1)),
((5, 5, 5), (3, 3, 3), (2, 2, 2), (1, 1, 1), (0, 0, 0)),
((7, 7, 7), (4, 4, 4), (3, 3, 3), (1, 1, 1), (2, 2, 2)),
):
run_conv_transpose_3d_output_padding(
N,
C,
O,
idim,
kdim,
stride,
padding,
output_padding,
dtype=dtype,
)
if __name__ == "__main__":
mlx_tests.MLXTestRunner()