Transposed Convolution (#1245)

* initial implementation for conv_transpose

ran pre-commit

implemented conv_transpose

updated conv_general docstring

updated conv_general docstring

updated code comments

removed commented run_conv_checks

updated acknowledgments

added missing entry to ops.rst

added op to nn.layers

resolved merge conflicts

* removed ConvolutionTranspose primitive as suggested by reviewer

removed ConvolutionTranspose primitive as suggested by reviewer

* remove transpose flag, add another test

---------

Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
Max-Heinrich Laves 2024-09-07 04:52:38 +02:00 committed by GitHub
parent ba3e913c7a
commit efeb9c0f02
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
15 changed files with 1337 additions and 45 deletions

View File

@ -18,6 +18,7 @@ MLX was developed with contributions from the following individuals:
- AmirHossein Razlighi: Added chaining support for some of the ops in `nn.Module`. Comparison works for non array objects in `mlx.core.array`. Exception handling for invalid operations in `mlx.core.array`. - AmirHossein Razlighi: Added chaining support for some of the ops in `nn.Module`. Comparison works for non array objects in `mlx.core.array`. Exception handling for invalid operations in `mlx.core.array`.
- Gleb Pobudzey: Added the `where` primitive, and groups in 1D and 2D convolutions. - Gleb Pobudzey: Added the `where` primitive, and groups in 1D and 2D convolutions.
- Paul Paczuski: Improved stability of BCE loss calculation - Paul Paczuski: Improved stability of BCE loss calculation
- Max-Heinrich Laves: Added `conv_transpose1d`, `conv_transpose2d`, and `conv_transpose3d` ops.
<a href="https://github.com/ml-explore/mlx/graphs/contributors"> <a href="https://github.com/ml-explore/mlx/graphs/contributors">
<img class="dark-light" src="https://contrib.rocks/image?repo=ml-explore/mlx&anon=0&columns=20&max=100&r=true" /> <img class="dark-light" src="https://contrib.rocks/image?repo=ml-explore/mlx&anon=0&columns=20&max=100&r=true" />

View File

@ -0,0 +1,135 @@
import argparse
import math
import os
import subprocess
import time
import mlx.core as mx
import numpy as np
import torch
N_warmup = 10
N_iter_bench = 100
N_iter_func = 5
def bench(f, a, b):
for i in range(N_warmup):
f(a, b)
torch.mps.synchronize()
s = time.perf_counter_ns()
for i in range(N_iter_bench):
f(a, b)
e = time.perf_counter_ns()
return (e - s) * 1e-9
def make_mx_conv_transpose_2D(strides=(1, 1), padding=(0, 0), groups=1):
def mx_conv_transpose_2D(a, b):
ys = []
for i in range(N_iter_func):
y = mx.conv_transpose2d(
a, b, stride=strides, padding=padding, groups=groups
)
ys.append(y)
mx.eval(ys)
return ys
return mx_conv_transpose_2D
def make_pt_conv_transpose_2D(strides=(1, 1), padding=(0, 0), groups=1):
@torch.no_grad()
def pt_conv_transpose_2D(a, b):
ys = []
for i in range(N_iter_func):
y = torch.conv_transpose2d(
a, b, stride=strides, padding=padding, groups=groups
)
ys.append(y)
torch.mps.synchronize()
return ys
return pt_conv_transpose_2D
def bench_shape(N, H, W, C, kH, kW, O, strides, padding, groups, np_dtype):
scale = 1.0 / math.sqrt(kH * kH * C)
a_np = np.random.uniform(0, 0.5, (N, H, W, C)).astype(np_dtype)
b_np = np.random.uniform(-scale, scale, (O, kH, kW, int(C / groups))).astype(
np_dtype
)
a_mx = mx.array(a_np)
b_mx = mx.array(b_np)
a_pt = torch.from_numpy(a_np.transpose((0, 3, 1, 2))).to("mps")
b_pt = torch.from_numpy(b_np.transpose((3, 0, 1, 2))).to("mps")
torch.mps.synchronize()
f_mx = make_mx_conv_transpose_2D(strides, padding, groups)
f_pt = make_pt_conv_transpose_2D(strides, padding, groups)
time_torch = bench(f_pt, a_pt, b_pt)
time_mlx = bench(f_mx, a_mx, b_mx)
out_mx = mx.conv_transpose2d(
a_mx, b_mx, stride=strides, padding=padding, groups=groups
)
out_pt = torch.conv_transpose2d(
a_pt.to("cpu"), b_pt.to("cpu"), stride=strides, padding=padding, groups=groups
)
out_pt = torch.permute(out_pt, (0, 2, 3, 1))
out_pt = out_pt.numpy(force=True)
atol = 2e-5 if np_dtype == np.float32 else 1e-4
if not np.allclose(out_pt, out_mx, atol=atol):
print(
f"Failed at {(N, H, W, C)}, {(O, kH, kW, C)} [strides = {strides}, padding = {padding}, groups = {groups}] with max(|a - b|) = {np.max(np.abs(out_pt - out_mx))}"
)
return time_mlx, time_torch
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Run conv benchmarks")
dtypes = ("float32",)
shapes = (
(4, 32, 32, 32, 5, 5, 32, (1, 1), (2, 2), 1),
(4, 32, 32, 64, 5, 5, 64, (1, 1), (2, 2), 1),
(4, 32, 32, 128, 5, 5, 128, (1, 1), (2, 2), 1),
(4, 32, 32, 256, 5, 5, 256, (1, 1), (2, 2), 1),
(4, 32, 32, 512, 5, 5, 512, (1, 1), (2, 2), 1),
(4, 64, 64, 32, 5, 5, 32, (1, 1), (2, 2), 1),
(4, 64, 64, 64, 5, 5, 64, (1, 1), (2, 2), 1),
(4, 64, 64, 128, 5, 5, 128, (1, 1), (2, 2), 1),
(4, 64, 64, 256, 5, 5, 256, (1, 1), (2, 2), 1),
(4, 128, 128, 32, 5, 5, 32, (1, 1), (2, 2), 1),
(4, 128, 128, 64, 5, 5, 64, (1, 1), (2, 2), 1),
(4, 128, 128, 128, 5, 5, 128, (1, 1), (2, 2), 1),
(4, 256, 256, 32, 5, 5, 3, (1, 1), (2, 2), 1),
(4, 256, 256, 3, 5, 5, 32, (1, 1), (2, 2), 1),
(4, 128, 128, 64, 5, 5, 3, (1, 1), (2, 2), 1),
(4, 128, 128, 3, 5, 5, 64, (1, 1), (2, 2), 1),
)
for dtype in dtypes:
print(
"(N, H, W, C), ( O, kH, kW, C), dtype, stride, pads, groups, diff%"
)
for N, H, W, C, kH, kW, O, strides, padding, groups in shapes:
np_dtype = getattr(np, dtype)
time_mlx, time_torch = bench_shape(
N, H, W, C, kH, kW, O, strides, padding, groups, np_dtype
)
diff = time_torch / time_mlx - 1.0
print(
f"({N}, {H:3d}, {W:3d}, {C:3d}), ({O:3d}, {kH:2d}, {kW:2d}, {C:3d}), {dtype}, {strides}, {padding}, {groups:7d}, {100. * diff:+5.2f}%"
)
if time_mlx >= 2.0 * time_torch:
print("ATTENTION ^^^^^^^")

View File

@ -16,6 +16,9 @@ Layers
Conv1d Conv1d
Conv2d Conv2d
Conv3d Conv3d
ConvTranspose1d
ConvTranspose2d
ConvTranspose3d
Dropout Dropout
Dropout2d Dropout2d
Dropout3d Dropout3d

View File

@ -45,6 +45,9 @@ Operations
conv1d conv1d
conv2d conv2d
conv3d conv3d
conv_transpose1d
conv_transpose2d
conv_transpose3d
conv_general conv_general
cos cos
cosh cosh

View File

@ -1125,7 +1125,7 @@ void Convolution::eval(const std::vector<array>& inputs, array& out) {
else { else {
std::ostringstream msg; std::ostringstream msg;
msg << "[Convolution::eval] Convolution currently only supports" msg << "[Convolution::eval] Convolution currently only supports"
<< " 1D and 2D convolutions. Got inputs with " << in.ndim() - 2 << " 1D, 2D and 3D convolutions. Got inputs with " << in.ndim() - 2
<< " spatial dimensions"; << " spatial dimensions";
throw std::invalid_argument(msg.str()); throw std::invalid_argument(msg.str());
} }

View File

@ -911,7 +911,7 @@ void Convolution::eval_gpu(const std::vector<array>& inputs, array& out) {
// Throw error // Throw error
else { else {
throw std::invalid_argument( throw std::invalid_argument(
"[Convolution::eval_gpu] Only supports 1D or 2D convolutions."); "[Convolution::eval_gpu] Only supports 1D, 2D or 3D convolutions.");
} }
// Clear copies // Clear copies

View File

@ -3298,6 +3298,93 @@ array conv3d(
s); s);
} }
// Helper function for transposed convolutions
array conv_transpose_general(
const array& input,
const array& weight,
std::vector<int> stride,
std::vector<int> padding,
std::vector<int> dilation,
int groups,
StreamOrDevice s) {
std::vector<int> padding_lo(padding.size());
std::vector<int> padding_hi(padding.size());
for (int i = 0; i < padding.size(); ++i) {
int wt_size = 1 + dilation[i] * (weight.shape(1 + i) - 1);
padding_lo[i] = wt_size - padding[i] - 1;
int conv_output_shape = (input.shape(i + 1) - 1) * stride[i] -
2 * padding[i] + dilation[i] * (weight.shape(i + 1) - 1) + 1;
int in_size = 1 + (conv_output_shape - 1);
int out_size = 1 + stride[i] * (input.shape(1 + i) - 1);
padding_hi[i] = in_size - out_size + padding[i];
}
return conv_general(
/* const array& input = */ input,
/* const array& weight = */ weight,
/* std::vector<int> stride = */ std::vector(stride.size(), 1),
/* std::vector<int> padding_lo = */ std::move(padding_lo),
/* std::vector<int> padding_hi = */ std::move(padding_hi),
/* std::vector<int> kernel_dilation = */ std::move(dilation),
/* std::vector<int> input_dilation = */ std::move(stride),
/* int groups = */ groups,
/* bool flip = */ true,
s);
}
/** 1D transposed convolution with a filter */
array conv_transpose1d(
const array& in_,
const array& wt_,
int stride /* = 1 */,
int padding /* = 0 */,
int dilation /* = 1 */,
int groups /* = 1 */,
StreamOrDevice s /* = {} */) {
return conv_transpose_general(
in_, wt_, {stride}, {padding}, {dilation}, groups, s);
}
/** 2D transposed convolution with a filter */
array conv_transpose2d(
const array& in_,
const array& wt_,
const std::pair<int, int>& stride /* = {1, 1} */,
const std::pair<int, int>& padding /* = {0, 0} */,
const std::pair<int, int>& dilation /* = {1, 1} */,
int groups /* = 1 */,
StreamOrDevice s /* = {} */) {
return conv_transpose_general(
in_,
wt_,
{stride.first, stride.second},
{padding.first, padding.second},
{dilation.first, dilation.second},
groups,
s);
}
/** 3D transposed convolution with a filter */
array conv_transpose3d(
const array& in_,
const array& wt_,
const std::tuple<int, int, int>& stride /* = {1, 1, 1} */,
const std::tuple<int, int, int>& padding /* = {0, 0, 0} */,
const std::tuple<int, int, int>& dilation /* = {1, 1, 1} */,
int groups /* = 1 */,
StreamOrDevice s /* = {} */) {
return conv_transpose_general(
in_,
wt_,
{std::get<0>(stride), std::get<1>(stride), std::get<2>(stride)},
{std::get<0>(padding), std::get<1>(padding), std::get<2>(padding)},
{std::get<0>(dilation), std::get<1>(dilation), std::get<2>(dilation)},
groups,
s);
}
/** General convolution with a filter */ /** General convolution with a filter */
array conv_general( array conv_general(
array in, array in,

View File

@ -1247,6 +1247,36 @@ array conv3d(
int groups = 1, int groups = 1,
StreamOrDevice s = {}); StreamOrDevice s = {});
/** 1D transposed convolution with a filter */
array conv_transpose1d(
const array& input,
const array& weight,
int stride = 1,
int padding = 0,
int dilation = 1,
int groups = 1,
StreamOrDevice s = {});
/** 2D transposed convolution with a filter */
array conv_transpose2d(
const array& input,
const array& weight,
const std::pair<int, int>& stride = {1, 1},
const std::pair<int, int>& padding = {0, 0},
const std::pair<int, int>& dilation = {1, 1},
int groups = 1,
StreamOrDevice s = {});
/** 3D transposed convolution with a filter */
array conv_transpose3d(
const array& input,
const array& weight,
const std::tuple<int, int, int>& stride = {1, 1, 1},
const std::tuple<int, int, int>& padding = {0, 0, 0},
const std::tuple<int, int, int>& dilation = {1, 1, 1},
int groups = 1,
StreamOrDevice s = {});
/** Quantized matmul multiplies x with a quantized matrix w*/ /** Quantized matmul multiplies x with a quantized matrix w*/
array quantized_matmul( array quantized_matmul(
const array& x, const array& x,

View File

@ -952,8 +952,8 @@ std::vector<array> Convolution::vjp(
/* const array& input = */ cotan, /* const array& input = */ cotan,
/* const array& weight = */ wt_trans, /* const array& weight = */ wt_trans,
/* std::vector<int> stride = */ input_dilation_, /* std::vector<int> stride = */ input_dilation_,
/* std::vector<int> padding_lo = */ padding_lo_, /* std::vector<int> padding_lo = */ padding_lo,
/* std::vector<int> padding_hi = */ padding_hi_, /* std::vector<int> padding_hi = */ padding_hi,
/* std::vector<int> kernel_dilation = */ kernel_dilation_, /* std::vector<int> kernel_dilation = */ kernel_dilation_,
/* std::vector<int> input_dilation = */ kernel_strides_, /* std::vector<int> input_dilation = */ kernel_strides_,
/* int groups = */ 1, /* int groups = */ 1,
@ -990,10 +990,34 @@ std::vector<array> Convolution::vjp(
no_dilation &= (input_dilation_[i] == 1) && (kernel_dilation_[i] == 1); no_dilation &= (input_dilation_[i] == 1) && (kernel_dilation_[i] == 1);
} }
if (no_dilation) { if (no_dilation && !flip_) {
auto grad = conv_weight_backward_patches( auto grad = conv_weight_backward_patches(
in, wt, cotan, kernel_strides_, padding_, stream()); in, wt, cotan, kernel_strides_, padding_, stream());
grads.push_back(grad); grads.push_back(grad);
} else {
auto cotan_trans = swapaxes(cotan, 0, -1, stream());
auto in_trans = swapaxes(in, 0, -1, stream());
if (flip_) {
auto padding = padding_;
for (int i = 0; i < padding.size(); i++) {
int wt_size = 1 + kernel_dilation_[i] * (wt.shape(1 + i) - 1);
padding[i] = wt_size - padding_[i] - 1;
}
auto grad_trans = conv_general(
/* const array& input = */ cotan_trans,
/* const array& weight = */ in_trans,
/* std::vector<int> stride = */ kernel_dilation_,
/* std::vector<int> padding_lo = */ padding,
/* std::vector<int> padding_hi = */ padding,
/* std::vector<int> kernel_dilation = */ input_dilation_,
/* std::vector<int> input_dilation = */ kernel_strides_,
/* int groups = */ 1,
/* bool flip = */ false,
stream());
auto grad = swapaxes(grad_trans, 0, -1, stream());
grads.push_back(grad_trans);
} else { } else {
std::vector<int> padding_lo = padding_; std::vector<int> padding_lo = padding_;
std::vector<int> padding_hi = padding_; std::vector<int> padding_hi = padding_;
@ -1016,13 +1040,14 @@ std::vector<array> Convolution::vjp(
/* std::vector<int> kernel_dilation = */ kernel_strides_, /* std::vector<int> kernel_dilation = */ kernel_strides_,
/* std::vector<int> input_dilation = */ input_dilation_, /* std::vector<int> input_dilation = */ input_dilation_,
/* int groups = */ 1, /* int groups = */ 1,
/* bool flip = */ flip_, /* bool flip = */ false,
stream()); stream());
auto grad = swapaxes(grad_trans, 0, -1, stream()); auto grad = swapaxes(grad_trans, 0, -1, stream());
grads.push_back(grad); grads.push_back(grad);
} }
} }
} }
}
return grads; return grads;
} }

View File

@ -55,6 +55,11 @@ from mlx.nn.layers.activations import (
from mlx.nn.layers.base import Module from mlx.nn.layers.base import Module
from mlx.nn.layers.containers import Sequential from mlx.nn.layers.containers import Sequential
from mlx.nn.layers.convolution import Conv1d, Conv2d, Conv3d from mlx.nn.layers.convolution import Conv1d, Conv2d, Conv3d
from mlx.nn.layers.convolution_transpose import (
ConvTranspose1d,
ConvTranspose2d,
ConvTranspose3d,
)
from mlx.nn.layers.dropout import Dropout, Dropout2d, Dropout3d from mlx.nn.layers.dropout import Dropout, Dropout2d, Dropout3d
from mlx.nn.layers.embedding import Embedding from mlx.nn.layers.embedding import Embedding
from mlx.nn.layers.linear import Bilinear, Identity, Linear from mlx.nn.layers.linear import Bilinear, Identity, Linear

View File

@ -21,9 +21,9 @@ class Conv1d(Module):
out_channels (int): The number of output channels out_channels (int): The number of output channels
kernel_size (int): The size of the convolution filters kernel_size (int): The size of the convolution filters
stride (int, optional): The stride when applying the filter. stride (int, optional): The stride when applying the filter.
Default: 1. Default: ``1``.
padding (int, optional): How many positions to 0-pad the input with. padding (int, optional): How many positions to 0-pad the input with.
Default: 0. Default: ``0``.
dilation (int, optional): The dilation of the convolution. dilation (int, optional): The dilation of the convolution.
bias (bool, optional): If ``True`` add a learnable bias to the output. bias (bool, optional): If ``True`` add a learnable bias to the output.
Default: ``True`` Default: ``True``
@ -84,9 +84,9 @@ class Conv2d(Module):
out_channels (int): The number of output channels. out_channels (int): The number of output channels.
kernel_size (int or tuple): The size of the convolution filters. kernel_size (int or tuple): The size of the convolution filters.
stride (int or tuple, optional): The size of the stride when stride (int or tuple, optional): The size of the stride when
applying the filter. Default: 1. applying the filter. Default: ``1``.
padding (int or tuple, optional): How many positions to 0-pad padding (int or tuple, optional): How many positions to 0-pad
the input with. Default: 0. the input with. Default: ``0``.
dilation (int or tuple, optional): The dilation of the convolution. dilation (int or tuple, optional): The dilation of the convolution.
bias (bool, optional): If ``True`` add a learnable bias to the bias (bool, optional): If ``True`` add a learnable bias to the
output. Default: ``True`` output. Default: ``True``

View File

@ -0,0 +1,206 @@
# Copyright © 2023 Apple Inc.
import math
from typing import Union
import mlx.core as mx
from mlx.nn.layers.base import Module
class ConvTranspose1d(Module):
"""Applies a 1-dimensional transposed convolution over the multi-channel input sequence.
The channels are expected to be last i.e. the input shape should be ``NLC`` where:
* ``N`` is the batch dimension
* ``L`` is the sequence length
* ``C`` is the number of input channels
Args:
in_channels (int): The number of input channels
out_channels (int): The number of output channels
kernel_size (int): The size of the convolution filters
stride (int, optional): The stride when applying the filter.
Default: ``1``.
padding (int, optional): How many positions to 0-pad the input with.
Default: ``0``.
dilation (int, optional): The dilation of the convolution.
bias (bool, optional): If ``True`` add a learnable bias to the output.
Default: ``True``
"""
def __init__(
self,
in_channels: int,
out_channels: int,
kernel_size: int,
stride: int = 1,
padding: int = 0,
dilation: int = 1,
bias: bool = True,
):
super().__init__()
scale = math.sqrt(1 / (in_channels * kernel_size))
self.weight = mx.random.uniform(
low=-scale,
high=scale,
shape=(out_channels, kernel_size, in_channels),
)
if bias:
self.bias = mx.zeros((out_channels,))
self.padding = padding
self.dilation = dilation
self.stride = stride
def _extra_repr(self):
return (
f"{self.weight.shape[-1]}, {self.weight.shape[0]}, "
f"kernel_size={self.weight.shape[1]}, stride={self.stride}, "
f"padding={self.padding}, dilation={self.dilation}, "
f"bias={'bias' in self}"
)
def __call__(self, x):
y = mx.conv_transpose1d(
x, self.weight, self.stride, self.padding, self.dilation
)
if "bias" in self:
y = y + self.bias
return y
class ConvTranspose2d(Module):
"""Applies a 2-dimensional transposed convolution over the multi-channel input image.
The channels are expected to be last i.e. the input shape should be ``NHWC`` where:
* ``N`` is the batch dimension
* ``H`` is the input image height
* ``W`` is the input image width
* ``C`` is the number of input channels
Args:
in_channels (int): The number of input channels.
out_channels (int): The number of output channels.
kernel_size (int or tuple): The size of the convolution filters.
stride (int or tuple, optional): The size of the stride when
applying the filter. Default: ``1``.
padding (int or tuple, optional): How many positions to 0-pad
the input with. Default: ``0``.
dilation (int or tuple, optional): The dilation of the convolution.
bias (bool, optional): If ``True`` add a learnable bias to the
output. Default: ``True``
"""
def __init__(
self,
in_channels: int,
out_channels: int,
kernel_size: Union[int, tuple],
stride: Union[int, tuple] = 1,
padding: Union[int, tuple] = 0,
dilation: Union[int, tuple] = 1,
bias: bool = True,
):
super().__init__()
kernel_size, stride, padding = map(
lambda x: (x, x) if isinstance(x, int) else x,
(kernel_size, stride, padding),
)
scale = math.sqrt(1 / (in_channels * kernel_size[0] * kernel_size[1]))
self.weight = mx.random.uniform(
low=-scale,
high=scale,
shape=(out_channels, *kernel_size, in_channels),
)
if bias:
self.bias = mx.zeros((out_channels,))
self.padding = padding
self.stride = stride
self.dilation = dilation
def _extra_repr(self):
return (
f"{self.weight.shape[-1]}, {self.weight.shape[0]}, "
f"kernel_size={self.weight.shape[1:2]}, stride={self.stride}, "
f"padding={self.padding}, dilation={self.dilation}, "
f"bias={'bias' in self}"
)
def __call__(self, x):
y = mx.conv_transpose2d(
x, self.weight, self.stride, self.padding, self.dilation
)
if "bias" in self:
y = y + self.bias
return y
class ConvTranspose3d(Module):
"""Applies a 3-dimensional transposed convolution over the multi-channel input image.
The channels are expected to be last i.e. the input shape should be ``NDHWC`` where:
* ``N`` is the batch dimension
* ``D`` is the input image depth
* ``H`` is the input image height
* ``W`` is the input image width
* ``C`` is the number of input channels
Args:
in_channels (int): The number of input channels.
out_channels (int): The number of output channels.
kernel_size (int or tuple): The size of the convolution filters.
stride (int or tuple, optional): The size of the stride when
applying the filter. Default: ``1``.
padding (int or tuple, optional): How many positions to 0-pad
the input with. Default: ``0``.
bias (bool, optional): If ``True`` add a learnable bias to the
output. Default: ``True``
"""
def __init__(
self,
in_channels: int,
out_channels: int,
kernel_size: Union[int, tuple],
stride: Union[int, tuple] = 1,
padding: Union[int, tuple] = 0,
bias: bool = True,
):
super().__init__()
kernel_size, stride, padding = map(
lambda x: (x, x, x) if isinstance(x, int) else x,
(kernel_size, stride, padding),
)
scale = math.sqrt(
1 / (in_channels * kernel_size[0] * kernel_size[1] * kernel_size[2])
)
self.weight = mx.random.uniform(
low=-scale,
high=scale,
shape=(out_channels, *kernel_size, in_channels),
)
if bias:
self.bias = mx.zeros((out_channels,))
self.padding = padding
self.stride = stride
def _extra_repr(self):
return (
f"{self.weight.shape[-1]}, {self.weight.shape[0]}, "
f"kernel_size={self.weight.shape[1:3]}, stride={self.stride}, "
f"padding={self.padding}, bias={'bias' in self}"
)
def __call__(self, x):
y = mx.conv_transpose3d(x, self.weight, self.stride, self.padding)
if "bias" in self:
y = y + self.bias
return y

View File

@ -3238,12 +3238,12 @@ void init_ops(nb::module_& m) {
1D convolution over an input with several channels 1D convolution over an input with several channels
Args: Args:
input (array): input array of shape (``N``, ``H``, ``C_in``) input (array): Input array of shape ``(N, H, C_in)``.
weight (array): weight array of shape (``C_out``, ``H``, ``C_in``) weight (array): Weight array of shape ``(C_out, H, C_in)``.
stride (int, optional): kernel stride. Default: ``1``. stride (int, optional): Kernel stride. Default: ``1``.
padding (int, optional): input padding. Default: ``0``. padding (int, optional): Input padding. Default: ``0``.
dilation (int, optional): kernel dilation. Default: ``1``. dilation (int, optional): Kernel dilation. Default: ``1``.
groups (int, optional): input feature groups. Default: ``1``. groups (int, optional): Input feature groups. Default: ``1``.
Returns: Returns:
array: The convolved array. array: The convolved array.
@ -3296,8 +3296,8 @@ void init_ops(nb::module_& m) {
2D convolution over an input with several channels 2D convolution over an input with several channels
Args: Args:
input (array): input array of shape ``(N, H, W, C_in)`` input (array): Input array of shape ``(N, H, W, C_in)``.
weight (array): weight array of shape ``(C_out, H, W, C_in)`` weight (array): Weight array of shape ``(C_out, H, W, C_in)``.
stride (int or tuple(int), optional): :obj:`tuple` of size 2 with stride (int or tuple(int), optional): :obj:`tuple` of size 2 with
kernel strides. All spatial dimensions get the same stride if kernel strides. All spatial dimensions get the same stride if
only one number is specified. Default: ``1``. only one number is specified. Default: ``1``.
@ -3368,8 +3368,173 @@ void init_ops(nb::module_& m) {
Note: Only the default ``groups=1`` is currently supported. Note: Only the default ``groups=1`` is currently supported.
Args: Args:
input (array): input array of shape ``(N, D, H, W, C_in)`` input (array): Input array of shape ``(N, D, H, W, C_in)``.
weight (array): weight array of shape ``(C_out, D, H, W, C_in)`` weight (array): Weight array of shape ``(C_out, D, H, W, C_in)``.
stride (int or tuple(int), optional): :obj:`tuple` of size 3 with
kernel strides. All spatial dimensions get the same stride if
only one number is specified. Default: ``1``.
padding (int or tuple(int), optional): :obj:`tuple` of size 3 with
symmetric input padding. All spatial dimensions get the same
padding if only one number is specified. Default: ``0``.
dilation (int or tuple(int), optional): :obj:`tuple` of size 3 with
kernel dilation. All spatial dimensions get the same dilation
if only one number is specified. Default: ``1``
groups (int, optional): input feature groups. Default: ``1``.
Returns:
array: The convolved array.
)pbdoc");
m.def(
"conv_transpose1d",
&conv_transpose1d,
nb::arg(),
nb::arg(),
"stride"_a = 1,
"padding"_a = 0,
"dilation"_a = 1,
"groups"_a = 1,
nb::kw_only(),
"stream"_a = nb::none(),
nb::sig(
"def conv_transpose1d(input: array, weight: array, /, stride: int = 1, padding: int = 0, dilation: int = 1, groups: int = 1, *, stream: Union[None, Stream, Device] = None) -> array"),
R"pbdoc(
1D transposed convolution over an input with several channels
Args:
input (array): Input array of shape ``(N, H, C_in)``.
weight (array): Weight array of shape ``(C_out, H, C_in)``.
stride (int, optional): Kernel stride. Default: ``1``.
padding (int, optional): Input padding. Default: ``0``.
dilation (int, optional): Kernel dilation. Default: ``1``.
groups (int, optional): Input feature groups. Default: ``1``.
Returns:
array: The convolved array.
)pbdoc");
m.def(
"conv_transpose2d",
[](const array& input,
const array& weight,
const std::variant<int, std::pair<int, int>>& stride,
const std::variant<int, std::pair<int, int>>& padding,
const std::variant<int, std::pair<int, int>>& dilation,
int groups,
StreamOrDevice s) {
std::pair<int, int> stride_pair{1, 1};
std::pair<int, int> padding_pair{0, 0};
std::pair<int, int> dilation_pair{1, 1};
if (auto pv = std::get_if<int>(&stride); pv) {
stride_pair = std::pair<int, int>{*pv, *pv};
} else {
stride_pair = std::get<std::pair<int, int>>(stride);
}
if (auto pv = std::get_if<int>(&padding); pv) {
padding_pair = std::pair<int, int>{*pv, *pv};
} else {
padding_pair = std::get<std::pair<int, int>>(padding);
}
if (auto pv = std::get_if<int>(&dilation); pv) {
dilation_pair = std::pair<int, int>{*pv, *pv};
} else {
dilation_pair = std::get<std::pair<int, int>>(dilation);
}
return conv_transpose2d(
input, weight, stride_pair, padding_pair, dilation_pair, groups, s);
},
nb::arg(),
nb::arg(),
"stride"_a = 1,
"padding"_a = 0,
"dilation"_a = 1,
"groups"_a = 1,
nb::kw_only(),
"stream"_a = nb::none(),
nb::sig(
"def conv_transpose2d(input: array, weight: array, /, stride: Union[int, Tuple[int, int]] = 1, padding: Union[int, Tuple[int, int]] = 0, dilation: Union[int, Tuple[int, int]] = 1, groups: int = 1, *, stream: Union[None, Stream, Device] = None) -> array"),
R"pbdoc(
2D transposed convolution over an input with several channels
Note: Only the default ``groups=1`` is currently supported.
Args:
input (array): Input array of shape ``(N, H, W, C_in)``.
weight (array): Weight array of shape ``(C_out, H, W, C_in)``.
stride (int or tuple(int), optional): :obj:`tuple` of size 2 with
kernel strides. All spatial dimensions get the same stride if
only one number is specified. Default: ``1``.
padding (int or tuple(int), optional): :obj:`tuple` of size 2 with
symmetric input padding. All spatial dimensions get the same
padding if only one number is specified. Default: ``0``.
dilation (int or tuple(int), optional): :obj:`tuple` of size 2 with
kernel dilation. All spatial dimensions get the same dilation
if only one number is specified. Default: ``1``
groups (int, optional): input feature groups. Default: ``1``.
Returns:
array: The convolved array.
)pbdoc");
m.def(
"conv_transpose3d",
[](const array& input,
const array& weight,
const std::variant<int, std::tuple<int, int, int>>& stride,
const std::variant<int, std::tuple<int, int, int>>& padding,
const std::variant<int, std::tuple<int, int, int>>& dilation,
int groups,
StreamOrDevice s) {
std::tuple<int, int, int> stride_tuple{1, 1, 1};
std::tuple<int, int, int> padding_tuple{0, 0, 0};
std::tuple<int, int, int> dilation_tuple{1, 1, 1};
if (auto pv = std::get_if<int>(&stride); pv) {
stride_tuple = std::tuple<int, int, int>{*pv, *pv, *pv};
} else {
stride_tuple = std::get<std::tuple<int, int, int>>(stride);
}
if (auto pv = std::get_if<int>(&padding); pv) {
padding_tuple = std::tuple<int, int, int>{*pv, *pv, *pv};
} else {
padding_tuple = std::get<std::tuple<int, int, int>>(padding);
}
if (auto pv = std::get_if<int>(&dilation); pv) {
dilation_tuple = std::tuple<int, int, int>{*pv, *pv, *pv};
} else {
dilation_tuple = std::get<std::tuple<int, int, int>>(dilation);
}
return conv_transpose3d(
input,
weight,
stride_tuple,
padding_tuple,
dilation_tuple,
groups,
s);
},
nb::arg(),
nb::arg(),
"stride"_a = 1,
"padding"_a = 0,
"dilation"_a = 1,
"groups"_a = 1,
nb::kw_only(),
"stream"_a = nb::none(),
nb::sig(
"def conv_transpose3d(input: array, weight: array, /, stride: Union[int, Tuple[int, int, int]] = 1, padding: Union[int, Tuple[int, int, int]] = 0, dilation: Union[int, Tuple[int, int, int]] = 1, groups: int = 1, *, stream: Union[None, Stream, Device] = None) -> array"),
R"pbdoc(
3D transposed convolution over an input with several channels
Note: Only the default ``groups=1`` is currently supported.
Args:
input (array): Input array of shape ``(N, D, H, W, C_in)``.
weight (array): Weight array of shape ``(C_out, D, H, W, C_in)``.
stride (int or tuple(int), optional): :obj:`tuple` of size 3 with stride (int or tuple(int), optional): :obj:`tuple` of size 3 with
kernel strides. All spatial dimensions get the same stride if kernel strides. All spatial dimensions get the same stride if
only one number is specified. Default: ``1``. only one number is specified. Default: ``1``.
@ -3465,8 +3630,8 @@ void init_ops(nb::module_& m) {
General convolution over an input with several channels General convolution over an input with several channels
Args: Args:
input (array): Input array of shape ``(N, ..., C_in)`` input (array): Input array of shape ``(N, ..., C_in)``.
weight (array): Weight array of shape ``(C_out, ..., C_in)`` weight (array): Weight array of shape ``(C_out, ..., C_in)``.
stride (int or list(int), optional): :obj:`list` with kernel strides. stride (int or list(int), optional): :obj:`list` with kernel strides.
All spatial dimensions get the same stride if All spatial dimensions get the same stride if
only one number is specified. Default: ``1``. only one number is specified. Default: ``1``.

View File

@ -866,6 +866,37 @@ class TestConv(mlx_tests.MLXTestCase):
flip=flip, flip=flip,
) )
def test_conv_general_flip_grad(self):
for s in (1, 2):
w = mx.random.normal(shape=(1, 2, 2, 1))
x = mx.random.normal(shape=(1, 2, 2, 1))
def conv_t(w):
return mx.conv_general(
x,
w,
stride=1,
padding=(1, 1),
kernel_dilation=1,
input_dilation=s,
flip=True,
)
cotan = mx.random.normal(shape=(1, 2 + s, 2 + s, 1))
dw = mx.vjp(conv_t, (w,), (cotan,))[1][0]
x = x.squeeze()
cotan = cotan.squeeze()
dw = dw.squeeze()
dw00 = (cotan[:-1:s, :-1:s] * x).sum()
dw01 = (cotan[:-1:s, 1::s] * x).sum()
dw10 = (cotan[1::s, :-1:s] * x).sum()
dw11 = (cotan[1::s, 1::s] * x).sum()
expected = mx.array([[dw00, dw01], [dw10, dw11]])
self.assertTrue(mx.allclose(dw, expected))
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()

View File

@ -0,0 +1,601 @@
# Copyright © 2023-2024 Apple Inc.
import math
import unittest
from itertools import permutations
import mlx.core as mx
import mlx_tests
import numpy as np
try:
import torch
import torch.nn.functional as F
has_torch = True
except ImportError as e:
has_torch = False
class TestConvTranspose(mlx_tests.MLXTestCase):
@unittest.skipIf(not has_torch, "requires Torch")
def test_torch_conv_transpose_1D(self):
def run_conv_transpose_1D(
N,
C,
O,
iH,
kH,
stride,
padding,
output_padding=0,
dilation=1,
groups=1,
dtype="float32",
atol=1e-5,
):
with self.subTest(
dtype=dtype,
N=N,
C=C,
O=O,
iH=iH,
kH=kH,
stride=stride,
padding=padding,
dilation=dilation,
groups=groups,
):
np_dtype = getattr(np, dtype)
np.random.seed(0)
in_np = np.random.normal(0, 1.0 / C, (N, iH, C)).astype(np_dtype)
wt_np = np.random.normal(0, 1.0 / C, (O, kH, int(C / groups))).astype(
np_dtype
)
in_mx, wt_mx = map(mx.array, (in_np, wt_np))
in_pt = torch.from_numpy(in_np.transpose(0, 2, 1))
wt_pt = torch.from_numpy(wt_np.transpose(2, 0, 1))
out_mx = mx.conv_transpose1d(
in_mx,
wt_mx,
stride=stride,
padding=padding,
dilation=dilation,
groups=groups,
)
out_pt = torch.conv_transpose1d(
in_pt,
wt_pt,
stride=stride,
padding=padding,
dilation=dilation,
groups=groups,
)
out_pt = torch.transpose(out_pt, 2, 1)
self.assertEqual(out_pt.shape, out_mx.shape)
self.assertTrue(np.allclose(out_pt.numpy(), out_mx, atol=atol))
for dtype in ("float32",):
for N, C, O in (
(1, 1, 1),
(1, 6, 1),
(1, 1, 6),
(4, 32, 64),
):
for iH, kH, stride, padding in (
(1, 1, 1, 0),
(3, 3, 1, 0),
(31, 5, 5, 2),
):
run_conv_transpose_1D(N, C, O, iH, kH, stride, padding, dtype=dtype)
# Groups tests
N, C, O = (4, 32, 64)
for iH, kH, stride, padding in (
(1, 1, 1, 0),
(3, 3, 1, 0),
(31, 5, 5, 2),
):
for group in (1,):
run_conv_transpose_1D(
N, C, O, iH, kH, stride, padding, groups=group, dtype=dtype
)
# Strided inputs tests
for tpose_in, tpose_wt in (
((0, 2, 1), (0, 1, 2)),
((0, 2, 1), (0, 2, 1)),
):
with self.subTest(name="strided", tpose_in=tpose_in, tpose_wt=tpose_wt):
in_np = np.random.normal(0, 1.0 / 16, (16, 16, 16)).astype(np.float32)
wt_np = np.random.normal(0, 1.0 / 16, (16, 16, 16)).astype(np.float32)
in_mx, wt_mx = map(mx.array, (in_np, wt_np))
in_mx_t = mx.transpose(in_mx, tpose_in)
wt_mx_t = mx.transpose(wt_mx, tpose_wt)
out_mx = mx.conv_transpose1d(in_mx_t, wt_mx_t)
in_pt = torch.from_numpy(in_np.transpose(tpose_in).transpose(0, 2, 1))
wt_pt = torch.from_numpy(wt_np.transpose(tpose_wt).transpose(2, 0, 1))
out_pt = torch.conv_transpose1d(in_pt, wt_pt)
out_pt = torch.transpose(out_pt, 2, 1)
self.assertEqual(out_pt.shape, out_mx.shape)
self.assertTrue(np.allclose(out_pt.numpy(), out_mx, atol=1e-5))
@unittest.skipIf(not has_torch, "requires Torch")
def test_torch_conv_transpose_1D_grad(self):
def run_conv_transpose1D_grad(
N,
C,
O,
iH,
kH,
stride,
padding,
dilation=1,
groups=1,
dtype="float32",
atol=1e-5,
):
with self.subTest(
dtype=dtype,
N=N,
C=C,
O=O,
iH=iH,
kH=kH,
stride=stride,
padding=padding,
dilation=dilation,
groups=groups,
):
np_dtype = getattr(np, dtype)
np.random.seed(0)
# oH = 1 + ((iH + 2 * padding - dilation * (kH - 1) - 1) // stride)
in_np = np.random.normal(0, 1.0 / C, (N, iH, C)).astype(np_dtype)
wt_np = np.random.normal(0, 1.0 / C, (O, kH, C)).astype(np_dtype)
in_mx, wt_mx = map(mx.array, (in_np, wt_np))
in_pt = torch.from_numpy(in_np.transpose(0, 2, 1)).requires_grad_(True)
wt_pt = torch.from_numpy(wt_np.transpose(2, 0, 1)).requires_grad_(True)
out_pt = F.conv_transpose1d(
in_pt, wt_pt, stride=stride, padding=padding, dilation=dilation
)
# use torch to compute ct
out_pt.retain_grad()
(out_pt - torch.randn_like(out_pt)).abs().sum().backward()
pt_grad_in = in_pt.grad.permute(0, 2, 1).numpy()
pt_grad_wt = wt_pt.grad.permute(1, 2, 0).numpy()
ct_mx = mx.array(out_pt.grad.numpy().transpose(0, 2, 1))
def f(a, b):
return mx.conv_transpose1d(
a,
b,
stride=stride,
padding=padding,
dilation=dilation,
groups=groups,
)
_, outs_mx = mx.vjp(
f,
[
in_mx,
wt_mx,
],
[
ct_mx,
],
)
mx_grad_in, mx_grad_wt = outs_mx
self.assertEqual(pt_grad_in.shape, mx_grad_in.shape)
self.assertEqual(in_mx.shape, mx_grad_in.shape)
self.assertTrue(np.allclose(pt_grad_in, mx_grad_in, atol=atol))
self.assertEqual(pt_grad_wt.shape, mx_grad_wt.shape)
self.assertEqual(wt_mx.shape, mx_grad_wt.shape)
self.assertTrue(np.allclose(pt_grad_wt, mx_grad_wt, atol=atol))
for dtype in ("float32",):
for N, C, O in (
(1, 1, 1),
(1, 6, 1),
(1, 1, 6),
(4, 32, 64),
):
for iH, kH, stride, padding in (
(1, 1, 1, 0),
(3, 3, 1, 0),
(31, 5, 5, 2),
):
run_conv_transpose1D_grad(
N, C, O, iH, kH, stride, padding, dtype=dtype
)
@unittest.skipIf(not has_torch, "requires Torch")
def test_torch_conv_transpose_2D(self):
def run_conv_transpose2D(
N,
C,
O,
idim,
kdim,
stride,
padding,
dilation=(1, 1),
groups=1,
dtype="float32",
atol=1e-5,
):
with self.subTest(
dtype=dtype,
N=N,
C=C,
O=O,
idim=idim,
kdim=kdim,
stride=stride,
padding=padding,
dilation=dilation,
groups=groups,
):
np_dtype = getattr(np, dtype)
np.random.seed(0)
iH, iW = idim
kH, kW = kdim
scale = 1.0 / math.sqrt(kH * kW * C)
in_np = np.random.normal(0.0, scale, (N, iH, iW, C)).astype(np_dtype)
wt_np = np.random.normal(0.0, 1.0, (O, kH, kW, int(C / groups))).astype(
np_dtype
)
in_mx, wt_mx = map(mx.array, (in_np, wt_np))
in_pt = torch.from_numpy(in_np.transpose(0, 3, 1, 2)).to("cpu")
wt_pt = torch.from_numpy(wt_np.transpose(3, 0, 1, 2)).to("cpu")
out_mx = mx.conv_transpose2d(
in_mx,
wt_mx,
stride=stride,
padding=padding,
dilation=dilation,
groups=groups,
)
out_pt = torch.conv_transpose2d(
in_pt,
wt_pt,
stride=stride,
padding=padding,
dilation=dilation,
groups=groups,
)
out_pt = torch.permute(out_pt, (0, 2, 3, 1)).numpy(force=True)
self.assertEqual(out_pt.shape, out_mx.shape)
self.assertTrue(np.allclose(out_pt, out_mx, atol=atol))
for dtype in ("float32",):
for N, C, O in (
(1, 1, 1),
(1, 6, 1),
(1, 1, 6),
(4, 32, 64),
):
for idim, kdim, stride, padding in (
((1, 1), (1, 1), (1, 1), (0, 0)),
((3, 3), (3, 1), (1, 1), (0, 0)),
((31, 31), (5, 5), (5, 5), (2, 2)),
):
run_conv_transpose2D(
N, C, O, idim, kdim, stride, padding, dtype=dtype
)
# Groups tests
N, C, O = (4, 32, 64)
for idim, kdim, stride, padding in (
((1, 1), (1, 1), (1, 1), (0, 0)),
((3, 3), (3, 1), (1, 1), (0, 0)),
((31, 31), (5, 5), (5, 5), (2, 2)),
):
for group in (1,):
run_conv_transpose2D(
N, C, O, idim, kdim, stride, padding, groups=group, dtype=dtype
)
@unittest.skipIf(not has_torch, "requires Torch")
def test_torch_conv_transpose_2D_grad(self):
def run_conv_transpose2D_grad(
N,
C,
O,
idim,
kdim,
stride,
padding,
dilation=(1, 1),
groups=1,
dtype="float32",
atol=1e-5,
):
with self.subTest(
dtype=dtype,
N=N,
C=C,
O=O,
idim=idim,
kdim=kdim,
stride=stride,
padding=padding,
dilation=dilation,
groups=groups,
):
np_dtype = getattr(np, dtype)
np.random.seed(0)
iH, iW = idim
kH, kW = kdim
scale = 1.0 / math.sqrt(kH * kW * C * O)
in_np = np.random.normal(0.0, scale, (N, iH, iW, C)).astype(np_dtype)
wt_np = np.random.normal(0.0, scale, (O, kH, kW, C)).astype(np_dtype)
in_mx, wt_mx = map(mx.array, (in_np, wt_np))
in_pt = torch.from_numpy(in_np.transpose(0, 3, 1, 2)).requires_grad_(
True
)
wt_pt = torch.from_numpy(wt_np.transpose(3, 0, 1, 2)).requires_grad_(
True
)
out_pt = F.conv_transpose2d(
in_pt, wt_pt, stride=stride, padding=padding, dilation=dilation
)
# use torch to compute ct
out_pt.retain_grad()
(out_pt - torch.randn_like(out_pt)).abs().sum().backward()
pt_grad_in = in_pt.grad.permute(0, 2, 3, 1).numpy()
pt_grad_wt = wt_pt.grad.permute(1, 2, 3, 0).numpy()
ct_mx = mx.array(out_pt.grad.numpy().transpose(0, 2, 3, 1))
def f(a, b):
return mx.conv_transpose2d(
a,
b,
stride=stride,
padding=padding,
dilation=dilation,
groups=groups,
)
_, outs_mx = mx.vjp(
f,
[in_mx, wt_mx],
[ct_mx],
)
mx_grad_in, mx_grad_wt = outs_mx
self.assertEqual(pt_grad_in.shape, mx_grad_in.shape)
self.assertEqual(in_mx.shape, mx_grad_in.shape)
self.assertTrue(np.allclose(pt_grad_in, mx_grad_in, atol=atol))
self.assertEqual(pt_grad_wt.shape, mx_grad_wt.shape)
self.assertEqual(wt_mx.shape, mx_grad_wt.shape)
self.assertTrue(np.allclose(pt_grad_wt, mx_grad_wt, atol=atol))
for dtype in ("float32",):
for N, C, O in ((1, 1, 1), (1, 6, 1), (1, 1, 6), (4, 32, 64), (4, 16, 32)):
for idim, kdim, stride, padding, dilation in (
((1, 1), (1, 1), (1, 1), (0, 0), (1, 1)),
((3, 3), (3, 1), (1, 1), (0, 0), (1, 1)),
((31, 31), (5, 5), (5, 5), (2, 2), (1, 1)),
((32, 32), (3, 3), (2, 2), (1, 1), (1, 1)),
((31, 31), (5, 5), (5, 5), (2, 2), (3, 2)),
((32, 32), (3, 3), (2, 2), (1, 1), (3, 2)),
):
run_conv_transpose2D_grad(
N, C, O, idim, kdim, stride, padding, dilation, dtype=dtype
)
@unittest.skipIf(not has_torch, "requires Torch")
def test_torch_conv_transpose_3D(self):
def run_conv_transpose3D(
N,
C,
O,
idim,
kdim,
stride,
padding,
dilation=(1, 1, 1),
groups=1,
dtype="float32",
atol=1e-5,
):
with self.subTest(
dtype=dtype,
N=N,
C=C,
O=O,
idim=idim,
kdim=kdim,
stride=stride,
padding=padding,
dilation=dilation,
groups=groups,
):
np_dtype = getattr(np, dtype)
np.random.seed(0)
iD, iH, iW = idim
kD, kH, kW = kdim
scale = 1.0 / math.sqrt(kD * kH * kW * C * O)
in_np = np.random.normal(0.0, scale, (N, iD, iH, iW, C)).astype(
np_dtype
)
wt_np = np.random.normal(0.0, 1.0, (O, kD, kH, kW, C)).astype(np_dtype)
in_mx, wt_mx = map(mx.array, (in_np, wt_np))
in_pt = torch.from_numpy(in_np.transpose(0, 4, 1, 2, 3))
wt_pt = torch.from_numpy(wt_np.transpose(4, 0, 1, 2, 3))
out_mx = mx.conv_transpose3d(
in_mx,
wt_mx,
stride=stride,
padding=padding,
dilation=dilation,
groups=groups,
)
out_pt = torch.conv_transpose3d(
in_pt,
wt_pt,
stride=stride,
padding=padding,
dilation=dilation,
groups=groups,
)
out_pt = torch.permute(out_pt, (0, 2, 3, 4, 1)).numpy(force=True)
self.assertEqual(out_pt.shape, out_mx.shape)
self.assertTrue(np.allclose(out_pt, out_mx, atol=atol))
for dtype in ("float32",):
for N, C, O in (
(1, 1, 1),
(1, 6, 1),
(1, 1, 6),
(2, 8, 16),
):
for idim, kdim, stride, padding in (
((1, 1, 1), (1, 1, 1), (1, 1, 1), (0, 0, 0)),
((3, 3, 3), (3, 1, 1), (1, 1, 1), (0, 0, 0)),
((15, 15, 15), (3, 3, 3), (3, 3, 3), (2, 2, 2)),
):
run_conv_transpose3D(
N, C, O, idim, kdim, stride, padding, dtype=dtype
)
@unittest.skipIf(not has_torch, "requires Torch")
def test_torch_conv_transpose_3D_grad(self):
def run_conv_transpose3D_grad(
N,
C,
O,
idim,
kdim,
stride,
padding,
dilation=(1, 1, 1),
groups=1,
dtype="float32",
atol=1e-4,
):
with self.subTest(
dtype=dtype,
N=N,
C=C,
O=O,
idim=idim,
kdim=kdim,
stride=stride,
padding=padding,
dilation=dilation,
groups=groups,
):
np_dtype = getattr(np, dtype)
np.random.seed(0)
iD, iH, iW = idim
kD, kH, kW = kdim
scale = 1.0 / math.sqrt(kD * kH * kW * C * O)
in_np = np.random.normal(0.0, scale, (N, iD, iH, iW, C)).astype(
np_dtype
)
wt_np = np.random.normal(0.0, scale, (O, kD, kH, kW, C)).astype(
np_dtype
)
in_mx, wt_mx = map(mx.array, (in_np, wt_np))
in_pt = torch.from_numpy(in_np.transpose(0, 4, 1, 2, 3)).requires_grad_(
True
)
wt_pt = torch.from_numpy(wt_np.transpose(4, 0, 1, 2, 3)).requires_grad_(
True
)
out_pt = F.conv_transpose3d(
in_pt,
wt_pt,
stride=stride,
padding=padding,
dilation=dilation,
groups=groups,
)
# use torch to compute ct
out_pt.retain_grad()
(out_pt - torch.randn_like(out_pt)).abs().sum().backward()
pt_grad_in = in_pt.grad.permute(0, 2, 3, 4, 1).numpy()
pt_grad_wt = wt_pt.grad.permute(1, 2, 3, 4, 0).numpy()
ct_mx = mx.array(out_pt.grad.numpy().transpose(0, 2, 3, 4, 1))
def f(a, b):
return mx.conv_transpose3d(
a,
b,
stride=stride,
padding=padding,
dilation=dilation,
groups=groups,
)
_, outs_mx = mx.vjp(
f,
[in_mx, wt_mx],
[ct_mx],
)
mx_grad_in, mx_grad_wt = outs_mx
self.assertEqual(pt_grad_in.shape, mx_grad_in.shape)
self.assertEqual(in_mx.shape, mx_grad_in.shape)
self.assertTrue(np.allclose(pt_grad_in, mx_grad_in, atol=atol))
self.assertEqual(pt_grad_wt.shape, mx_grad_wt.shape)
self.assertEqual(wt_mx.shape, mx_grad_wt.shape)
self.assertTrue(np.allclose(pt_grad_wt, mx_grad_wt, atol=atol))
for dtype in ("float32",):
for N, C, O in ((1, 1, 1), (1, 6, 1), (1, 1, 6), (2, 4, 8), (2, 8, 16)):
for idim, kdim, stride, padding, dilation in (
((1, 1, 1), (1, 1, 1), (1, 1, 1), (0, 0, 0), (1, 1, 1)),
((3, 3, 3), (3, 1, 1), (1, 1, 1), (0, 0, 0), (1, 1, 1)),
((15, 15, 15), (5, 5, 5), (5, 5, 5), (2, 2, 2), (1, 1, 1)),
((16, 16, 16), (3, 3, 3), (2, 2, 2), (1, 1, 1), (1, 1, 1)),
((15, 15, 15), (5, 5, 5), (5, 5, 5), (2, 2, 2), (3, 2, 2)),
((16, 16, 16), (3, 3, 3), (2, 2, 2), (1, 1, 1), (3, 2, 2)),
):
run_conv_transpose3D_grad(
N, C, O, idim, kdim, stride, padding, dilation, dtype=dtype
)
if __name__ == "__main__":
unittest.main()