From efeb9c0f02ba0f3ec734f458e85046506e624dde Mon Sep 17 00:00:00 2001 From: Max-Heinrich Laves <8014859+mlaves@users.noreply.github.com> Date: Sat, 7 Sep 2024 04:52:38 +0200 Subject: [PATCH] Transposed Convolution (#1245) * initial implementation for conv_transpose ran pre-commit implemented conv_transpose updated conv_general docstring updated conv_general docstring updated code comments removed commented run_conv_checks updated acknowledgments added missing entry to ops.rst added op to nn.layers resolved merge conflicts * removed ConvolutionTranspose primitive as suggested by reviewer removed ConvolutionTranspose primitive as suggested by reviewer * remove transpose flag, add another test --------- Co-authored-by: Awni Hannun --- ACKNOWLEDGMENTS.md | 1 + benchmarks/python/conv_transpose_bench.py | 135 ++++ docs/src/python/nn/layers.rst | 3 + docs/src/python/ops.rst | 3 + mlx/backend/common/conv.cpp | 2 +- mlx/backend/metal/conv.cpp | 2 +- mlx/ops.cpp | 87 +++ mlx/ops.h | 30 + mlx/primitives.cpp | 79 ++- python/mlx/nn/layers/__init__.py | 5 + python/mlx/nn/layers/convolution.py | 8 +- python/mlx/nn/layers/convolution_transpose.py | 206 ++++++ python/src/ops.cpp | 189 +++++- python/tests/test_conv.py | 31 + python/tests/test_conv_transpose.py | 601 ++++++++++++++++++ 15 files changed, 1337 insertions(+), 45 deletions(-) create mode 100644 benchmarks/python/conv_transpose_bench.py create mode 100644 python/mlx/nn/layers/convolution_transpose.py create mode 100644 python/tests/test_conv_transpose.py diff --git a/ACKNOWLEDGMENTS.md b/ACKNOWLEDGMENTS.md index 265eca97f..f406c36bb 100644 --- a/ACKNOWLEDGMENTS.md +++ b/ACKNOWLEDGMENTS.md @@ -18,6 +18,7 @@ MLX was developed with contributions from the following individuals: - AmirHossein Razlighi: Added chaining support for some of the ops in `nn.Module`. Comparison works for non array objects in `mlx.core.array`. Exception handling for invalid operations in `mlx.core.array`. - Gleb Pobudzey: Added the `where` primitive, and groups in 1D and 2D convolutions. - Paul Paczuski: Improved stability of BCE loss calculation +- Max-Heinrich Laves: Added `conv_transpose1d`, `conv_transpose2d`, and `conv_transpose3d` ops. diff --git a/benchmarks/python/conv_transpose_bench.py b/benchmarks/python/conv_transpose_bench.py new file mode 100644 index 000000000..22522d8ed --- /dev/null +++ b/benchmarks/python/conv_transpose_bench.py @@ -0,0 +1,135 @@ +import argparse +import math +import os +import subprocess +import time + +import mlx.core as mx +import numpy as np +import torch + +N_warmup = 10 +N_iter_bench = 100 +N_iter_func = 5 + + +def bench(f, a, b): + for i in range(N_warmup): + f(a, b) + torch.mps.synchronize() + + s = time.perf_counter_ns() + for i in range(N_iter_bench): + f(a, b) + e = time.perf_counter_ns() + return (e - s) * 1e-9 + + +def make_mx_conv_transpose_2D(strides=(1, 1), padding=(0, 0), groups=1): + def mx_conv_transpose_2D(a, b): + ys = [] + for i in range(N_iter_func): + y = mx.conv_transpose2d( + a, b, stride=strides, padding=padding, groups=groups + ) + ys.append(y) + mx.eval(ys) + return ys + + return mx_conv_transpose_2D + + +def make_pt_conv_transpose_2D(strides=(1, 1), padding=(0, 0), groups=1): + @torch.no_grad() + def pt_conv_transpose_2D(a, b): + ys = [] + for i in range(N_iter_func): + y = torch.conv_transpose2d( + a, b, stride=strides, padding=padding, groups=groups + ) + ys.append(y) + torch.mps.synchronize() + return ys + + return pt_conv_transpose_2D + + +def bench_shape(N, H, W, C, kH, kW, O, strides, padding, groups, np_dtype): + scale = 1.0 / math.sqrt(kH * kH * C) + a_np = np.random.uniform(0, 0.5, (N, H, W, C)).astype(np_dtype) + b_np = np.random.uniform(-scale, scale, (O, kH, kW, int(C / groups))).astype( + np_dtype + ) + + a_mx = mx.array(a_np) + b_mx = mx.array(b_np) + + a_pt = torch.from_numpy(a_np.transpose((0, 3, 1, 2))).to("mps") + b_pt = torch.from_numpy(b_np.transpose((3, 0, 1, 2))).to("mps") + + torch.mps.synchronize() + + f_mx = make_mx_conv_transpose_2D(strides, padding, groups) + f_pt = make_pt_conv_transpose_2D(strides, padding, groups) + + time_torch = bench(f_pt, a_pt, b_pt) + time_mlx = bench(f_mx, a_mx, b_mx) + + out_mx = mx.conv_transpose2d( + a_mx, b_mx, stride=strides, padding=padding, groups=groups + ) + out_pt = torch.conv_transpose2d( + a_pt.to("cpu"), b_pt.to("cpu"), stride=strides, padding=padding, groups=groups + ) + out_pt = torch.permute(out_pt, (0, 2, 3, 1)) + out_pt = out_pt.numpy(force=True) + + atol = 2e-5 if np_dtype == np.float32 else 1e-4 + + if not np.allclose(out_pt, out_mx, atol=atol): + print( + f"Failed at {(N, H, W, C)}, {(O, kH, kW, C)} [strides = {strides}, padding = {padding}, groups = {groups}] with max(|a - b|) = {np.max(np.abs(out_pt - out_mx))}" + ) + + return time_mlx, time_torch + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Run conv benchmarks") + + dtypes = ("float32",) + shapes = ( + (4, 32, 32, 32, 5, 5, 32, (1, 1), (2, 2), 1), + (4, 32, 32, 64, 5, 5, 64, (1, 1), (2, 2), 1), + (4, 32, 32, 128, 5, 5, 128, (1, 1), (2, 2), 1), + (4, 32, 32, 256, 5, 5, 256, (1, 1), (2, 2), 1), + (4, 32, 32, 512, 5, 5, 512, (1, 1), (2, 2), 1), + (4, 64, 64, 32, 5, 5, 32, (1, 1), (2, 2), 1), + (4, 64, 64, 64, 5, 5, 64, (1, 1), (2, 2), 1), + (4, 64, 64, 128, 5, 5, 128, (1, 1), (2, 2), 1), + (4, 64, 64, 256, 5, 5, 256, (1, 1), (2, 2), 1), + (4, 128, 128, 32, 5, 5, 32, (1, 1), (2, 2), 1), + (4, 128, 128, 64, 5, 5, 64, (1, 1), (2, 2), 1), + (4, 128, 128, 128, 5, 5, 128, (1, 1), (2, 2), 1), + (4, 256, 256, 32, 5, 5, 3, (1, 1), (2, 2), 1), + (4, 256, 256, 3, 5, 5, 32, (1, 1), (2, 2), 1), + (4, 128, 128, 64, 5, 5, 3, (1, 1), (2, 2), 1), + (4, 128, 128, 3, 5, 5, 64, (1, 1), (2, 2), 1), + ) + + for dtype in dtypes: + print( + "(N, H, W, C), ( O, kH, kW, C), dtype, stride, pads, groups, diff%" + ) + for N, H, W, C, kH, kW, O, strides, padding, groups in shapes: + np_dtype = getattr(np, dtype) + time_mlx, time_torch = bench_shape( + N, H, W, C, kH, kW, O, strides, padding, groups, np_dtype + ) + diff = time_torch / time_mlx - 1.0 + + print( + f"({N}, {H:3d}, {W:3d}, {C:3d}), ({O:3d}, {kH:2d}, {kW:2d}, {C:3d}), {dtype}, {strides}, {padding}, {groups:7d}, {100. * diff:+5.2f}%" + ) + if time_mlx >= 2.0 * time_torch: + print("ATTENTION ^^^^^^^") diff --git a/docs/src/python/nn/layers.rst b/docs/src/python/nn/layers.rst index d06e009e8..77105ea35 100644 --- a/docs/src/python/nn/layers.rst +++ b/docs/src/python/nn/layers.rst @@ -16,6 +16,9 @@ Layers Conv1d Conv2d Conv3d + ConvTranspose1d + ConvTranspose2d + ConvTranspose3d Dropout Dropout2d Dropout3d diff --git a/docs/src/python/ops.rst b/docs/src/python/ops.rst index c2e70c824..65ed3006f 100644 --- a/docs/src/python/ops.rst +++ b/docs/src/python/ops.rst @@ -45,6 +45,9 @@ Operations conv1d conv2d conv3d + conv_transpose1d + conv_transpose2d + conv_transpose3d conv_general cos cosh diff --git a/mlx/backend/common/conv.cpp b/mlx/backend/common/conv.cpp index 79bc3c4a1..e60d3bc9c 100644 --- a/mlx/backend/common/conv.cpp +++ b/mlx/backend/common/conv.cpp @@ -1125,7 +1125,7 @@ void Convolution::eval(const std::vector& inputs, array& out) { else { std::ostringstream msg; msg << "[Convolution::eval] Convolution currently only supports" - << " 1D and 2D convolutions. Got inputs with " << in.ndim() - 2 + << " 1D, 2D and 3D convolutions. Got inputs with " << in.ndim() - 2 << " spatial dimensions"; throw std::invalid_argument(msg.str()); } diff --git a/mlx/backend/metal/conv.cpp b/mlx/backend/metal/conv.cpp index 395488ead..5fda6c3e8 100644 --- a/mlx/backend/metal/conv.cpp +++ b/mlx/backend/metal/conv.cpp @@ -911,7 +911,7 @@ void Convolution::eval_gpu(const std::vector& inputs, array& out) { // Throw error else { throw std::invalid_argument( - "[Convolution::eval_gpu] Only supports 1D or 2D convolutions."); + "[Convolution::eval_gpu] Only supports 1D, 2D or 3D convolutions."); } // Clear copies diff --git a/mlx/ops.cpp b/mlx/ops.cpp index 9b9a0dcea..a27f7ca91 100644 --- a/mlx/ops.cpp +++ b/mlx/ops.cpp @@ -3298,6 +3298,93 @@ array conv3d( s); } +// Helper function for transposed convolutions +array conv_transpose_general( + const array& input, + const array& weight, + std::vector stride, + std::vector padding, + std::vector dilation, + int groups, + StreamOrDevice s) { + std::vector padding_lo(padding.size()); + std::vector padding_hi(padding.size()); + for (int i = 0; i < padding.size(); ++i) { + int wt_size = 1 + dilation[i] * (weight.shape(1 + i) - 1); + padding_lo[i] = wt_size - padding[i] - 1; + + int conv_output_shape = (input.shape(i + 1) - 1) * stride[i] - + 2 * padding[i] + dilation[i] * (weight.shape(i + 1) - 1) + 1; + + int in_size = 1 + (conv_output_shape - 1); + int out_size = 1 + stride[i] * (input.shape(1 + i) - 1); + padding_hi[i] = in_size - out_size + padding[i]; + } + + return conv_general( + /* const array& input = */ input, + /* const array& weight = */ weight, + /* std::vector stride = */ std::vector(stride.size(), 1), + /* std::vector padding_lo = */ std::move(padding_lo), + /* std::vector padding_hi = */ std::move(padding_hi), + /* std::vector kernel_dilation = */ std::move(dilation), + /* std::vector input_dilation = */ std::move(stride), + /* int groups = */ groups, + /* bool flip = */ true, + s); +} + +/** 1D transposed convolution with a filter */ +array conv_transpose1d( + const array& in_, + const array& wt_, + int stride /* = 1 */, + int padding /* = 0 */, + int dilation /* = 1 */, + int groups /* = 1 */, + StreamOrDevice s /* = {} */) { + return conv_transpose_general( + in_, wt_, {stride}, {padding}, {dilation}, groups, s); +} + +/** 2D transposed convolution with a filter */ +array conv_transpose2d( + const array& in_, + const array& wt_, + const std::pair& stride /* = {1, 1} */, + const std::pair& padding /* = {0, 0} */, + const std::pair& dilation /* = {1, 1} */, + int groups /* = 1 */, + StreamOrDevice s /* = {} */) { + return conv_transpose_general( + in_, + wt_, + {stride.first, stride.second}, + {padding.first, padding.second}, + {dilation.first, dilation.second}, + groups, + s); +} + +/** 3D transposed convolution with a filter */ +array conv_transpose3d( + const array& in_, + const array& wt_, + const std::tuple& stride /* = {1, 1, 1} */, + const std::tuple& padding /* = {0, 0, 0} */, + const std::tuple& dilation /* = {1, 1, 1} */, + int groups /* = 1 */, + StreamOrDevice s /* = {} */) { + return conv_transpose_general( + in_, + wt_, + {std::get<0>(stride), std::get<1>(stride), std::get<2>(stride)}, + {std::get<0>(padding), std::get<1>(padding), std::get<2>(padding)}, + {std::get<0>(dilation), std::get<1>(dilation), std::get<2>(dilation)}, + groups, + s); +} + /** General convolution with a filter */ array conv_general( array in, diff --git a/mlx/ops.h b/mlx/ops.h index 711a37aa9..42445ccb6 100644 --- a/mlx/ops.h +++ b/mlx/ops.h @@ -1247,6 +1247,36 @@ array conv3d( int groups = 1, StreamOrDevice s = {}); +/** 1D transposed convolution with a filter */ +array conv_transpose1d( + const array& input, + const array& weight, + int stride = 1, + int padding = 0, + int dilation = 1, + int groups = 1, + StreamOrDevice s = {}); + +/** 2D transposed convolution with a filter */ +array conv_transpose2d( + const array& input, + const array& weight, + const std::pair& stride = {1, 1}, + const std::pair& padding = {0, 0}, + const std::pair& dilation = {1, 1}, + int groups = 1, + StreamOrDevice s = {}); + +/** 3D transposed convolution with a filter */ +array conv_transpose3d( + const array& input, + const array& weight, + const std::tuple& stride = {1, 1, 1}, + const std::tuple& padding = {0, 0, 0}, + const std::tuple& dilation = {1, 1, 1}, + int groups = 1, + StreamOrDevice s = {}); + /** Quantized matmul multiplies x with a quantized matrix w*/ array quantized_matmul( const array& x, diff --git a/mlx/primitives.cpp b/mlx/primitives.cpp index fb4686e84..a1549fa6f 100644 --- a/mlx/primitives.cpp +++ b/mlx/primitives.cpp @@ -952,8 +952,8 @@ std::vector Convolution::vjp( /* const array& input = */ cotan, /* const array& weight = */ wt_trans, /* std::vector stride = */ input_dilation_, - /* std::vector padding_lo = */ padding_lo_, - /* std::vector padding_hi = */ padding_hi_, + /* std::vector padding_lo = */ padding_lo, + /* std::vector padding_hi = */ padding_hi, /* std::vector kernel_dilation = */ kernel_dilation_, /* std::vector input_dilation = */ kernel_strides_, /* int groups = */ 1, @@ -990,36 +990,61 @@ std::vector Convolution::vjp( no_dilation &= (input_dilation_[i] == 1) && (kernel_dilation_[i] == 1); } - if (no_dilation) { + if (no_dilation && !flip_) { auto grad = conv_weight_backward_patches( in, wt, cotan, kernel_strides_, padding_, stream()); grads.push_back(grad); } else { - std::vector padding_lo = padding_; - std::vector padding_hi = padding_; - - for (int i = 0; i < padding_hi.size(); ++i) { - int in_size = 1 + input_dilation_[i] * (in.shape(1 + i) - 1); - int out_size = 1 + kernel_strides_[i] * (cotan.shape(1 + i) - 1); - int wt_size = 1 + kernel_dilation_[i] * (wt.shape(1 + i) - 1); - padding_hi[i] = out_size - in_size + wt_size - padding_[i] - 1; - } - - auto in_trans = swapaxes(in, 0, -1, stream()); auto cotan_trans = swapaxes(cotan, 0, -1, stream()); - auto grad_trans = conv_general( - /* const array& input = */ in_trans, - /* const array& weight = */ cotan_trans, - /* std::vector stride = */ kernel_dilation_, - /* std::vector padding_lo = */ padding_lo, - /* std::vector padding_hi = */ padding_hi, - /* std::vector kernel_dilation = */ kernel_strides_, - /* std::vector input_dilation = */ input_dilation_, - /* int groups = */ 1, - /* bool flip = */ flip_, - stream()); - auto grad = swapaxes(grad_trans, 0, -1, stream()); - grads.push_back(grad); + auto in_trans = swapaxes(in, 0, -1, stream()); + + if (flip_) { + auto padding = padding_; + for (int i = 0; i < padding.size(); i++) { + int wt_size = 1 + kernel_dilation_[i] * (wt.shape(1 + i) - 1); + padding[i] = wt_size - padding_[i] - 1; + } + + auto grad_trans = conv_general( + /* const array& input = */ cotan_trans, + /* const array& weight = */ in_trans, + /* std::vector stride = */ kernel_dilation_, + /* std::vector padding_lo = */ padding, + /* std::vector padding_hi = */ padding, + /* std::vector kernel_dilation = */ input_dilation_, + /* std::vector input_dilation = */ kernel_strides_, + /* int groups = */ 1, + /* bool flip = */ false, + stream()); + auto grad = swapaxes(grad_trans, 0, -1, stream()); + grads.push_back(grad_trans); + } else { + std::vector padding_lo = padding_; + std::vector padding_hi = padding_; + + for (int i = 0; i < padding_hi.size(); ++i) { + int in_size = 1 + input_dilation_[i] * (in.shape(1 + i) - 1); + int out_size = 1 + kernel_strides_[i] * (cotan.shape(1 + i) - 1); + int wt_size = 1 + kernel_dilation_[i] * (wt.shape(1 + i) - 1); + padding_hi[i] = out_size - in_size + wt_size - padding_[i] - 1; + } + + auto in_trans = swapaxes(in, 0, -1, stream()); + auto cotan_trans = swapaxes(cotan, 0, -1, stream()); + auto grad_trans = conv_general( + /* const array& input = */ in_trans, + /* const array& weight = */ cotan_trans, + /* std::vector stride = */ kernel_dilation_, + /* std::vector padding_lo = */ padding_lo, + /* std::vector padding_hi = */ padding_hi, + /* std::vector kernel_dilation = */ kernel_strides_, + /* std::vector input_dilation = */ input_dilation_, + /* int groups = */ 1, + /* bool flip = */ false, + stream()); + auto grad = swapaxes(grad_trans, 0, -1, stream()); + grads.push_back(grad); + } } } } diff --git a/python/mlx/nn/layers/__init__.py b/python/mlx/nn/layers/__init__.py index 890c4ee5d..3cf5e33a8 100644 --- a/python/mlx/nn/layers/__init__.py +++ b/python/mlx/nn/layers/__init__.py @@ -55,6 +55,11 @@ from mlx.nn.layers.activations import ( from mlx.nn.layers.base import Module from mlx.nn.layers.containers import Sequential from mlx.nn.layers.convolution import Conv1d, Conv2d, Conv3d +from mlx.nn.layers.convolution_transpose import ( + ConvTranspose1d, + ConvTranspose2d, + ConvTranspose3d, +) from mlx.nn.layers.dropout import Dropout, Dropout2d, Dropout3d from mlx.nn.layers.embedding import Embedding from mlx.nn.layers.linear import Bilinear, Identity, Linear diff --git a/python/mlx/nn/layers/convolution.py b/python/mlx/nn/layers/convolution.py index 3202ebb73..27eb76e81 100644 --- a/python/mlx/nn/layers/convolution.py +++ b/python/mlx/nn/layers/convolution.py @@ -21,9 +21,9 @@ class Conv1d(Module): out_channels (int): The number of output channels kernel_size (int): The size of the convolution filters stride (int, optional): The stride when applying the filter. - Default: 1. + Default: ``1``. padding (int, optional): How many positions to 0-pad the input with. - Default: 0. + Default: ``0``. dilation (int, optional): The dilation of the convolution. bias (bool, optional): If ``True`` add a learnable bias to the output. Default: ``True`` @@ -84,9 +84,9 @@ class Conv2d(Module): out_channels (int): The number of output channels. kernel_size (int or tuple): The size of the convolution filters. stride (int or tuple, optional): The size of the stride when - applying the filter. Default: 1. + applying the filter. Default: ``1``. padding (int or tuple, optional): How many positions to 0-pad - the input with. Default: 0. + the input with. Default: ``0``. dilation (int or tuple, optional): The dilation of the convolution. bias (bool, optional): If ``True`` add a learnable bias to the output. Default: ``True`` diff --git a/python/mlx/nn/layers/convolution_transpose.py b/python/mlx/nn/layers/convolution_transpose.py new file mode 100644 index 000000000..ec55049e5 --- /dev/null +++ b/python/mlx/nn/layers/convolution_transpose.py @@ -0,0 +1,206 @@ +# Copyright © 2023 Apple Inc. + +import math +from typing import Union + +import mlx.core as mx +from mlx.nn.layers.base import Module + + +class ConvTranspose1d(Module): + """Applies a 1-dimensional transposed convolution over the multi-channel input sequence. + + The channels are expected to be last i.e. the input shape should be ``NLC`` where: + + * ``N`` is the batch dimension + * ``L`` is the sequence length + * ``C`` is the number of input channels + + Args: + in_channels (int): The number of input channels + out_channels (int): The number of output channels + kernel_size (int): The size of the convolution filters + stride (int, optional): The stride when applying the filter. + Default: ``1``. + padding (int, optional): How many positions to 0-pad the input with. + Default: ``0``. + dilation (int, optional): The dilation of the convolution. + bias (bool, optional): If ``True`` add a learnable bias to the output. + Default: ``True`` + """ + + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: int, + stride: int = 1, + padding: int = 0, + dilation: int = 1, + bias: bool = True, + ): + super().__init__() + + scale = math.sqrt(1 / (in_channels * kernel_size)) + self.weight = mx.random.uniform( + low=-scale, + high=scale, + shape=(out_channels, kernel_size, in_channels), + ) + if bias: + self.bias = mx.zeros((out_channels,)) + + self.padding = padding + self.dilation = dilation + self.stride = stride + + def _extra_repr(self): + return ( + f"{self.weight.shape[-1]}, {self.weight.shape[0]}, " + f"kernel_size={self.weight.shape[1]}, stride={self.stride}, " + f"padding={self.padding}, dilation={self.dilation}, " + f"bias={'bias' in self}" + ) + + def __call__(self, x): + y = mx.conv_transpose1d( + x, self.weight, self.stride, self.padding, self.dilation + ) + if "bias" in self: + y = y + self.bias + return y + + +class ConvTranspose2d(Module): + """Applies a 2-dimensional transposed convolution over the multi-channel input image. + + The channels are expected to be last i.e. the input shape should be ``NHWC`` where: + + * ``N`` is the batch dimension + * ``H`` is the input image height + * ``W`` is the input image width + * ``C`` is the number of input channels + + Args: + in_channels (int): The number of input channels. + out_channels (int): The number of output channels. + kernel_size (int or tuple): The size of the convolution filters. + stride (int or tuple, optional): The size of the stride when + applying the filter. Default: ``1``. + padding (int or tuple, optional): How many positions to 0-pad + the input with. Default: ``0``. + dilation (int or tuple, optional): The dilation of the convolution. + bias (bool, optional): If ``True`` add a learnable bias to the + output. Default: ``True`` + """ + + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: Union[int, tuple], + stride: Union[int, tuple] = 1, + padding: Union[int, tuple] = 0, + dilation: Union[int, tuple] = 1, + bias: bool = True, + ): + super().__init__() + + kernel_size, stride, padding = map( + lambda x: (x, x) if isinstance(x, int) else x, + (kernel_size, stride, padding), + ) + scale = math.sqrt(1 / (in_channels * kernel_size[0] * kernel_size[1])) + self.weight = mx.random.uniform( + low=-scale, + high=scale, + shape=(out_channels, *kernel_size, in_channels), + ) + if bias: + self.bias = mx.zeros((out_channels,)) + + self.padding = padding + self.stride = stride + self.dilation = dilation + + def _extra_repr(self): + return ( + f"{self.weight.shape[-1]}, {self.weight.shape[0]}, " + f"kernel_size={self.weight.shape[1:2]}, stride={self.stride}, " + f"padding={self.padding}, dilation={self.dilation}, " + f"bias={'bias' in self}" + ) + + def __call__(self, x): + y = mx.conv_transpose2d( + x, self.weight, self.stride, self.padding, self.dilation + ) + if "bias" in self: + y = y + self.bias + return y + + +class ConvTranspose3d(Module): + """Applies a 3-dimensional transposed convolution over the multi-channel input image. + + The channels are expected to be last i.e. the input shape should be ``NDHWC`` where: + + * ``N`` is the batch dimension + * ``D`` is the input image depth + * ``H`` is the input image height + * ``W`` is the input image width + * ``C`` is the number of input channels + + Args: + in_channels (int): The number of input channels. + out_channels (int): The number of output channels. + kernel_size (int or tuple): The size of the convolution filters. + stride (int or tuple, optional): The size of the stride when + applying the filter. Default: ``1``. + padding (int or tuple, optional): How many positions to 0-pad + the input with. Default: ``0``. + bias (bool, optional): If ``True`` add a learnable bias to the + output. Default: ``True`` + """ + + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: Union[int, tuple], + stride: Union[int, tuple] = 1, + padding: Union[int, tuple] = 0, + bias: bool = True, + ): + super().__init__() + + kernel_size, stride, padding = map( + lambda x: (x, x, x) if isinstance(x, int) else x, + (kernel_size, stride, padding), + ) + scale = math.sqrt( + 1 / (in_channels * kernel_size[0] * kernel_size[1] * kernel_size[2]) + ) + self.weight = mx.random.uniform( + low=-scale, + high=scale, + shape=(out_channels, *kernel_size, in_channels), + ) + if bias: + self.bias = mx.zeros((out_channels,)) + + self.padding = padding + self.stride = stride + + def _extra_repr(self): + return ( + f"{self.weight.shape[-1]}, {self.weight.shape[0]}, " + f"kernel_size={self.weight.shape[1:3]}, stride={self.stride}, " + f"padding={self.padding}, bias={'bias' in self}" + ) + + def __call__(self, x): + y = mx.conv_transpose3d(x, self.weight, self.stride, self.padding) + if "bias" in self: + y = y + self.bias + return y diff --git a/python/src/ops.cpp b/python/src/ops.cpp index 7b3dace34..a594fd287 100644 --- a/python/src/ops.cpp +++ b/python/src/ops.cpp @@ -3238,12 +3238,12 @@ void init_ops(nb::module_& m) { 1D convolution over an input with several channels Args: - input (array): input array of shape (``N``, ``H``, ``C_in``) - weight (array): weight array of shape (``C_out``, ``H``, ``C_in``) - stride (int, optional): kernel stride. Default: ``1``. - padding (int, optional): input padding. Default: ``0``. - dilation (int, optional): kernel dilation. Default: ``1``. - groups (int, optional): input feature groups. Default: ``1``. + input (array): Input array of shape ``(N, H, C_in)``. + weight (array): Weight array of shape ``(C_out, H, C_in)``. + stride (int, optional): Kernel stride. Default: ``1``. + padding (int, optional): Input padding. Default: ``0``. + dilation (int, optional): Kernel dilation. Default: ``1``. + groups (int, optional): Input feature groups. Default: ``1``. Returns: array: The convolved array. @@ -3296,8 +3296,8 @@ void init_ops(nb::module_& m) { 2D convolution over an input with several channels Args: - input (array): input array of shape ``(N, H, W, C_in)`` - weight (array): weight array of shape ``(C_out, H, W, C_in)`` + input (array): Input array of shape ``(N, H, W, C_in)``. + weight (array): Weight array of shape ``(C_out, H, W, C_in)``. stride (int or tuple(int), optional): :obj:`tuple` of size 2 with kernel strides. All spatial dimensions get the same stride if only one number is specified. Default: ``1``. @@ -3368,8 +3368,173 @@ void init_ops(nb::module_& m) { Note: Only the default ``groups=1`` is currently supported. Args: - input (array): input array of shape ``(N, D, H, W, C_in)`` - weight (array): weight array of shape ``(C_out, D, H, W, C_in)`` + input (array): Input array of shape ``(N, D, H, W, C_in)``. + weight (array): Weight array of shape ``(C_out, D, H, W, C_in)``. + stride (int or tuple(int), optional): :obj:`tuple` of size 3 with + kernel strides. All spatial dimensions get the same stride if + only one number is specified. Default: ``1``. + padding (int or tuple(int), optional): :obj:`tuple` of size 3 with + symmetric input padding. All spatial dimensions get the same + padding if only one number is specified. Default: ``0``. + dilation (int or tuple(int), optional): :obj:`tuple` of size 3 with + kernel dilation. All spatial dimensions get the same dilation + if only one number is specified. Default: ``1`` + groups (int, optional): input feature groups. Default: ``1``. + + Returns: + array: The convolved array. + )pbdoc"); + m.def( + "conv_transpose1d", + &conv_transpose1d, + nb::arg(), + nb::arg(), + "stride"_a = 1, + "padding"_a = 0, + "dilation"_a = 1, + "groups"_a = 1, + nb::kw_only(), + "stream"_a = nb::none(), + nb::sig( + "def conv_transpose1d(input: array, weight: array, /, stride: int = 1, padding: int = 0, dilation: int = 1, groups: int = 1, *, stream: Union[None, Stream, Device] = None) -> array"), + R"pbdoc( + 1D transposed convolution over an input with several channels + + Args: + input (array): Input array of shape ``(N, H, C_in)``. + weight (array): Weight array of shape ``(C_out, H, C_in)``. + stride (int, optional): Kernel stride. Default: ``1``. + padding (int, optional): Input padding. Default: ``0``. + dilation (int, optional): Kernel dilation. Default: ``1``. + groups (int, optional): Input feature groups. Default: ``1``. + + Returns: + array: The convolved array. + )pbdoc"); + m.def( + "conv_transpose2d", + [](const array& input, + const array& weight, + const std::variant>& stride, + const std::variant>& padding, + const std::variant>& dilation, + int groups, + StreamOrDevice s) { + std::pair stride_pair{1, 1}; + std::pair padding_pair{0, 0}; + std::pair dilation_pair{1, 1}; + + if (auto pv = std::get_if(&stride); pv) { + stride_pair = std::pair{*pv, *pv}; + } else { + stride_pair = std::get>(stride); + } + + if (auto pv = std::get_if(&padding); pv) { + padding_pair = std::pair{*pv, *pv}; + } else { + padding_pair = std::get>(padding); + } + + if (auto pv = std::get_if(&dilation); pv) { + dilation_pair = std::pair{*pv, *pv}; + } else { + dilation_pair = std::get>(dilation); + } + + return conv_transpose2d( + input, weight, stride_pair, padding_pair, dilation_pair, groups, s); + }, + nb::arg(), + nb::arg(), + "stride"_a = 1, + "padding"_a = 0, + "dilation"_a = 1, + "groups"_a = 1, + nb::kw_only(), + "stream"_a = nb::none(), + nb::sig( + "def conv_transpose2d(input: array, weight: array, /, stride: Union[int, Tuple[int, int]] = 1, padding: Union[int, Tuple[int, int]] = 0, dilation: Union[int, Tuple[int, int]] = 1, groups: int = 1, *, stream: Union[None, Stream, Device] = None) -> array"), + R"pbdoc( + 2D transposed convolution over an input with several channels + + Note: Only the default ``groups=1`` is currently supported. + + Args: + input (array): Input array of shape ``(N, H, W, C_in)``. + weight (array): Weight array of shape ``(C_out, H, W, C_in)``. + stride (int or tuple(int), optional): :obj:`tuple` of size 2 with + kernel strides. All spatial dimensions get the same stride if + only one number is specified. Default: ``1``. + padding (int or tuple(int), optional): :obj:`tuple` of size 2 with + symmetric input padding. All spatial dimensions get the same + padding if only one number is specified. Default: ``0``. + dilation (int or tuple(int), optional): :obj:`tuple` of size 2 with + kernel dilation. All spatial dimensions get the same dilation + if only one number is specified. Default: ``1`` + groups (int, optional): input feature groups. Default: ``1``. + + Returns: + array: The convolved array. + )pbdoc"); + m.def( + "conv_transpose3d", + [](const array& input, + const array& weight, + const std::variant>& stride, + const std::variant>& padding, + const std::variant>& dilation, + int groups, + StreamOrDevice s) { + std::tuple stride_tuple{1, 1, 1}; + std::tuple padding_tuple{0, 0, 0}; + std::tuple dilation_tuple{1, 1, 1}; + + if (auto pv = std::get_if(&stride); pv) { + stride_tuple = std::tuple{*pv, *pv, *pv}; + } else { + stride_tuple = std::get>(stride); + } + + if (auto pv = std::get_if(&padding); pv) { + padding_tuple = std::tuple{*pv, *pv, *pv}; + } else { + padding_tuple = std::get>(padding); + } + + if (auto pv = std::get_if(&dilation); pv) { + dilation_tuple = std::tuple{*pv, *pv, *pv}; + } else { + dilation_tuple = std::get>(dilation); + } + + return conv_transpose3d( + input, + weight, + stride_tuple, + padding_tuple, + dilation_tuple, + groups, + s); + }, + nb::arg(), + nb::arg(), + "stride"_a = 1, + "padding"_a = 0, + "dilation"_a = 1, + "groups"_a = 1, + nb::kw_only(), + "stream"_a = nb::none(), + nb::sig( + "def conv_transpose3d(input: array, weight: array, /, stride: Union[int, Tuple[int, int, int]] = 1, padding: Union[int, Tuple[int, int, int]] = 0, dilation: Union[int, Tuple[int, int, int]] = 1, groups: int = 1, *, stream: Union[None, Stream, Device] = None) -> array"), + R"pbdoc( + 3D transposed convolution over an input with several channels + + Note: Only the default ``groups=1`` is currently supported. + + Args: + input (array): Input array of shape ``(N, D, H, W, C_in)``. + weight (array): Weight array of shape ``(C_out, D, H, W, C_in)``. stride (int or tuple(int), optional): :obj:`tuple` of size 3 with kernel strides. All spatial dimensions get the same stride if only one number is specified. Default: ``1``. @@ -3465,8 +3630,8 @@ void init_ops(nb::module_& m) { General convolution over an input with several channels Args: - input (array): Input array of shape ``(N, ..., C_in)`` - weight (array): Weight array of shape ``(C_out, ..., C_in)`` + input (array): Input array of shape ``(N, ..., C_in)``. + weight (array): Weight array of shape ``(C_out, ..., C_in)``. stride (int or list(int), optional): :obj:`list` with kernel strides. All spatial dimensions get the same stride if only one number is specified. Default: ``1``. diff --git a/python/tests/test_conv.py b/python/tests/test_conv.py index d5b43b2a2..46291cf6d 100644 --- a/python/tests/test_conv.py +++ b/python/tests/test_conv.py @@ -866,6 +866,37 @@ class TestConv(mlx_tests.MLXTestCase): flip=flip, ) + def test_conv_general_flip_grad(self): + for s in (1, 2): + w = mx.random.normal(shape=(1, 2, 2, 1)) + x = mx.random.normal(shape=(1, 2, 2, 1)) + + def conv_t(w): + return mx.conv_general( + x, + w, + stride=1, + padding=(1, 1), + kernel_dilation=1, + input_dilation=s, + flip=True, + ) + + cotan = mx.random.normal(shape=(1, 2 + s, 2 + s, 1)) + + dw = mx.vjp(conv_t, (w,), (cotan,))[1][0] + + x = x.squeeze() + cotan = cotan.squeeze() + dw = dw.squeeze() + + dw00 = (cotan[:-1:s, :-1:s] * x).sum() + dw01 = (cotan[:-1:s, 1::s] * x).sum() + dw10 = (cotan[1::s, :-1:s] * x).sum() + dw11 = (cotan[1::s, 1::s] * x).sum() + expected = mx.array([[dw00, dw01], [dw10, dw11]]) + self.assertTrue(mx.allclose(dw, expected)) + if __name__ == "__main__": unittest.main() diff --git a/python/tests/test_conv_transpose.py b/python/tests/test_conv_transpose.py new file mode 100644 index 000000000..0efff048d --- /dev/null +++ b/python/tests/test_conv_transpose.py @@ -0,0 +1,601 @@ +# Copyright © 2023-2024 Apple Inc. + +import math +import unittest +from itertools import permutations + +import mlx.core as mx +import mlx_tests +import numpy as np + +try: + import torch + import torch.nn.functional as F + + has_torch = True +except ImportError as e: + has_torch = False + + +class TestConvTranspose(mlx_tests.MLXTestCase): + @unittest.skipIf(not has_torch, "requires Torch") + def test_torch_conv_transpose_1D(self): + def run_conv_transpose_1D( + N, + C, + O, + iH, + kH, + stride, + padding, + output_padding=0, + dilation=1, + groups=1, + dtype="float32", + atol=1e-5, + ): + with self.subTest( + dtype=dtype, + N=N, + C=C, + O=O, + iH=iH, + kH=kH, + stride=stride, + padding=padding, + dilation=dilation, + groups=groups, + ): + np_dtype = getattr(np, dtype) + np.random.seed(0) + in_np = np.random.normal(0, 1.0 / C, (N, iH, C)).astype(np_dtype) + wt_np = np.random.normal(0, 1.0 / C, (O, kH, int(C / groups))).astype( + np_dtype + ) + + in_mx, wt_mx = map(mx.array, (in_np, wt_np)) + in_pt = torch.from_numpy(in_np.transpose(0, 2, 1)) + wt_pt = torch.from_numpy(wt_np.transpose(2, 0, 1)) + + out_mx = mx.conv_transpose1d( + in_mx, + wt_mx, + stride=stride, + padding=padding, + dilation=dilation, + groups=groups, + ) + out_pt = torch.conv_transpose1d( + in_pt, + wt_pt, + stride=stride, + padding=padding, + dilation=dilation, + groups=groups, + ) + out_pt = torch.transpose(out_pt, 2, 1) + + self.assertEqual(out_pt.shape, out_mx.shape) + self.assertTrue(np.allclose(out_pt.numpy(), out_mx, atol=atol)) + + for dtype in ("float32",): + for N, C, O in ( + (1, 1, 1), + (1, 6, 1), + (1, 1, 6), + (4, 32, 64), + ): + for iH, kH, stride, padding in ( + (1, 1, 1, 0), + (3, 3, 1, 0), + (31, 5, 5, 2), + ): + run_conv_transpose_1D(N, C, O, iH, kH, stride, padding, dtype=dtype) + + # Groups tests + N, C, O = (4, 32, 64) + for iH, kH, stride, padding in ( + (1, 1, 1, 0), + (3, 3, 1, 0), + (31, 5, 5, 2), + ): + for group in (1,): + run_conv_transpose_1D( + N, C, O, iH, kH, stride, padding, groups=group, dtype=dtype + ) + + # Strided inputs tests + for tpose_in, tpose_wt in ( + ((0, 2, 1), (0, 1, 2)), + ((0, 2, 1), (0, 2, 1)), + ): + with self.subTest(name="strided", tpose_in=tpose_in, tpose_wt=tpose_wt): + in_np = np.random.normal(0, 1.0 / 16, (16, 16, 16)).astype(np.float32) + wt_np = np.random.normal(0, 1.0 / 16, (16, 16, 16)).astype(np.float32) + + in_mx, wt_mx = map(mx.array, (in_np, wt_np)) + in_mx_t = mx.transpose(in_mx, tpose_in) + wt_mx_t = mx.transpose(wt_mx, tpose_wt) + out_mx = mx.conv_transpose1d(in_mx_t, wt_mx_t) + + in_pt = torch.from_numpy(in_np.transpose(tpose_in).transpose(0, 2, 1)) + wt_pt = torch.from_numpy(wt_np.transpose(tpose_wt).transpose(2, 0, 1)) + + out_pt = torch.conv_transpose1d(in_pt, wt_pt) + out_pt = torch.transpose(out_pt, 2, 1) + + self.assertEqual(out_pt.shape, out_mx.shape) + self.assertTrue(np.allclose(out_pt.numpy(), out_mx, atol=1e-5)) + + @unittest.skipIf(not has_torch, "requires Torch") + def test_torch_conv_transpose_1D_grad(self): + def run_conv_transpose1D_grad( + N, + C, + O, + iH, + kH, + stride, + padding, + dilation=1, + groups=1, + dtype="float32", + atol=1e-5, + ): + with self.subTest( + dtype=dtype, + N=N, + C=C, + O=O, + iH=iH, + kH=kH, + stride=stride, + padding=padding, + dilation=dilation, + groups=groups, + ): + np_dtype = getattr(np, dtype) + np.random.seed(0) + # oH = 1 + ((iH + 2 * padding - dilation * (kH - 1) - 1) // stride) + + in_np = np.random.normal(0, 1.0 / C, (N, iH, C)).astype(np_dtype) + wt_np = np.random.normal(0, 1.0 / C, (O, kH, C)).astype(np_dtype) + + in_mx, wt_mx = map(mx.array, (in_np, wt_np)) + in_pt = torch.from_numpy(in_np.transpose(0, 2, 1)).requires_grad_(True) + wt_pt = torch.from_numpy(wt_np.transpose(2, 0, 1)).requires_grad_(True) + + out_pt = F.conv_transpose1d( + in_pt, wt_pt, stride=stride, padding=padding, dilation=dilation + ) + + # use torch to compute ct + out_pt.retain_grad() + (out_pt - torch.randn_like(out_pt)).abs().sum().backward() + + pt_grad_in = in_pt.grad.permute(0, 2, 1).numpy() + pt_grad_wt = wt_pt.grad.permute(1, 2, 0).numpy() + + ct_mx = mx.array(out_pt.grad.numpy().transpose(0, 2, 1)) + + def f(a, b): + return mx.conv_transpose1d( + a, + b, + stride=stride, + padding=padding, + dilation=dilation, + groups=groups, + ) + + _, outs_mx = mx.vjp( + f, + [ + in_mx, + wt_mx, + ], + [ + ct_mx, + ], + ) + + mx_grad_in, mx_grad_wt = outs_mx + + self.assertEqual(pt_grad_in.shape, mx_grad_in.shape) + self.assertEqual(in_mx.shape, mx_grad_in.shape) + self.assertTrue(np.allclose(pt_grad_in, mx_grad_in, atol=atol)) + + self.assertEqual(pt_grad_wt.shape, mx_grad_wt.shape) + self.assertEqual(wt_mx.shape, mx_grad_wt.shape) + self.assertTrue(np.allclose(pt_grad_wt, mx_grad_wt, atol=atol)) + + for dtype in ("float32",): + for N, C, O in ( + (1, 1, 1), + (1, 6, 1), + (1, 1, 6), + (4, 32, 64), + ): + for iH, kH, stride, padding in ( + (1, 1, 1, 0), + (3, 3, 1, 0), + (31, 5, 5, 2), + ): + run_conv_transpose1D_grad( + N, C, O, iH, kH, stride, padding, dtype=dtype + ) + + @unittest.skipIf(not has_torch, "requires Torch") + def test_torch_conv_transpose_2D(self): + def run_conv_transpose2D( + N, + C, + O, + idim, + kdim, + stride, + padding, + dilation=(1, 1), + groups=1, + dtype="float32", + atol=1e-5, + ): + with self.subTest( + dtype=dtype, + N=N, + C=C, + O=O, + idim=idim, + kdim=kdim, + stride=stride, + padding=padding, + dilation=dilation, + groups=groups, + ): + np_dtype = getattr(np, dtype) + np.random.seed(0) + iH, iW = idim + kH, kW = kdim + scale = 1.0 / math.sqrt(kH * kW * C) + in_np = np.random.normal(0.0, scale, (N, iH, iW, C)).astype(np_dtype) + wt_np = np.random.normal(0.0, 1.0, (O, kH, kW, int(C / groups))).astype( + np_dtype + ) + + in_mx, wt_mx = map(mx.array, (in_np, wt_np)) + in_pt = torch.from_numpy(in_np.transpose(0, 3, 1, 2)).to("cpu") + wt_pt = torch.from_numpy(wt_np.transpose(3, 0, 1, 2)).to("cpu") + + out_mx = mx.conv_transpose2d( + in_mx, + wt_mx, + stride=stride, + padding=padding, + dilation=dilation, + groups=groups, + ) + out_pt = torch.conv_transpose2d( + in_pt, + wt_pt, + stride=stride, + padding=padding, + dilation=dilation, + groups=groups, + ) + out_pt = torch.permute(out_pt, (0, 2, 3, 1)).numpy(force=True) + + self.assertEqual(out_pt.shape, out_mx.shape) + self.assertTrue(np.allclose(out_pt, out_mx, atol=atol)) + + for dtype in ("float32",): + for N, C, O in ( + (1, 1, 1), + (1, 6, 1), + (1, 1, 6), + (4, 32, 64), + ): + for idim, kdim, stride, padding in ( + ((1, 1), (1, 1), (1, 1), (0, 0)), + ((3, 3), (3, 1), (1, 1), (0, 0)), + ((31, 31), (5, 5), (5, 5), (2, 2)), + ): + run_conv_transpose2D( + N, C, O, idim, kdim, stride, padding, dtype=dtype + ) + + # Groups tests + N, C, O = (4, 32, 64) + for idim, kdim, stride, padding in ( + ((1, 1), (1, 1), (1, 1), (0, 0)), + ((3, 3), (3, 1), (1, 1), (0, 0)), + ((31, 31), (5, 5), (5, 5), (2, 2)), + ): + for group in (1,): + run_conv_transpose2D( + N, C, O, idim, kdim, stride, padding, groups=group, dtype=dtype + ) + + @unittest.skipIf(not has_torch, "requires Torch") + def test_torch_conv_transpose_2D_grad(self): + def run_conv_transpose2D_grad( + N, + C, + O, + idim, + kdim, + stride, + padding, + dilation=(1, 1), + groups=1, + dtype="float32", + atol=1e-5, + ): + with self.subTest( + dtype=dtype, + N=N, + C=C, + O=O, + idim=idim, + kdim=kdim, + stride=stride, + padding=padding, + dilation=dilation, + groups=groups, + ): + np_dtype = getattr(np, dtype) + np.random.seed(0) + iH, iW = idim + kH, kW = kdim + scale = 1.0 / math.sqrt(kH * kW * C * O) + + in_np = np.random.normal(0.0, scale, (N, iH, iW, C)).astype(np_dtype) + wt_np = np.random.normal(0.0, scale, (O, kH, kW, C)).astype(np_dtype) + + in_mx, wt_mx = map(mx.array, (in_np, wt_np)) + in_pt = torch.from_numpy(in_np.transpose(0, 3, 1, 2)).requires_grad_( + True + ) + wt_pt = torch.from_numpy(wt_np.transpose(3, 0, 1, 2)).requires_grad_( + True + ) + + out_pt = F.conv_transpose2d( + in_pt, wt_pt, stride=stride, padding=padding, dilation=dilation + ) + + # use torch to compute ct + out_pt.retain_grad() + (out_pt - torch.randn_like(out_pt)).abs().sum().backward() + + pt_grad_in = in_pt.grad.permute(0, 2, 3, 1).numpy() + pt_grad_wt = wt_pt.grad.permute(1, 2, 3, 0).numpy() + + ct_mx = mx.array(out_pt.grad.numpy().transpose(0, 2, 3, 1)) + + def f(a, b): + return mx.conv_transpose2d( + a, + b, + stride=stride, + padding=padding, + dilation=dilation, + groups=groups, + ) + + _, outs_mx = mx.vjp( + f, + [in_mx, wt_mx], + [ct_mx], + ) + + mx_grad_in, mx_grad_wt = outs_mx + + self.assertEqual(pt_grad_in.shape, mx_grad_in.shape) + self.assertEqual(in_mx.shape, mx_grad_in.shape) + self.assertTrue(np.allclose(pt_grad_in, mx_grad_in, atol=atol)) + + self.assertEqual(pt_grad_wt.shape, mx_grad_wt.shape) + self.assertEqual(wt_mx.shape, mx_grad_wt.shape) + self.assertTrue(np.allclose(pt_grad_wt, mx_grad_wt, atol=atol)) + + for dtype in ("float32",): + for N, C, O in ((1, 1, 1), (1, 6, 1), (1, 1, 6), (4, 32, 64), (4, 16, 32)): + for idim, kdim, stride, padding, dilation in ( + ((1, 1), (1, 1), (1, 1), (0, 0), (1, 1)), + ((3, 3), (3, 1), (1, 1), (0, 0), (1, 1)), + ((31, 31), (5, 5), (5, 5), (2, 2), (1, 1)), + ((32, 32), (3, 3), (2, 2), (1, 1), (1, 1)), + ((31, 31), (5, 5), (5, 5), (2, 2), (3, 2)), + ((32, 32), (3, 3), (2, 2), (1, 1), (3, 2)), + ): + run_conv_transpose2D_grad( + N, C, O, idim, kdim, stride, padding, dilation, dtype=dtype + ) + + @unittest.skipIf(not has_torch, "requires Torch") + def test_torch_conv_transpose_3D(self): + def run_conv_transpose3D( + N, + C, + O, + idim, + kdim, + stride, + padding, + dilation=(1, 1, 1), + groups=1, + dtype="float32", + atol=1e-5, + ): + with self.subTest( + dtype=dtype, + N=N, + C=C, + O=O, + idim=idim, + kdim=kdim, + stride=stride, + padding=padding, + dilation=dilation, + groups=groups, + ): + np_dtype = getattr(np, dtype) + np.random.seed(0) + iD, iH, iW = idim + kD, kH, kW = kdim + scale = 1.0 / math.sqrt(kD * kH * kW * C * O) + in_np = np.random.normal(0.0, scale, (N, iD, iH, iW, C)).astype( + np_dtype + ) + wt_np = np.random.normal(0.0, 1.0, (O, kD, kH, kW, C)).astype(np_dtype) + + in_mx, wt_mx = map(mx.array, (in_np, wt_np)) + in_pt = torch.from_numpy(in_np.transpose(0, 4, 1, 2, 3)) + wt_pt = torch.from_numpy(wt_np.transpose(4, 0, 1, 2, 3)) + + out_mx = mx.conv_transpose3d( + in_mx, + wt_mx, + stride=stride, + padding=padding, + dilation=dilation, + groups=groups, + ) + out_pt = torch.conv_transpose3d( + in_pt, + wt_pt, + stride=stride, + padding=padding, + dilation=dilation, + groups=groups, + ) + out_pt = torch.permute(out_pt, (0, 2, 3, 4, 1)).numpy(force=True) + + self.assertEqual(out_pt.shape, out_mx.shape) + self.assertTrue(np.allclose(out_pt, out_mx, atol=atol)) + + for dtype in ("float32",): + for N, C, O in ( + (1, 1, 1), + (1, 6, 1), + (1, 1, 6), + (2, 8, 16), + ): + for idim, kdim, stride, padding in ( + ((1, 1, 1), (1, 1, 1), (1, 1, 1), (0, 0, 0)), + ((3, 3, 3), (3, 1, 1), (1, 1, 1), (0, 0, 0)), + ((15, 15, 15), (3, 3, 3), (3, 3, 3), (2, 2, 2)), + ): + run_conv_transpose3D( + N, C, O, idim, kdim, stride, padding, dtype=dtype + ) + + @unittest.skipIf(not has_torch, "requires Torch") + def test_torch_conv_transpose_3D_grad(self): + def run_conv_transpose3D_grad( + N, + C, + O, + idim, + kdim, + stride, + padding, + dilation=(1, 1, 1), + groups=1, + dtype="float32", + atol=1e-4, + ): + with self.subTest( + dtype=dtype, + N=N, + C=C, + O=O, + idim=idim, + kdim=kdim, + stride=stride, + padding=padding, + dilation=dilation, + groups=groups, + ): + np_dtype = getattr(np, dtype) + np.random.seed(0) + iD, iH, iW = idim + kD, kH, kW = kdim + scale = 1.0 / math.sqrt(kD * kH * kW * C * O) + + in_np = np.random.normal(0.0, scale, (N, iD, iH, iW, C)).astype( + np_dtype + ) + wt_np = np.random.normal(0.0, scale, (O, kD, kH, kW, C)).astype( + np_dtype + ) + + in_mx, wt_mx = map(mx.array, (in_np, wt_np)) + in_pt = torch.from_numpy(in_np.transpose(0, 4, 1, 2, 3)).requires_grad_( + True + ) + wt_pt = torch.from_numpy(wt_np.transpose(4, 0, 1, 2, 3)).requires_grad_( + True + ) + + out_pt = F.conv_transpose3d( + in_pt, + wt_pt, + stride=stride, + padding=padding, + dilation=dilation, + groups=groups, + ) + + # use torch to compute ct + out_pt.retain_grad() + (out_pt - torch.randn_like(out_pt)).abs().sum().backward() + + pt_grad_in = in_pt.grad.permute(0, 2, 3, 4, 1).numpy() + pt_grad_wt = wt_pt.grad.permute(1, 2, 3, 4, 0).numpy() + + ct_mx = mx.array(out_pt.grad.numpy().transpose(0, 2, 3, 4, 1)) + + def f(a, b): + return mx.conv_transpose3d( + a, + b, + stride=stride, + padding=padding, + dilation=dilation, + groups=groups, + ) + + _, outs_mx = mx.vjp( + f, + [in_mx, wt_mx], + [ct_mx], + ) + + mx_grad_in, mx_grad_wt = outs_mx + + self.assertEqual(pt_grad_in.shape, mx_grad_in.shape) + self.assertEqual(in_mx.shape, mx_grad_in.shape) + self.assertTrue(np.allclose(pt_grad_in, mx_grad_in, atol=atol)) + + self.assertEqual(pt_grad_wt.shape, mx_grad_wt.shape) + self.assertEqual(wt_mx.shape, mx_grad_wt.shape) + self.assertTrue(np.allclose(pt_grad_wt, mx_grad_wt, atol=atol)) + + for dtype in ("float32",): + for N, C, O in ((1, 1, 1), (1, 6, 1), (1, 1, 6), (2, 4, 8), (2, 8, 16)): + for idim, kdim, stride, padding, dilation in ( + ((1, 1, 1), (1, 1, 1), (1, 1, 1), (0, 0, 0), (1, 1, 1)), + ((3, 3, 3), (3, 1, 1), (1, 1, 1), (0, 0, 0), (1, 1, 1)), + ((15, 15, 15), (5, 5, 5), (5, 5, 5), (2, 2, 2), (1, 1, 1)), + ((16, 16, 16), (3, 3, 3), (2, 2, 2), (1, 1, 1), (1, 1, 1)), + ((15, 15, 15), (5, 5, 5), (5, 5, 5), (2, 2, 2), (3, 2, 2)), + ((16, 16, 16), (3, 3, 3), (2, 2, 2), (1, 1, 1), (3, 2, 2)), + ): + run_conv_transpose3D_grad( + N, C, O, idim, kdim, stride, padding, dilation, dtype=dtype + ) + + +if __name__ == "__main__": + unittest.main()