Transposed Convolution (#1245)

* initial implementation for conv_transpose

ran pre-commit

implemented conv_transpose

updated conv_general docstring

updated conv_general docstring

updated code comments

removed commented run_conv_checks

updated acknowledgments

added missing entry to ops.rst

added op to nn.layers

resolved merge conflicts

* removed ConvolutionTranspose primitive as suggested by reviewer

removed ConvolutionTranspose primitive as suggested by reviewer

* remove transpose flag, add another test

---------

Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
Max-Heinrich Laves
2024-09-07 04:52:38 +02:00
committed by GitHub
parent ba3e913c7a
commit efeb9c0f02
15 changed files with 1337 additions and 45 deletions

View File

@@ -55,6 +55,11 @@ from mlx.nn.layers.activations import (
from mlx.nn.layers.base import Module
from mlx.nn.layers.containers import Sequential
from mlx.nn.layers.convolution import Conv1d, Conv2d, Conv3d
from mlx.nn.layers.convolution_transpose import (
ConvTranspose1d,
ConvTranspose2d,
ConvTranspose3d,
)
from mlx.nn.layers.dropout import Dropout, Dropout2d, Dropout3d
from mlx.nn.layers.embedding import Embedding
from mlx.nn.layers.linear import Bilinear, Identity, Linear

View File

@@ -21,9 +21,9 @@ class Conv1d(Module):
out_channels (int): The number of output channels
kernel_size (int): The size of the convolution filters
stride (int, optional): The stride when applying the filter.
Default: 1.
Default: ``1``.
padding (int, optional): How many positions to 0-pad the input with.
Default: 0.
Default: ``0``.
dilation (int, optional): The dilation of the convolution.
bias (bool, optional): If ``True`` add a learnable bias to the output.
Default: ``True``
@@ -84,9 +84,9 @@ class Conv2d(Module):
out_channels (int): The number of output channels.
kernel_size (int or tuple): The size of the convolution filters.
stride (int or tuple, optional): The size of the stride when
applying the filter. Default: 1.
applying the filter. Default: ``1``.
padding (int or tuple, optional): How many positions to 0-pad
the input with. Default: 0.
the input with. Default: ``0``.
dilation (int or tuple, optional): The dilation of the convolution.
bias (bool, optional): If ``True`` add a learnable bias to the
output. Default: ``True``

View File

@@ -0,0 +1,206 @@
# Copyright © 2023 Apple Inc.
import math
from typing import Union
import mlx.core as mx
from mlx.nn.layers.base import Module
class ConvTranspose1d(Module):
"""Applies a 1-dimensional transposed convolution over the multi-channel input sequence.
The channels are expected to be last i.e. the input shape should be ``NLC`` where:
* ``N`` is the batch dimension
* ``L`` is the sequence length
* ``C`` is the number of input channels
Args:
in_channels (int): The number of input channels
out_channels (int): The number of output channels
kernel_size (int): The size of the convolution filters
stride (int, optional): The stride when applying the filter.
Default: ``1``.
padding (int, optional): How many positions to 0-pad the input with.
Default: ``0``.
dilation (int, optional): The dilation of the convolution.
bias (bool, optional): If ``True`` add a learnable bias to the output.
Default: ``True``
"""
def __init__(
self,
in_channels: int,
out_channels: int,
kernel_size: int,
stride: int = 1,
padding: int = 0,
dilation: int = 1,
bias: bool = True,
):
super().__init__()
scale = math.sqrt(1 / (in_channels * kernel_size))
self.weight = mx.random.uniform(
low=-scale,
high=scale,
shape=(out_channels, kernel_size, in_channels),
)
if bias:
self.bias = mx.zeros((out_channels,))
self.padding = padding
self.dilation = dilation
self.stride = stride
def _extra_repr(self):
return (
f"{self.weight.shape[-1]}, {self.weight.shape[0]}, "
f"kernel_size={self.weight.shape[1]}, stride={self.stride}, "
f"padding={self.padding}, dilation={self.dilation}, "
f"bias={'bias' in self}"
)
def __call__(self, x):
y = mx.conv_transpose1d(
x, self.weight, self.stride, self.padding, self.dilation
)
if "bias" in self:
y = y + self.bias
return y
class ConvTranspose2d(Module):
"""Applies a 2-dimensional transposed convolution over the multi-channel input image.
The channels are expected to be last i.e. the input shape should be ``NHWC`` where:
* ``N`` is the batch dimension
* ``H`` is the input image height
* ``W`` is the input image width
* ``C`` is the number of input channels
Args:
in_channels (int): The number of input channels.
out_channels (int): The number of output channels.
kernel_size (int or tuple): The size of the convolution filters.
stride (int or tuple, optional): The size of the stride when
applying the filter. Default: ``1``.
padding (int or tuple, optional): How many positions to 0-pad
the input with. Default: ``0``.
dilation (int or tuple, optional): The dilation of the convolution.
bias (bool, optional): If ``True`` add a learnable bias to the
output. Default: ``True``
"""
def __init__(
self,
in_channels: int,
out_channels: int,
kernel_size: Union[int, tuple],
stride: Union[int, tuple] = 1,
padding: Union[int, tuple] = 0,
dilation: Union[int, tuple] = 1,
bias: bool = True,
):
super().__init__()
kernel_size, stride, padding = map(
lambda x: (x, x) if isinstance(x, int) else x,
(kernel_size, stride, padding),
)
scale = math.sqrt(1 / (in_channels * kernel_size[0] * kernel_size[1]))
self.weight = mx.random.uniform(
low=-scale,
high=scale,
shape=(out_channels, *kernel_size, in_channels),
)
if bias:
self.bias = mx.zeros((out_channels,))
self.padding = padding
self.stride = stride
self.dilation = dilation
def _extra_repr(self):
return (
f"{self.weight.shape[-1]}, {self.weight.shape[0]}, "
f"kernel_size={self.weight.shape[1:2]}, stride={self.stride}, "
f"padding={self.padding}, dilation={self.dilation}, "
f"bias={'bias' in self}"
)
def __call__(self, x):
y = mx.conv_transpose2d(
x, self.weight, self.stride, self.padding, self.dilation
)
if "bias" in self:
y = y + self.bias
return y
class ConvTranspose3d(Module):
"""Applies a 3-dimensional transposed convolution over the multi-channel input image.
The channels are expected to be last i.e. the input shape should be ``NDHWC`` where:
* ``N`` is the batch dimension
* ``D`` is the input image depth
* ``H`` is the input image height
* ``W`` is the input image width
* ``C`` is the number of input channels
Args:
in_channels (int): The number of input channels.
out_channels (int): The number of output channels.
kernel_size (int or tuple): The size of the convolution filters.
stride (int or tuple, optional): The size of the stride when
applying the filter. Default: ``1``.
padding (int or tuple, optional): How many positions to 0-pad
the input with. Default: ``0``.
bias (bool, optional): If ``True`` add a learnable bias to the
output. Default: ``True``
"""
def __init__(
self,
in_channels: int,
out_channels: int,
kernel_size: Union[int, tuple],
stride: Union[int, tuple] = 1,
padding: Union[int, tuple] = 0,
bias: bool = True,
):
super().__init__()
kernel_size, stride, padding = map(
lambda x: (x, x, x) if isinstance(x, int) else x,
(kernel_size, stride, padding),
)
scale = math.sqrt(
1 / (in_channels * kernel_size[0] * kernel_size[1] * kernel_size[2])
)
self.weight = mx.random.uniform(
low=-scale,
high=scale,
shape=(out_channels, *kernel_size, in_channels),
)
if bias:
self.bias = mx.zeros((out_channels,))
self.padding = padding
self.stride = stride
def _extra_repr(self):
return (
f"{self.weight.shape[-1]}, {self.weight.shape[0]}, "
f"kernel_size={self.weight.shape[1:3]}, stride={self.stride}, "
f"padding={self.padding}, bias={'bias' in self}"
)
def __call__(self, x):
y = mx.conv_transpose3d(x, self.weight, self.stride, self.padding)
if "bias" in self:
y = y + self.bias
return y

View File

@@ -3238,12 +3238,12 @@ void init_ops(nb::module_& m) {
1D convolution over an input with several channels
Args:
input (array): input array of shape (``N``, ``H``, ``C_in``)
weight (array): weight array of shape (``C_out``, ``H``, ``C_in``)
stride (int, optional): kernel stride. Default: ``1``.
padding (int, optional): input padding. Default: ``0``.
dilation (int, optional): kernel dilation. Default: ``1``.
groups (int, optional): input feature groups. Default: ``1``.
input (array): Input array of shape ``(N, H, C_in)``.
weight (array): Weight array of shape ``(C_out, H, C_in)``.
stride (int, optional): Kernel stride. Default: ``1``.
padding (int, optional): Input padding. Default: ``0``.
dilation (int, optional): Kernel dilation. Default: ``1``.
groups (int, optional): Input feature groups. Default: ``1``.
Returns:
array: The convolved array.
@@ -3296,8 +3296,8 @@ void init_ops(nb::module_& m) {
2D convolution over an input with several channels
Args:
input (array): input array of shape ``(N, H, W, C_in)``
weight (array): weight array of shape ``(C_out, H, W, C_in)``
input (array): Input array of shape ``(N, H, W, C_in)``.
weight (array): Weight array of shape ``(C_out, H, W, C_in)``.
stride (int or tuple(int), optional): :obj:`tuple` of size 2 with
kernel strides. All spatial dimensions get the same stride if
only one number is specified. Default: ``1``.
@@ -3368,8 +3368,173 @@ void init_ops(nb::module_& m) {
Note: Only the default ``groups=1`` is currently supported.
Args:
input (array): input array of shape ``(N, D, H, W, C_in)``
weight (array): weight array of shape ``(C_out, D, H, W, C_in)``
input (array): Input array of shape ``(N, D, H, W, C_in)``.
weight (array): Weight array of shape ``(C_out, D, H, W, C_in)``.
stride (int or tuple(int), optional): :obj:`tuple` of size 3 with
kernel strides. All spatial dimensions get the same stride if
only one number is specified. Default: ``1``.
padding (int or tuple(int), optional): :obj:`tuple` of size 3 with
symmetric input padding. All spatial dimensions get the same
padding if only one number is specified. Default: ``0``.
dilation (int or tuple(int), optional): :obj:`tuple` of size 3 with
kernel dilation. All spatial dimensions get the same dilation
if only one number is specified. Default: ``1``
groups (int, optional): input feature groups. Default: ``1``.
Returns:
array: The convolved array.
)pbdoc");
m.def(
"conv_transpose1d",
&conv_transpose1d,
nb::arg(),
nb::arg(),
"stride"_a = 1,
"padding"_a = 0,
"dilation"_a = 1,
"groups"_a = 1,
nb::kw_only(),
"stream"_a = nb::none(),
nb::sig(
"def conv_transpose1d(input: array, weight: array, /, stride: int = 1, padding: int = 0, dilation: int = 1, groups: int = 1, *, stream: Union[None, Stream, Device] = None) -> array"),
R"pbdoc(
1D transposed convolution over an input with several channels
Args:
input (array): Input array of shape ``(N, H, C_in)``.
weight (array): Weight array of shape ``(C_out, H, C_in)``.
stride (int, optional): Kernel stride. Default: ``1``.
padding (int, optional): Input padding. Default: ``0``.
dilation (int, optional): Kernel dilation. Default: ``1``.
groups (int, optional): Input feature groups. Default: ``1``.
Returns:
array: The convolved array.
)pbdoc");
m.def(
"conv_transpose2d",
[](const array& input,
const array& weight,
const std::variant<int, std::pair<int, int>>& stride,
const std::variant<int, std::pair<int, int>>& padding,
const std::variant<int, std::pair<int, int>>& dilation,
int groups,
StreamOrDevice s) {
std::pair<int, int> stride_pair{1, 1};
std::pair<int, int> padding_pair{0, 0};
std::pair<int, int> dilation_pair{1, 1};
if (auto pv = std::get_if<int>(&stride); pv) {
stride_pair = std::pair<int, int>{*pv, *pv};
} else {
stride_pair = std::get<std::pair<int, int>>(stride);
}
if (auto pv = std::get_if<int>(&padding); pv) {
padding_pair = std::pair<int, int>{*pv, *pv};
} else {
padding_pair = std::get<std::pair<int, int>>(padding);
}
if (auto pv = std::get_if<int>(&dilation); pv) {
dilation_pair = std::pair<int, int>{*pv, *pv};
} else {
dilation_pair = std::get<std::pair<int, int>>(dilation);
}
return conv_transpose2d(
input, weight, stride_pair, padding_pair, dilation_pair, groups, s);
},
nb::arg(),
nb::arg(),
"stride"_a = 1,
"padding"_a = 0,
"dilation"_a = 1,
"groups"_a = 1,
nb::kw_only(),
"stream"_a = nb::none(),
nb::sig(
"def conv_transpose2d(input: array, weight: array, /, stride: Union[int, Tuple[int, int]] = 1, padding: Union[int, Tuple[int, int]] = 0, dilation: Union[int, Tuple[int, int]] = 1, groups: int = 1, *, stream: Union[None, Stream, Device] = None) -> array"),
R"pbdoc(
2D transposed convolution over an input with several channels
Note: Only the default ``groups=1`` is currently supported.
Args:
input (array): Input array of shape ``(N, H, W, C_in)``.
weight (array): Weight array of shape ``(C_out, H, W, C_in)``.
stride (int or tuple(int), optional): :obj:`tuple` of size 2 with
kernel strides. All spatial dimensions get the same stride if
only one number is specified. Default: ``1``.
padding (int or tuple(int), optional): :obj:`tuple` of size 2 with
symmetric input padding. All spatial dimensions get the same
padding if only one number is specified. Default: ``0``.
dilation (int or tuple(int), optional): :obj:`tuple` of size 2 with
kernel dilation. All spatial dimensions get the same dilation
if only one number is specified. Default: ``1``
groups (int, optional): input feature groups. Default: ``1``.
Returns:
array: The convolved array.
)pbdoc");
m.def(
"conv_transpose3d",
[](const array& input,
const array& weight,
const std::variant<int, std::tuple<int, int, int>>& stride,
const std::variant<int, std::tuple<int, int, int>>& padding,
const std::variant<int, std::tuple<int, int, int>>& dilation,
int groups,
StreamOrDevice s) {
std::tuple<int, int, int> stride_tuple{1, 1, 1};
std::tuple<int, int, int> padding_tuple{0, 0, 0};
std::tuple<int, int, int> dilation_tuple{1, 1, 1};
if (auto pv = std::get_if<int>(&stride); pv) {
stride_tuple = std::tuple<int, int, int>{*pv, *pv, *pv};
} else {
stride_tuple = std::get<std::tuple<int, int, int>>(stride);
}
if (auto pv = std::get_if<int>(&padding); pv) {
padding_tuple = std::tuple<int, int, int>{*pv, *pv, *pv};
} else {
padding_tuple = std::get<std::tuple<int, int, int>>(padding);
}
if (auto pv = std::get_if<int>(&dilation); pv) {
dilation_tuple = std::tuple<int, int, int>{*pv, *pv, *pv};
} else {
dilation_tuple = std::get<std::tuple<int, int, int>>(dilation);
}
return conv_transpose3d(
input,
weight,
stride_tuple,
padding_tuple,
dilation_tuple,
groups,
s);
},
nb::arg(),
nb::arg(),
"stride"_a = 1,
"padding"_a = 0,
"dilation"_a = 1,
"groups"_a = 1,
nb::kw_only(),
"stream"_a = nb::none(),
nb::sig(
"def conv_transpose3d(input: array, weight: array, /, stride: Union[int, Tuple[int, int, int]] = 1, padding: Union[int, Tuple[int, int, int]] = 0, dilation: Union[int, Tuple[int, int, int]] = 1, groups: int = 1, *, stream: Union[None, Stream, Device] = None) -> array"),
R"pbdoc(
3D transposed convolution over an input with several channels
Note: Only the default ``groups=1`` is currently supported.
Args:
input (array): Input array of shape ``(N, D, H, W, C_in)``.
weight (array): Weight array of shape ``(C_out, D, H, W, C_in)``.
stride (int or tuple(int), optional): :obj:`tuple` of size 3 with
kernel strides. All spatial dimensions get the same stride if
only one number is specified. Default: ``1``.
@@ -3465,8 +3630,8 @@ void init_ops(nb::module_& m) {
General convolution over an input with several channels
Args:
input (array): Input array of shape ``(N, ..., C_in)``
weight (array): Weight array of shape ``(C_out, ..., C_in)``
input (array): Input array of shape ``(N, ..., C_in)``.
weight (array): Weight array of shape ``(C_out, ..., C_in)``.
stride (int or list(int), optional): :obj:`list` with kernel strides.
All spatial dimensions get the same stride if
only one number is specified. Default: ``1``.

View File

@@ -866,6 +866,37 @@ class TestConv(mlx_tests.MLXTestCase):
flip=flip,
)
def test_conv_general_flip_grad(self):
for s in (1, 2):
w = mx.random.normal(shape=(1, 2, 2, 1))
x = mx.random.normal(shape=(1, 2, 2, 1))
def conv_t(w):
return mx.conv_general(
x,
w,
stride=1,
padding=(1, 1),
kernel_dilation=1,
input_dilation=s,
flip=True,
)
cotan = mx.random.normal(shape=(1, 2 + s, 2 + s, 1))
dw = mx.vjp(conv_t, (w,), (cotan,))[1][0]
x = x.squeeze()
cotan = cotan.squeeze()
dw = dw.squeeze()
dw00 = (cotan[:-1:s, :-1:s] * x).sum()
dw01 = (cotan[:-1:s, 1::s] * x).sum()
dw10 = (cotan[1::s, :-1:s] * x).sum()
dw11 = (cotan[1::s, 1::s] * x).sum()
expected = mx.array([[dw00, dw01], [dw10, dw11]])
self.assertTrue(mx.allclose(dw, expected))
if __name__ == "__main__":
unittest.main()

View File

@@ -0,0 +1,601 @@
# Copyright © 2023-2024 Apple Inc.
import math
import unittest
from itertools import permutations
import mlx.core as mx
import mlx_tests
import numpy as np
try:
import torch
import torch.nn.functional as F
has_torch = True
except ImportError as e:
has_torch = False
class TestConvTranspose(mlx_tests.MLXTestCase):
@unittest.skipIf(not has_torch, "requires Torch")
def test_torch_conv_transpose_1D(self):
def run_conv_transpose_1D(
N,
C,
O,
iH,
kH,
stride,
padding,
output_padding=0,
dilation=1,
groups=1,
dtype="float32",
atol=1e-5,
):
with self.subTest(
dtype=dtype,
N=N,
C=C,
O=O,
iH=iH,
kH=kH,
stride=stride,
padding=padding,
dilation=dilation,
groups=groups,
):
np_dtype = getattr(np, dtype)
np.random.seed(0)
in_np = np.random.normal(0, 1.0 / C, (N, iH, C)).astype(np_dtype)
wt_np = np.random.normal(0, 1.0 / C, (O, kH, int(C / groups))).astype(
np_dtype
)
in_mx, wt_mx = map(mx.array, (in_np, wt_np))
in_pt = torch.from_numpy(in_np.transpose(0, 2, 1))
wt_pt = torch.from_numpy(wt_np.transpose(2, 0, 1))
out_mx = mx.conv_transpose1d(
in_mx,
wt_mx,
stride=stride,
padding=padding,
dilation=dilation,
groups=groups,
)
out_pt = torch.conv_transpose1d(
in_pt,
wt_pt,
stride=stride,
padding=padding,
dilation=dilation,
groups=groups,
)
out_pt = torch.transpose(out_pt, 2, 1)
self.assertEqual(out_pt.shape, out_mx.shape)
self.assertTrue(np.allclose(out_pt.numpy(), out_mx, atol=atol))
for dtype in ("float32",):
for N, C, O in (
(1, 1, 1),
(1, 6, 1),
(1, 1, 6),
(4, 32, 64),
):
for iH, kH, stride, padding in (
(1, 1, 1, 0),
(3, 3, 1, 0),
(31, 5, 5, 2),
):
run_conv_transpose_1D(N, C, O, iH, kH, stride, padding, dtype=dtype)
# Groups tests
N, C, O = (4, 32, 64)
for iH, kH, stride, padding in (
(1, 1, 1, 0),
(3, 3, 1, 0),
(31, 5, 5, 2),
):
for group in (1,):
run_conv_transpose_1D(
N, C, O, iH, kH, stride, padding, groups=group, dtype=dtype
)
# Strided inputs tests
for tpose_in, tpose_wt in (
((0, 2, 1), (0, 1, 2)),
((0, 2, 1), (0, 2, 1)),
):
with self.subTest(name="strided", tpose_in=tpose_in, tpose_wt=tpose_wt):
in_np = np.random.normal(0, 1.0 / 16, (16, 16, 16)).astype(np.float32)
wt_np = np.random.normal(0, 1.0 / 16, (16, 16, 16)).astype(np.float32)
in_mx, wt_mx = map(mx.array, (in_np, wt_np))
in_mx_t = mx.transpose(in_mx, tpose_in)
wt_mx_t = mx.transpose(wt_mx, tpose_wt)
out_mx = mx.conv_transpose1d(in_mx_t, wt_mx_t)
in_pt = torch.from_numpy(in_np.transpose(tpose_in).transpose(0, 2, 1))
wt_pt = torch.from_numpy(wt_np.transpose(tpose_wt).transpose(2, 0, 1))
out_pt = torch.conv_transpose1d(in_pt, wt_pt)
out_pt = torch.transpose(out_pt, 2, 1)
self.assertEqual(out_pt.shape, out_mx.shape)
self.assertTrue(np.allclose(out_pt.numpy(), out_mx, atol=1e-5))
@unittest.skipIf(not has_torch, "requires Torch")
def test_torch_conv_transpose_1D_grad(self):
def run_conv_transpose1D_grad(
N,
C,
O,
iH,
kH,
stride,
padding,
dilation=1,
groups=1,
dtype="float32",
atol=1e-5,
):
with self.subTest(
dtype=dtype,
N=N,
C=C,
O=O,
iH=iH,
kH=kH,
stride=stride,
padding=padding,
dilation=dilation,
groups=groups,
):
np_dtype = getattr(np, dtype)
np.random.seed(0)
# oH = 1 + ((iH + 2 * padding - dilation * (kH - 1) - 1) // stride)
in_np = np.random.normal(0, 1.0 / C, (N, iH, C)).astype(np_dtype)
wt_np = np.random.normal(0, 1.0 / C, (O, kH, C)).astype(np_dtype)
in_mx, wt_mx = map(mx.array, (in_np, wt_np))
in_pt = torch.from_numpy(in_np.transpose(0, 2, 1)).requires_grad_(True)
wt_pt = torch.from_numpy(wt_np.transpose(2, 0, 1)).requires_grad_(True)
out_pt = F.conv_transpose1d(
in_pt, wt_pt, stride=stride, padding=padding, dilation=dilation
)
# use torch to compute ct
out_pt.retain_grad()
(out_pt - torch.randn_like(out_pt)).abs().sum().backward()
pt_grad_in = in_pt.grad.permute(0, 2, 1).numpy()
pt_grad_wt = wt_pt.grad.permute(1, 2, 0).numpy()
ct_mx = mx.array(out_pt.grad.numpy().transpose(0, 2, 1))
def f(a, b):
return mx.conv_transpose1d(
a,
b,
stride=stride,
padding=padding,
dilation=dilation,
groups=groups,
)
_, outs_mx = mx.vjp(
f,
[
in_mx,
wt_mx,
],
[
ct_mx,
],
)
mx_grad_in, mx_grad_wt = outs_mx
self.assertEqual(pt_grad_in.shape, mx_grad_in.shape)
self.assertEqual(in_mx.shape, mx_grad_in.shape)
self.assertTrue(np.allclose(pt_grad_in, mx_grad_in, atol=atol))
self.assertEqual(pt_grad_wt.shape, mx_grad_wt.shape)
self.assertEqual(wt_mx.shape, mx_grad_wt.shape)
self.assertTrue(np.allclose(pt_grad_wt, mx_grad_wt, atol=atol))
for dtype in ("float32",):
for N, C, O in (
(1, 1, 1),
(1, 6, 1),
(1, 1, 6),
(4, 32, 64),
):
for iH, kH, stride, padding in (
(1, 1, 1, 0),
(3, 3, 1, 0),
(31, 5, 5, 2),
):
run_conv_transpose1D_grad(
N, C, O, iH, kH, stride, padding, dtype=dtype
)
@unittest.skipIf(not has_torch, "requires Torch")
def test_torch_conv_transpose_2D(self):
def run_conv_transpose2D(
N,
C,
O,
idim,
kdim,
stride,
padding,
dilation=(1, 1),
groups=1,
dtype="float32",
atol=1e-5,
):
with self.subTest(
dtype=dtype,
N=N,
C=C,
O=O,
idim=idim,
kdim=kdim,
stride=stride,
padding=padding,
dilation=dilation,
groups=groups,
):
np_dtype = getattr(np, dtype)
np.random.seed(0)
iH, iW = idim
kH, kW = kdim
scale = 1.0 / math.sqrt(kH * kW * C)
in_np = np.random.normal(0.0, scale, (N, iH, iW, C)).astype(np_dtype)
wt_np = np.random.normal(0.0, 1.0, (O, kH, kW, int(C / groups))).astype(
np_dtype
)
in_mx, wt_mx = map(mx.array, (in_np, wt_np))
in_pt = torch.from_numpy(in_np.transpose(0, 3, 1, 2)).to("cpu")
wt_pt = torch.from_numpy(wt_np.transpose(3, 0, 1, 2)).to("cpu")
out_mx = mx.conv_transpose2d(
in_mx,
wt_mx,
stride=stride,
padding=padding,
dilation=dilation,
groups=groups,
)
out_pt = torch.conv_transpose2d(
in_pt,
wt_pt,
stride=stride,
padding=padding,
dilation=dilation,
groups=groups,
)
out_pt = torch.permute(out_pt, (0, 2, 3, 1)).numpy(force=True)
self.assertEqual(out_pt.shape, out_mx.shape)
self.assertTrue(np.allclose(out_pt, out_mx, atol=atol))
for dtype in ("float32",):
for N, C, O in (
(1, 1, 1),
(1, 6, 1),
(1, 1, 6),
(4, 32, 64),
):
for idim, kdim, stride, padding in (
((1, 1), (1, 1), (1, 1), (0, 0)),
((3, 3), (3, 1), (1, 1), (0, 0)),
((31, 31), (5, 5), (5, 5), (2, 2)),
):
run_conv_transpose2D(
N, C, O, idim, kdim, stride, padding, dtype=dtype
)
# Groups tests
N, C, O = (4, 32, 64)
for idim, kdim, stride, padding in (
((1, 1), (1, 1), (1, 1), (0, 0)),
((3, 3), (3, 1), (1, 1), (0, 0)),
((31, 31), (5, 5), (5, 5), (2, 2)),
):
for group in (1,):
run_conv_transpose2D(
N, C, O, idim, kdim, stride, padding, groups=group, dtype=dtype
)
@unittest.skipIf(not has_torch, "requires Torch")
def test_torch_conv_transpose_2D_grad(self):
def run_conv_transpose2D_grad(
N,
C,
O,
idim,
kdim,
stride,
padding,
dilation=(1, 1),
groups=1,
dtype="float32",
atol=1e-5,
):
with self.subTest(
dtype=dtype,
N=N,
C=C,
O=O,
idim=idim,
kdim=kdim,
stride=stride,
padding=padding,
dilation=dilation,
groups=groups,
):
np_dtype = getattr(np, dtype)
np.random.seed(0)
iH, iW = idim
kH, kW = kdim
scale = 1.0 / math.sqrt(kH * kW * C * O)
in_np = np.random.normal(0.0, scale, (N, iH, iW, C)).astype(np_dtype)
wt_np = np.random.normal(0.0, scale, (O, kH, kW, C)).astype(np_dtype)
in_mx, wt_mx = map(mx.array, (in_np, wt_np))
in_pt = torch.from_numpy(in_np.transpose(0, 3, 1, 2)).requires_grad_(
True
)
wt_pt = torch.from_numpy(wt_np.transpose(3, 0, 1, 2)).requires_grad_(
True
)
out_pt = F.conv_transpose2d(
in_pt, wt_pt, stride=stride, padding=padding, dilation=dilation
)
# use torch to compute ct
out_pt.retain_grad()
(out_pt - torch.randn_like(out_pt)).abs().sum().backward()
pt_grad_in = in_pt.grad.permute(0, 2, 3, 1).numpy()
pt_grad_wt = wt_pt.grad.permute(1, 2, 3, 0).numpy()
ct_mx = mx.array(out_pt.grad.numpy().transpose(0, 2, 3, 1))
def f(a, b):
return mx.conv_transpose2d(
a,
b,
stride=stride,
padding=padding,
dilation=dilation,
groups=groups,
)
_, outs_mx = mx.vjp(
f,
[in_mx, wt_mx],
[ct_mx],
)
mx_grad_in, mx_grad_wt = outs_mx
self.assertEqual(pt_grad_in.shape, mx_grad_in.shape)
self.assertEqual(in_mx.shape, mx_grad_in.shape)
self.assertTrue(np.allclose(pt_grad_in, mx_grad_in, atol=atol))
self.assertEqual(pt_grad_wt.shape, mx_grad_wt.shape)
self.assertEqual(wt_mx.shape, mx_grad_wt.shape)
self.assertTrue(np.allclose(pt_grad_wt, mx_grad_wt, atol=atol))
for dtype in ("float32",):
for N, C, O in ((1, 1, 1), (1, 6, 1), (1, 1, 6), (4, 32, 64), (4, 16, 32)):
for idim, kdim, stride, padding, dilation in (
((1, 1), (1, 1), (1, 1), (0, 0), (1, 1)),
((3, 3), (3, 1), (1, 1), (0, 0), (1, 1)),
((31, 31), (5, 5), (5, 5), (2, 2), (1, 1)),
((32, 32), (3, 3), (2, 2), (1, 1), (1, 1)),
((31, 31), (5, 5), (5, 5), (2, 2), (3, 2)),
((32, 32), (3, 3), (2, 2), (1, 1), (3, 2)),
):
run_conv_transpose2D_grad(
N, C, O, idim, kdim, stride, padding, dilation, dtype=dtype
)
@unittest.skipIf(not has_torch, "requires Torch")
def test_torch_conv_transpose_3D(self):
def run_conv_transpose3D(
N,
C,
O,
idim,
kdim,
stride,
padding,
dilation=(1, 1, 1),
groups=1,
dtype="float32",
atol=1e-5,
):
with self.subTest(
dtype=dtype,
N=N,
C=C,
O=O,
idim=idim,
kdim=kdim,
stride=stride,
padding=padding,
dilation=dilation,
groups=groups,
):
np_dtype = getattr(np, dtype)
np.random.seed(0)
iD, iH, iW = idim
kD, kH, kW = kdim
scale = 1.0 / math.sqrt(kD * kH * kW * C * O)
in_np = np.random.normal(0.0, scale, (N, iD, iH, iW, C)).astype(
np_dtype
)
wt_np = np.random.normal(0.0, 1.0, (O, kD, kH, kW, C)).astype(np_dtype)
in_mx, wt_mx = map(mx.array, (in_np, wt_np))
in_pt = torch.from_numpy(in_np.transpose(0, 4, 1, 2, 3))
wt_pt = torch.from_numpy(wt_np.transpose(4, 0, 1, 2, 3))
out_mx = mx.conv_transpose3d(
in_mx,
wt_mx,
stride=stride,
padding=padding,
dilation=dilation,
groups=groups,
)
out_pt = torch.conv_transpose3d(
in_pt,
wt_pt,
stride=stride,
padding=padding,
dilation=dilation,
groups=groups,
)
out_pt = torch.permute(out_pt, (0, 2, 3, 4, 1)).numpy(force=True)
self.assertEqual(out_pt.shape, out_mx.shape)
self.assertTrue(np.allclose(out_pt, out_mx, atol=atol))
for dtype in ("float32",):
for N, C, O in (
(1, 1, 1),
(1, 6, 1),
(1, 1, 6),
(2, 8, 16),
):
for idim, kdim, stride, padding in (
((1, 1, 1), (1, 1, 1), (1, 1, 1), (0, 0, 0)),
((3, 3, 3), (3, 1, 1), (1, 1, 1), (0, 0, 0)),
((15, 15, 15), (3, 3, 3), (3, 3, 3), (2, 2, 2)),
):
run_conv_transpose3D(
N, C, O, idim, kdim, stride, padding, dtype=dtype
)
@unittest.skipIf(not has_torch, "requires Torch")
def test_torch_conv_transpose_3D_grad(self):
def run_conv_transpose3D_grad(
N,
C,
O,
idim,
kdim,
stride,
padding,
dilation=(1, 1, 1),
groups=1,
dtype="float32",
atol=1e-4,
):
with self.subTest(
dtype=dtype,
N=N,
C=C,
O=O,
idim=idim,
kdim=kdim,
stride=stride,
padding=padding,
dilation=dilation,
groups=groups,
):
np_dtype = getattr(np, dtype)
np.random.seed(0)
iD, iH, iW = idim
kD, kH, kW = kdim
scale = 1.0 / math.sqrt(kD * kH * kW * C * O)
in_np = np.random.normal(0.0, scale, (N, iD, iH, iW, C)).astype(
np_dtype
)
wt_np = np.random.normal(0.0, scale, (O, kD, kH, kW, C)).astype(
np_dtype
)
in_mx, wt_mx = map(mx.array, (in_np, wt_np))
in_pt = torch.from_numpy(in_np.transpose(0, 4, 1, 2, 3)).requires_grad_(
True
)
wt_pt = torch.from_numpy(wt_np.transpose(4, 0, 1, 2, 3)).requires_grad_(
True
)
out_pt = F.conv_transpose3d(
in_pt,
wt_pt,
stride=stride,
padding=padding,
dilation=dilation,
groups=groups,
)
# use torch to compute ct
out_pt.retain_grad()
(out_pt - torch.randn_like(out_pt)).abs().sum().backward()
pt_grad_in = in_pt.grad.permute(0, 2, 3, 4, 1).numpy()
pt_grad_wt = wt_pt.grad.permute(1, 2, 3, 4, 0).numpy()
ct_mx = mx.array(out_pt.grad.numpy().transpose(0, 2, 3, 4, 1))
def f(a, b):
return mx.conv_transpose3d(
a,
b,
stride=stride,
padding=padding,
dilation=dilation,
groups=groups,
)
_, outs_mx = mx.vjp(
f,
[in_mx, wt_mx],
[ct_mx],
)
mx_grad_in, mx_grad_wt = outs_mx
self.assertEqual(pt_grad_in.shape, mx_grad_in.shape)
self.assertEqual(in_mx.shape, mx_grad_in.shape)
self.assertTrue(np.allclose(pt_grad_in, mx_grad_in, atol=atol))
self.assertEqual(pt_grad_wt.shape, mx_grad_wt.shape)
self.assertEqual(wt_mx.shape, mx_grad_wt.shape)
self.assertTrue(np.allclose(pt_grad_wt, mx_grad_wt, atol=atol))
for dtype in ("float32",):
for N, C, O in ((1, 1, 1), (1, 6, 1), (1, 1, 6), (2, 4, 8), (2, 8, 16)):
for idim, kdim, stride, padding, dilation in (
((1, 1, 1), (1, 1, 1), (1, 1, 1), (0, 0, 0), (1, 1, 1)),
((3, 3, 3), (3, 1, 1), (1, 1, 1), (0, 0, 0), (1, 1, 1)),
((15, 15, 15), (5, 5, 5), (5, 5, 5), (2, 2, 2), (1, 1, 1)),
((16, 16, 16), (3, 3, 3), (2, 2, 2), (1, 1, 1), (1, 1, 1)),
((15, 15, 15), (5, 5, 5), (5, 5, 5), (2, 2, 2), (3, 2, 2)),
((16, 16, 16), (3, 3, 3), (2, 2, 2), (1, 1, 1), (3, 2, 2)),
):
run_conv_transpose3D_grad(
N, C, O, idim, kdim, stride, padding, dilation, dtype=dtype
)
if __name__ == "__main__":
unittest.main()