mirror of
https://github.com/ml-explore/mlx.git
synced 2025-07-19 15:41:13 +08:00
Transposed Convolution (#1245)
* initial implementation for conv_transpose ran pre-commit implemented conv_transpose updated conv_general docstring updated conv_general docstring updated code comments removed commented run_conv_checks updated acknowledgments added missing entry to ops.rst added op to nn.layers resolved merge conflicts * removed ConvolutionTranspose primitive as suggested by reviewer removed ConvolutionTranspose primitive as suggested by reviewer * remove transpose flag, add another test --------- Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
parent
ba3e913c7a
commit
efeb9c0f02
@ -18,6 +18,7 @@ MLX was developed with contributions from the following individuals:
|
||||
- AmirHossein Razlighi: Added chaining support for some of the ops in `nn.Module`. Comparison works for non array objects in `mlx.core.array`. Exception handling for invalid operations in `mlx.core.array`.
|
||||
- Gleb Pobudzey: Added the `where` primitive, and groups in 1D and 2D convolutions.
|
||||
- Paul Paczuski: Improved stability of BCE loss calculation
|
||||
- Max-Heinrich Laves: Added `conv_transpose1d`, `conv_transpose2d`, and `conv_transpose3d` ops.
|
||||
|
||||
<a href="https://github.com/ml-explore/mlx/graphs/contributors">
|
||||
<img class="dark-light" src="https://contrib.rocks/image?repo=ml-explore/mlx&anon=0&columns=20&max=100&r=true" />
|
||||
|
135
benchmarks/python/conv_transpose_bench.py
Normal file
135
benchmarks/python/conv_transpose_bench.py
Normal file
@ -0,0 +1,135 @@
|
||||
import argparse
|
||||
import math
|
||||
import os
|
||||
import subprocess
|
||||
import time
|
||||
|
||||
import mlx.core as mx
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
N_warmup = 10
|
||||
N_iter_bench = 100
|
||||
N_iter_func = 5
|
||||
|
||||
|
||||
def bench(f, a, b):
|
||||
for i in range(N_warmup):
|
||||
f(a, b)
|
||||
torch.mps.synchronize()
|
||||
|
||||
s = time.perf_counter_ns()
|
||||
for i in range(N_iter_bench):
|
||||
f(a, b)
|
||||
e = time.perf_counter_ns()
|
||||
return (e - s) * 1e-9
|
||||
|
||||
|
||||
def make_mx_conv_transpose_2D(strides=(1, 1), padding=(0, 0), groups=1):
|
||||
def mx_conv_transpose_2D(a, b):
|
||||
ys = []
|
||||
for i in range(N_iter_func):
|
||||
y = mx.conv_transpose2d(
|
||||
a, b, stride=strides, padding=padding, groups=groups
|
||||
)
|
||||
ys.append(y)
|
||||
mx.eval(ys)
|
||||
return ys
|
||||
|
||||
return mx_conv_transpose_2D
|
||||
|
||||
|
||||
def make_pt_conv_transpose_2D(strides=(1, 1), padding=(0, 0), groups=1):
|
||||
@torch.no_grad()
|
||||
def pt_conv_transpose_2D(a, b):
|
||||
ys = []
|
||||
for i in range(N_iter_func):
|
||||
y = torch.conv_transpose2d(
|
||||
a, b, stride=strides, padding=padding, groups=groups
|
||||
)
|
||||
ys.append(y)
|
||||
torch.mps.synchronize()
|
||||
return ys
|
||||
|
||||
return pt_conv_transpose_2D
|
||||
|
||||
|
||||
def bench_shape(N, H, W, C, kH, kW, O, strides, padding, groups, np_dtype):
|
||||
scale = 1.0 / math.sqrt(kH * kH * C)
|
||||
a_np = np.random.uniform(0, 0.5, (N, H, W, C)).astype(np_dtype)
|
||||
b_np = np.random.uniform(-scale, scale, (O, kH, kW, int(C / groups))).astype(
|
||||
np_dtype
|
||||
)
|
||||
|
||||
a_mx = mx.array(a_np)
|
||||
b_mx = mx.array(b_np)
|
||||
|
||||
a_pt = torch.from_numpy(a_np.transpose((0, 3, 1, 2))).to("mps")
|
||||
b_pt = torch.from_numpy(b_np.transpose((3, 0, 1, 2))).to("mps")
|
||||
|
||||
torch.mps.synchronize()
|
||||
|
||||
f_mx = make_mx_conv_transpose_2D(strides, padding, groups)
|
||||
f_pt = make_pt_conv_transpose_2D(strides, padding, groups)
|
||||
|
||||
time_torch = bench(f_pt, a_pt, b_pt)
|
||||
time_mlx = bench(f_mx, a_mx, b_mx)
|
||||
|
||||
out_mx = mx.conv_transpose2d(
|
||||
a_mx, b_mx, stride=strides, padding=padding, groups=groups
|
||||
)
|
||||
out_pt = torch.conv_transpose2d(
|
||||
a_pt.to("cpu"), b_pt.to("cpu"), stride=strides, padding=padding, groups=groups
|
||||
)
|
||||
out_pt = torch.permute(out_pt, (0, 2, 3, 1))
|
||||
out_pt = out_pt.numpy(force=True)
|
||||
|
||||
atol = 2e-5 if np_dtype == np.float32 else 1e-4
|
||||
|
||||
if not np.allclose(out_pt, out_mx, atol=atol):
|
||||
print(
|
||||
f"Failed at {(N, H, W, C)}, {(O, kH, kW, C)} [strides = {strides}, padding = {padding}, groups = {groups}] with max(|a - b|) = {np.max(np.abs(out_pt - out_mx))}"
|
||||
)
|
||||
|
||||
return time_mlx, time_torch
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="Run conv benchmarks")
|
||||
|
||||
dtypes = ("float32",)
|
||||
shapes = (
|
||||
(4, 32, 32, 32, 5, 5, 32, (1, 1), (2, 2), 1),
|
||||
(4, 32, 32, 64, 5, 5, 64, (1, 1), (2, 2), 1),
|
||||
(4, 32, 32, 128, 5, 5, 128, (1, 1), (2, 2), 1),
|
||||
(4, 32, 32, 256, 5, 5, 256, (1, 1), (2, 2), 1),
|
||||
(4, 32, 32, 512, 5, 5, 512, (1, 1), (2, 2), 1),
|
||||
(4, 64, 64, 32, 5, 5, 32, (1, 1), (2, 2), 1),
|
||||
(4, 64, 64, 64, 5, 5, 64, (1, 1), (2, 2), 1),
|
||||
(4, 64, 64, 128, 5, 5, 128, (1, 1), (2, 2), 1),
|
||||
(4, 64, 64, 256, 5, 5, 256, (1, 1), (2, 2), 1),
|
||||
(4, 128, 128, 32, 5, 5, 32, (1, 1), (2, 2), 1),
|
||||
(4, 128, 128, 64, 5, 5, 64, (1, 1), (2, 2), 1),
|
||||
(4, 128, 128, 128, 5, 5, 128, (1, 1), (2, 2), 1),
|
||||
(4, 256, 256, 32, 5, 5, 3, (1, 1), (2, 2), 1),
|
||||
(4, 256, 256, 3, 5, 5, 32, (1, 1), (2, 2), 1),
|
||||
(4, 128, 128, 64, 5, 5, 3, (1, 1), (2, 2), 1),
|
||||
(4, 128, 128, 3, 5, 5, 64, (1, 1), (2, 2), 1),
|
||||
)
|
||||
|
||||
for dtype in dtypes:
|
||||
print(
|
||||
"(N, H, W, C), ( O, kH, kW, C), dtype, stride, pads, groups, diff%"
|
||||
)
|
||||
for N, H, W, C, kH, kW, O, strides, padding, groups in shapes:
|
||||
np_dtype = getattr(np, dtype)
|
||||
time_mlx, time_torch = bench_shape(
|
||||
N, H, W, C, kH, kW, O, strides, padding, groups, np_dtype
|
||||
)
|
||||
diff = time_torch / time_mlx - 1.0
|
||||
|
||||
print(
|
||||
f"({N}, {H:3d}, {W:3d}, {C:3d}), ({O:3d}, {kH:2d}, {kW:2d}, {C:3d}), {dtype}, {strides}, {padding}, {groups:7d}, {100. * diff:+5.2f}%"
|
||||
)
|
||||
if time_mlx >= 2.0 * time_torch:
|
||||
print("ATTENTION ^^^^^^^")
|
@ -16,6 +16,9 @@ Layers
|
||||
Conv1d
|
||||
Conv2d
|
||||
Conv3d
|
||||
ConvTranspose1d
|
||||
ConvTranspose2d
|
||||
ConvTranspose3d
|
||||
Dropout
|
||||
Dropout2d
|
||||
Dropout3d
|
||||
|
@ -45,6 +45,9 @@ Operations
|
||||
conv1d
|
||||
conv2d
|
||||
conv3d
|
||||
conv_transpose1d
|
||||
conv_transpose2d
|
||||
conv_transpose3d
|
||||
conv_general
|
||||
cos
|
||||
cosh
|
||||
|
@ -1125,7 +1125,7 @@ void Convolution::eval(const std::vector<array>& inputs, array& out) {
|
||||
else {
|
||||
std::ostringstream msg;
|
||||
msg << "[Convolution::eval] Convolution currently only supports"
|
||||
<< " 1D and 2D convolutions. Got inputs with " << in.ndim() - 2
|
||||
<< " 1D, 2D and 3D convolutions. Got inputs with " << in.ndim() - 2
|
||||
<< " spatial dimensions";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
|
@ -911,7 +911,7 @@ void Convolution::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
// Throw error
|
||||
else {
|
||||
throw std::invalid_argument(
|
||||
"[Convolution::eval_gpu] Only supports 1D or 2D convolutions.");
|
||||
"[Convolution::eval_gpu] Only supports 1D, 2D or 3D convolutions.");
|
||||
}
|
||||
|
||||
// Clear copies
|
||||
|
87
mlx/ops.cpp
87
mlx/ops.cpp
@ -3298,6 +3298,93 @@ array conv3d(
|
||||
s);
|
||||
}
|
||||
|
||||
// Helper function for transposed convolutions
|
||||
array conv_transpose_general(
|
||||
const array& input,
|
||||
const array& weight,
|
||||
std::vector<int> stride,
|
||||
std::vector<int> padding,
|
||||
std::vector<int> dilation,
|
||||
int groups,
|
||||
StreamOrDevice s) {
|
||||
std::vector<int> padding_lo(padding.size());
|
||||
std::vector<int> padding_hi(padding.size());
|
||||
for (int i = 0; i < padding.size(); ++i) {
|
||||
int wt_size = 1 + dilation[i] * (weight.shape(1 + i) - 1);
|
||||
padding_lo[i] = wt_size - padding[i] - 1;
|
||||
|
||||
int conv_output_shape = (input.shape(i + 1) - 1) * stride[i] -
|
||||
2 * padding[i] + dilation[i] * (weight.shape(i + 1) - 1) + 1;
|
||||
|
||||
int in_size = 1 + (conv_output_shape - 1);
|
||||
int out_size = 1 + stride[i] * (input.shape(1 + i) - 1);
|
||||
padding_hi[i] = in_size - out_size + padding[i];
|
||||
}
|
||||
|
||||
return conv_general(
|
||||
/* const array& input = */ input,
|
||||
/* const array& weight = */ weight,
|
||||
/* std::vector<int> stride = */ std::vector(stride.size(), 1),
|
||||
/* std::vector<int> padding_lo = */ std::move(padding_lo),
|
||||
/* std::vector<int> padding_hi = */ std::move(padding_hi),
|
||||
/* std::vector<int> kernel_dilation = */ std::move(dilation),
|
||||
/* std::vector<int> input_dilation = */ std::move(stride),
|
||||
/* int groups = */ groups,
|
||||
/* bool flip = */ true,
|
||||
s);
|
||||
}
|
||||
|
||||
/** 1D transposed convolution with a filter */
|
||||
array conv_transpose1d(
|
||||
const array& in_,
|
||||
const array& wt_,
|
||||
int stride /* = 1 */,
|
||||
int padding /* = 0 */,
|
||||
int dilation /* = 1 */,
|
||||
int groups /* = 1 */,
|
||||
StreamOrDevice s /* = {} */) {
|
||||
return conv_transpose_general(
|
||||
in_, wt_, {stride}, {padding}, {dilation}, groups, s);
|
||||
}
|
||||
|
||||
/** 2D transposed convolution with a filter */
|
||||
array conv_transpose2d(
|
||||
const array& in_,
|
||||
const array& wt_,
|
||||
const std::pair<int, int>& stride /* = {1, 1} */,
|
||||
const std::pair<int, int>& padding /* = {0, 0} */,
|
||||
const std::pair<int, int>& dilation /* = {1, 1} */,
|
||||
int groups /* = 1 */,
|
||||
StreamOrDevice s /* = {} */) {
|
||||
return conv_transpose_general(
|
||||
in_,
|
||||
wt_,
|
||||
{stride.first, stride.second},
|
||||
{padding.first, padding.second},
|
||||
{dilation.first, dilation.second},
|
||||
groups,
|
||||
s);
|
||||
}
|
||||
|
||||
/** 3D transposed convolution with a filter */
|
||||
array conv_transpose3d(
|
||||
const array& in_,
|
||||
const array& wt_,
|
||||
const std::tuple<int, int, int>& stride /* = {1, 1, 1} */,
|
||||
const std::tuple<int, int, int>& padding /* = {0, 0, 0} */,
|
||||
const std::tuple<int, int, int>& dilation /* = {1, 1, 1} */,
|
||||
int groups /* = 1 */,
|
||||
StreamOrDevice s /* = {} */) {
|
||||
return conv_transpose_general(
|
||||
in_,
|
||||
wt_,
|
||||
{std::get<0>(stride), std::get<1>(stride), std::get<2>(stride)},
|
||||
{std::get<0>(padding), std::get<1>(padding), std::get<2>(padding)},
|
||||
{std::get<0>(dilation), std::get<1>(dilation), std::get<2>(dilation)},
|
||||
groups,
|
||||
s);
|
||||
}
|
||||
|
||||
/** General convolution with a filter */
|
||||
array conv_general(
|
||||
array in,
|
||||
|
30
mlx/ops.h
30
mlx/ops.h
@ -1247,6 +1247,36 @@ array conv3d(
|
||||
int groups = 1,
|
||||
StreamOrDevice s = {});
|
||||
|
||||
/** 1D transposed convolution with a filter */
|
||||
array conv_transpose1d(
|
||||
const array& input,
|
||||
const array& weight,
|
||||
int stride = 1,
|
||||
int padding = 0,
|
||||
int dilation = 1,
|
||||
int groups = 1,
|
||||
StreamOrDevice s = {});
|
||||
|
||||
/** 2D transposed convolution with a filter */
|
||||
array conv_transpose2d(
|
||||
const array& input,
|
||||
const array& weight,
|
||||
const std::pair<int, int>& stride = {1, 1},
|
||||
const std::pair<int, int>& padding = {0, 0},
|
||||
const std::pair<int, int>& dilation = {1, 1},
|
||||
int groups = 1,
|
||||
StreamOrDevice s = {});
|
||||
|
||||
/** 3D transposed convolution with a filter */
|
||||
array conv_transpose3d(
|
||||
const array& input,
|
||||
const array& weight,
|
||||
const std::tuple<int, int, int>& stride = {1, 1, 1},
|
||||
const std::tuple<int, int, int>& padding = {0, 0, 0},
|
||||
const std::tuple<int, int, int>& dilation = {1, 1, 1},
|
||||
int groups = 1,
|
||||
StreamOrDevice s = {});
|
||||
|
||||
/** Quantized matmul multiplies x with a quantized matrix w*/
|
||||
array quantized_matmul(
|
||||
const array& x,
|
||||
|
@ -952,8 +952,8 @@ std::vector<array> Convolution::vjp(
|
||||
/* const array& input = */ cotan,
|
||||
/* const array& weight = */ wt_trans,
|
||||
/* std::vector<int> stride = */ input_dilation_,
|
||||
/* std::vector<int> padding_lo = */ padding_lo_,
|
||||
/* std::vector<int> padding_hi = */ padding_hi_,
|
||||
/* std::vector<int> padding_lo = */ padding_lo,
|
||||
/* std::vector<int> padding_hi = */ padding_hi,
|
||||
/* std::vector<int> kernel_dilation = */ kernel_dilation_,
|
||||
/* std::vector<int> input_dilation = */ kernel_strides_,
|
||||
/* int groups = */ 1,
|
||||
@ -990,36 +990,61 @@ std::vector<array> Convolution::vjp(
|
||||
no_dilation &= (input_dilation_[i] == 1) && (kernel_dilation_[i] == 1);
|
||||
}
|
||||
|
||||
if (no_dilation) {
|
||||
if (no_dilation && !flip_) {
|
||||
auto grad = conv_weight_backward_patches(
|
||||
in, wt, cotan, kernel_strides_, padding_, stream());
|
||||
grads.push_back(grad);
|
||||
} else {
|
||||
std::vector<int> padding_lo = padding_;
|
||||
std::vector<int> padding_hi = padding_;
|
||||
|
||||
for (int i = 0; i < padding_hi.size(); ++i) {
|
||||
int in_size = 1 + input_dilation_[i] * (in.shape(1 + i) - 1);
|
||||
int out_size = 1 + kernel_strides_[i] * (cotan.shape(1 + i) - 1);
|
||||
int wt_size = 1 + kernel_dilation_[i] * (wt.shape(1 + i) - 1);
|
||||
padding_hi[i] = out_size - in_size + wt_size - padding_[i] - 1;
|
||||
}
|
||||
|
||||
auto in_trans = swapaxes(in, 0, -1, stream());
|
||||
auto cotan_trans = swapaxes(cotan, 0, -1, stream());
|
||||
auto grad_trans = conv_general(
|
||||
/* const array& input = */ in_trans,
|
||||
/* const array& weight = */ cotan_trans,
|
||||
/* std::vector<int> stride = */ kernel_dilation_,
|
||||
/* std::vector<int> padding_lo = */ padding_lo,
|
||||
/* std::vector<int> padding_hi = */ padding_hi,
|
||||
/* std::vector<int> kernel_dilation = */ kernel_strides_,
|
||||
/* std::vector<int> input_dilation = */ input_dilation_,
|
||||
/* int groups = */ 1,
|
||||
/* bool flip = */ flip_,
|
||||
stream());
|
||||
auto grad = swapaxes(grad_trans, 0, -1, stream());
|
||||
grads.push_back(grad);
|
||||
auto in_trans = swapaxes(in, 0, -1, stream());
|
||||
|
||||
if (flip_) {
|
||||
auto padding = padding_;
|
||||
for (int i = 0; i < padding.size(); i++) {
|
||||
int wt_size = 1 + kernel_dilation_[i] * (wt.shape(1 + i) - 1);
|
||||
padding[i] = wt_size - padding_[i] - 1;
|
||||
}
|
||||
|
||||
auto grad_trans = conv_general(
|
||||
/* const array& input = */ cotan_trans,
|
||||
/* const array& weight = */ in_trans,
|
||||
/* std::vector<int> stride = */ kernel_dilation_,
|
||||
/* std::vector<int> padding_lo = */ padding,
|
||||
/* std::vector<int> padding_hi = */ padding,
|
||||
/* std::vector<int> kernel_dilation = */ input_dilation_,
|
||||
/* std::vector<int> input_dilation = */ kernel_strides_,
|
||||
/* int groups = */ 1,
|
||||
/* bool flip = */ false,
|
||||
stream());
|
||||
auto grad = swapaxes(grad_trans, 0, -1, stream());
|
||||
grads.push_back(grad_trans);
|
||||
} else {
|
||||
std::vector<int> padding_lo = padding_;
|
||||
std::vector<int> padding_hi = padding_;
|
||||
|
||||
for (int i = 0; i < padding_hi.size(); ++i) {
|
||||
int in_size = 1 + input_dilation_[i] * (in.shape(1 + i) - 1);
|
||||
int out_size = 1 + kernel_strides_[i] * (cotan.shape(1 + i) - 1);
|
||||
int wt_size = 1 + kernel_dilation_[i] * (wt.shape(1 + i) - 1);
|
||||
padding_hi[i] = out_size - in_size + wt_size - padding_[i] - 1;
|
||||
}
|
||||
|
||||
auto in_trans = swapaxes(in, 0, -1, stream());
|
||||
auto cotan_trans = swapaxes(cotan, 0, -1, stream());
|
||||
auto grad_trans = conv_general(
|
||||
/* const array& input = */ in_trans,
|
||||
/* const array& weight = */ cotan_trans,
|
||||
/* std::vector<int> stride = */ kernel_dilation_,
|
||||
/* std::vector<int> padding_lo = */ padding_lo,
|
||||
/* std::vector<int> padding_hi = */ padding_hi,
|
||||
/* std::vector<int> kernel_dilation = */ kernel_strides_,
|
||||
/* std::vector<int> input_dilation = */ input_dilation_,
|
||||
/* int groups = */ 1,
|
||||
/* bool flip = */ false,
|
||||
stream());
|
||||
auto grad = swapaxes(grad_trans, 0, -1, stream());
|
||||
grads.push_back(grad);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -55,6 +55,11 @@ from mlx.nn.layers.activations import (
|
||||
from mlx.nn.layers.base import Module
|
||||
from mlx.nn.layers.containers import Sequential
|
||||
from mlx.nn.layers.convolution import Conv1d, Conv2d, Conv3d
|
||||
from mlx.nn.layers.convolution_transpose import (
|
||||
ConvTranspose1d,
|
||||
ConvTranspose2d,
|
||||
ConvTranspose3d,
|
||||
)
|
||||
from mlx.nn.layers.dropout import Dropout, Dropout2d, Dropout3d
|
||||
from mlx.nn.layers.embedding import Embedding
|
||||
from mlx.nn.layers.linear import Bilinear, Identity, Linear
|
||||
|
@ -21,9 +21,9 @@ class Conv1d(Module):
|
||||
out_channels (int): The number of output channels
|
||||
kernel_size (int): The size of the convolution filters
|
||||
stride (int, optional): The stride when applying the filter.
|
||||
Default: 1.
|
||||
Default: ``1``.
|
||||
padding (int, optional): How many positions to 0-pad the input with.
|
||||
Default: 0.
|
||||
Default: ``0``.
|
||||
dilation (int, optional): The dilation of the convolution.
|
||||
bias (bool, optional): If ``True`` add a learnable bias to the output.
|
||||
Default: ``True``
|
||||
@ -84,9 +84,9 @@ class Conv2d(Module):
|
||||
out_channels (int): The number of output channels.
|
||||
kernel_size (int or tuple): The size of the convolution filters.
|
||||
stride (int or tuple, optional): The size of the stride when
|
||||
applying the filter. Default: 1.
|
||||
applying the filter. Default: ``1``.
|
||||
padding (int or tuple, optional): How many positions to 0-pad
|
||||
the input with. Default: 0.
|
||||
the input with. Default: ``0``.
|
||||
dilation (int or tuple, optional): The dilation of the convolution.
|
||||
bias (bool, optional): If ``True`` add a learnable bias to the
|
||||
output. Default: ``True``
|
||||
|
206
python/mlx/nn/layers/convolution_transpose.py
Normal file
206
python/mlx/nn/layers/convolution_transpose.py
Normal file
@ -0,0 +1,206 @@
|
||||
# Copyright © 2023 Apple Inc.
|
||||
|
||||
import math
|
||||
from typing import Union
|
||||
|
||||
import mlx.core as mx
|
||||
from mlx.nn.layers.base import Module
|
||||
|
||||
|
||||
class ConvTranspose1d(Module):
|
||||
"""Applies a 1-dimensional transposed convolution over the multi-channel input sequence.
|
||||
|
||||
The channels are expected to be last i.e. the input shape should be ``NLC`` where:
|
||||
|
||||
* ``N`` is the batch dimension
|
||||
* ``L`` is the sequence length
|
||||
* ``C`` is the number of input channels
|
||||
|
||||
Args:
|
||||
in_channels (int): The number of input channels
|
||||
out_channels (int): The number of output channels
|
||||
kernel_size (int): The size of the convolution filters
|
||||
stride (int, optional): The stride when applying the filter.
|
||||
Default: ``1``.
|
||||
padding (int, optional): How many positions to 0-pad the input with.
|
||||
Default: ``0``.
|
||||
dilation (int, optional): The dilation of the convolution.
|
||||
bias (bool, optional): If ``True`` add a learnable bias to the output.
|
||||
Default: ``True``
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
out_channels: int,
|
||||
kernel_size: int,
|
||||
stride: int = 1,
|
||||
padding: int = 0,
|
||||
dilation: int = 1,
|
||||
bias: bool = True,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
scale = math.sqrt(1 / (in_channels * kernel_size))
|
||||
self.weight = mx.random.uniform(
|
||||
low=-scale,
|
||||
high=scale,
|
||||
shape=(out_channels, kernel_size, in_channels),
|
||||
)
|
||||
if bias:
|
||||
self.bias = mx.zeros((out_channels,))
|
||||
|
||||
self.padding = padding
|
||||
self.dilation = dilation
|
||||
self.stride = stride
|
||||
|
||||
def _extra_repr(self):
|
||||
return (
|
||||
f"{self.weight.shape[-1]}, {self.weight.shape[0]}, "
|
||||
f"kernel_size={self.weight.shape[1]}, stride={self.stride}, "
|
||||
f"padding={self.padding}, dilation={self.dilation}, "
|
||||
f"bias={'bias' in self}"
|
||||
)
|
||||
|
||||
def __call__(self, x):
|
||||
y = mx.conv_transpose1d(
|
||||
x, self.weight, self.stride, self.padding, self.dilation
|
||||
)
|
||||
if "bias" in self:
|
||||
y = y + self.bias
|
||||
return y
|
||||
|
||||
|
||||
class ConvTranspose2d(Module):
|
||||
"""Applies a 2-dimensional transposed convolution over the multi-channel input image.
|
||||
|
||||
The channels are expected to be last i.e. the input shape should be ``NHWC`` where:
|
||||
|
||||
* ``N`` is the batch dimension
|
||||
* ``H`` is the input image height
|
||||
* ``W`` is the input image width
|
||||
* ``C`` is the number of input channels
|
||||
|
||||
Args:
|
||||
in_channels (int): The number of input channels.
|
||||
out_channels (int): The number of output channels.
|
||||
kernel_size (int or tuple): The size of the convolution filters.
|
||||
stride (int or tuple, optional): The size of the stride when
|
||||
applying the filter. Default: ``1``.
|
||||
padding (int or tuple, optional): How many positions to 0-pad
|
||||
the input with. Default: ``0``.
|
||||
dilation (int or tuple, optional): The dilation of the convolution.
|
||||
bias (bool, optional): If ``True`` add a learnable bias to the
|
||||
output. Default: ``True``
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
out_channels: int,
|
||||
kernel_size: Union[int, tuple],
|
||||
stride: Union[int, tuple] = 1,
|
||||
padding: Union[int, tuple] = 0,
|
||||
dilation: Union[int, tuple] = 1,
|
||||
bias: bool = True,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
kernel_size, stride, padding = map(
|
||||
lambda x: (x, x) if isinstance(x, int) else x,
|
||||
(kernel_size, stride, padding),
|
||||
)
|
||||
scale = math.sqrt(1 / (in_channels * kernel_size[0] * kernel_size[1]))
|
||||
self.weight = mx.random.uniform(
|
||||
low=-scale,
|
||||
high=scale,
|
||||
shape=(out_channels, *kernel_size, in_channels),
|
||||
)
|
||||
if bias:
|
||||
self.bias = mx.zeros((out_channels,))
|
||||
|
||||
self.padding = padding
|
||||
self.stride = stride
|
||||
self.dilation = dilation
|
||||
|
||||
def _extra_repr(self):
|
||||
return (
|
||||
f"{self.weight.shape[-1]}, {self.weight.shape[0]}, "
|
||||
f"kernel_size={self.weight.shape[1:2]}, stride={self.stride}, "
|
||||
f"padding={self.padding}, dilation={self.dilation}, "
|
||||
f"bias={'bias' in self}"
|
||||
)
|
||||
|
||||
def __call__(self, x):
|
||||
y = mx.conv_transpose2d(
|
||||
x, self.weight, self.stride, self.padding, self.dilation
|
||||
)
|
||||
if "bias" in self:
|
||||
y = y + self.bias
|
||||
return y
|
||||
|
||||
|
||||
class ConvTranspose3d(Module):
|
||||
"""Applies a 3-dimensional transposed convolution over the multi-channel input image.
|
||||
|
||||
The channels are expected to be last i.e. the input shape should be ``NDHWC`` where:
|
||||
|
||||
* ``N`` is the batch dimension
|
||||
* ``D`` is the input image depth
|
||||
* ``H`` is the input image height
|
||||
* ``W`` is the input image width
|
||||
* ``C`` is the number of input channels
|
||||
|
||||
Args:
|
||||
in_channels (int): The number of input channels.
|
||||
out_channels (int): The number of output channels.
|
||||
kernel_size (int or tuple): The size of the convolution filters.
|
||||
stride (int or tuple, optional): The size of the stride when
|
||||
applying the filter. Default: ``1``.
|
||||
padding (int or tuple, optional): How many positions to 0-pad
|
||||
the input with. Default: ``0``.
|
||||
bias (bool, optional): If ``True`` add a learnable bias to the
|
||||
output. Default: ``True``
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
out_channels: int,
|
||||
kernel_size: Union[int, tuple],
|
||||
stride: Union[int, tuple] = 1,
|
||||
padding: Union[int, tuple] = 0,
|
||||
bias: bool = True,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
kernel_size, stride, padding = map(
|
||||
lambda x: (x, x, x) if isinstance(x, int) else x,
|
||||
(kernel_size, stride, padding),
|
||||
)
|
||||
scale = math.sqrt(
|
||||
1 / (in_channels * kernel_size[0] * kernel_size[1] * kernel_size[2])
|
||||
)
|
||||
self.weight = mx.random.uniform(
|
||||
low=-scale,
|
||||
high=scale,
|
||||
shape=(out_channels, *kernel_size, in_channels),
|
||||
)
|
||||
if bias:
|
||||
self.bias = mx.zeros((out_channels,))
|
||||
|
||||
self.padding = padding
|
||||
self.stride = stride
|
||||
|
||||
def _extra_repr(self):
|
||||
return (
|
||||
f"{self.weight.shape[-1]}, {self.weight.shape[0]}, "
|
||||
f"kernel_size={self.weight.shape[1:3]}, stride={self.stride}, "
|
||||
f"padding={self.padding}, bias={'bias' in self}"
|
||||
)
|
||||
|
||||
def __call__(self, x):
|
||||
y = mx.conv_transpose3d(x, self.weight, self.stride, self.padding)
|
||||
if "bias" in self:
|
||||
y = y + self.bias
|
||||
return y
|
@ -3238,12 +3238,12 @@ void init_ops(nb::module_& m) {
|
||||
1D convolution over an input with several channels
|
||||
|
||||
Args:
|
||||
input (array): input array of shape (``N``, ``H``, ``C_in``)
|
||||
weight (array): weight array of shape (``C_out``, ``H``, ``C_in``)
|
||||
stride (int, optional): kernel stride. Default: ``1``.
|
||||
padding (int, optional): input padding. Default: ``0``.
|
||||
dilation (int, optional): kernel dilation. Default: ``1``.
|
||||
groups (int, optional): input feature groups. Default: ``1``.
|
||||
input (array): Input array of shape ``(N, H, C_in)``.
|
||||
weight (array): Weight array of shape ``(C_out, H, C_in)``.
|
||||
stride (int, optional): Kernel stride. Default: ``1``.
|
||||
padding (int, optional): Input padding. Default: ``0``.
|
||||
dilation (int, optional): Kernel dilation. Default: ``1``.
|
||||
groups (int, optional): Input feature groups. Default: ``1``.
|
||||
|
||||
Returns:
|
||||
array: The convolved array.
|
||||
@ -3296,8 +3296,8 @@ void init_ops(nb::module_& m) {
|
||||
2D convolution over an input with several channels
|
||||
|
||||
Args:
|
||||
input (array): input array of shape ``(N, H, W, C_in)``
|
||||
weight (array): weight array of shape ``(C_out, H, W, C_in)``
|
||||
input (array): Input array of shape ``(N, H, W, C_in)``.
|
||||
weight (array): Weight array of shape ``(C_out, H, W, C_in)``.
|
||||
stride (int or tuple(int), optional): :obj:`tuple` of size 2 with
|
||||
kernel strides. All spatial dimensions get the same stride if
|
||||
only one number is specified. Default: ``1``.
|
||||
@ -3368,8 +3368,173 @@ void init_ops(nb::module_& m) {
|
||||
Note: Only the default ``groups=1`` is currently supported.
|
||||
|
||||
Args:
|
||||
input (array): input array of shape ``(N, D, H, W, C_in)``
|
||||
weight (array): weight array of shape ``(C_out, D, H, W, C_in)``
|
||||
input (array): Input array of shape ``(N, D, H, W, C_in)``.
|
||||
weight (array): Weight array of shape ``(C_out, D, H, W, C_in)``.
|
||||
stride (int or tuple(int), optional): :obj:`tuple` of size 3 with
|
||||
kernel strides. All spatial dimensions get the same stride if
|
||||
only one number is specified. Default: ``1``.
|
||||
padding (int or tuple(int), optional): :obj:`tuple` of size 3 with
|
||||
symmetric input padding. All spatial dimensions get the same
|
||||
padding if only one number is specified. Default: ``0``.
|
||||
dilation (int or tuple(int), optional): :obj:`tuple` of size 3 with
|
||||
kernel dilation. All spatial dimensions get the same dilation
|
||||
if only one number is specified. Default: ``1``
|
||||
groups (int, optional): input feature groups. Default: ``1``.
|
||||
|
||||
Returns:
|
||||
array: The convolved array.
|
||||
)pbdoc");
|
||||
m.def(
|
||||
"conv_transpose1d",
|
||||
&conv_transpose1d,
|
||||
nb::arg(),
|
||||
nb::arg(),
|
||||
"stride"_a = 1,
|
||||
"padding"_a = 0,
|
||||
"dilation"_a = 1,
|
||||
"groups"_a = 1,
|
||||
nb::kw_only(),
|
||||
"stream"_a = nb::none(),
|
||||
nb::sig(
|
||||
"def conv_transpose1d(input: array, weight: array, /, stride: int = 1, padding: int = 0, dilation: int = 1, groups: int = 1, *, stream: Union[None, Stream, Device] = None) -> array"),
|
||||
R"pbdoc(
|
||||
1D transposed convolution over an input with several channels
|
||||
|
||||
Args:
|
||||
input (array): Input array of shape ``(N, H, C_in)``.
|
||||
weight (array): Weight array of shape ``(C_out, H, C_in)``.
|
||||
stride (int, optional): Kernel stride. Default: ``1``.
|
||||
padding (int, optional): Input padding. Default: ``0``.
|
||||
dilation (int, optional): Kernel dilation. Default: ``1``.
|
||||
groups (int, optional): Input feature groups. Default: ``1``.
|
||||
|
||||
Returns:
|
||||
array: The convolved array.
|
||||
)pbdoc");
|
||||
m.def(
|
||||
"conv_transpose2d",
|
||||
[](const array& input,
|
||||
const array& weight,
|
||||
const std::variant<int, std::pair<int, int>>& stride,
|
||||
const std::variant<int, std::pair<int, int>>& padding,
|
||||
const std::variant<int, std::pair<int, int>>& dilation,
|
||||
int groups,
|
||||
StreamOrDevice s) {
|
||||
std::pair<int, int> stride_pair{1, 1};
|
||||
std::pair<int, int> padding_pair{0, 0};
|
||||
std::pair<int, int> dilation_pair{1, 1};
|
||||
|
||||
if (auto pv = std::get_if<int>(&stride); pv) {
|
||||
stride_pair = std::pair<int, int>{*pv, *pv};
|
||||
} else {
|
||||
stride_pair = std::get<std::pair<int, int>>(stride);
|
||||
}
|
||||
|
||||
if (auto pv = std::get_if<int>(&padding); pv) {
|
||||
padding_pair = std::pair<int, int>{*pv, *pv};
|
||||
} else {
|
||||
padding_pair = std::get<std::pair<int, int>>(padding);
|
||||
}
|
||||
|
||||
if (auto pv = std::get_if<int>(&dilation); pv) {
|
||||
dilation_pair = std::pair<int, int>{*pv, *pv};
|
||||
} else {
|
||||
dilation_pair = std::get<std::pair<int, int>>(dilation);
|
||||
}
|
||||
|
||||
return conv_transpose2d(
|
||||
input, weight, stride_pair, padding_pair, dilation_pair, groups, s);
|
||||
},
|
||||
nb::arg(),
|
||||
nb::arg(),
|
||||
"stride"_a = 1,
|
||||
"padding"_a = 0,
|
||||
"dilation"_a = 1,
|
||||
"groups"_a = 1,
|
||||
nb::kw_only(),
|
||||
"stream"_a = nb::none(),
|
||||
nb::sig(
|
||||
"def conv_transpose2d(input: array, weight: array, /, stride: Union[int, Tuple[int, int]] = 1, padding: Union[int, Tuple[int, int]] = 0, dilation: Union[int, Tuple[int, int]] = 1, groups: int = 1, *, stream: Union[None, Stream, Device] = None) -> array"),
|
||||
R"pbdoc(
|
||||
2D transposed convolution over an input with several channels
|
||||
|
||||
Note: Only the default ``groups=1`` is currently supported.
|
||||
|
||||
Args:
|
||||
input (array): Input array of shape ``(N, H, W, C_in)``.
|
||||
weight (array): Weight array of shape ``(C_out, H, W, C_in)``.
|
||||
stride (int or tuple(int), optional): :obj:`tuple` of size 2 with
|
||||
kernel strides. All spatial dimensions get the same stride if
|
||||
only one number is specified. Default: ``1``.
|
||||
padding (int or tuple(int), optional): :obj:`tuple` of size 2 with
|
||||
symmetric input padding. All spatial dimensions get the same
|
||||
padding if only one number is specified. Default: ``0``.
|
||||
dilation (int or tuple(int), optional): :obj:`tuple` of size 2 with
|
||||
kernel dilation. All spatial dimensions get the same dilation
|
||||
if only one number is specified. Default: ``1``
|
||||
groups (int, optional): input feature groups. Default: ``1``.
|
||||
|
||||
Returns:
|
||||
array: The convolved array.
|
||||
)pbdoc");
|
||||
m.def(
|
||||
"conv_transpose3d",
|
||||
[](const array& input,
|
||||
const array& weight,
|
||||
const std::variant<int, std::tuple<int, int, int>>& stride,
|
||||
const std::variant<int, std::tuple<int, int, int>>& padding,
|
||||
const std::variant<int, std::tuple<int, int, int>>& dilation,
|
||||
int groups,
|
||||
StreamOrDevice s) {
|
||||
std::tuple<int, int, int> stride_tuple{1, 1, 1};
|
||||
std::tuple<int, int, int> padding_tuple{0, 0, 0};
|
||||
std::tuple<int, int, int> dilation_tuple{1, 1, 1};
|
||||
|
||||
if (auto pv = std::get_if<int>(&stride); pv) {
|
||||
stride_tuple = std::tuple<int, int, int>{*pv, *pv, *pv};
|
||||
} else {
|
||||
stride_tuple = std::get<std::tuple<int, int, int>>(stride);
|
||||
}
|
||||
|
||||
if (auto pv = std::get_if<int>(&padding); pv) {
|
||||
padding_tuple = std::tuple<int, int, int>{*pv, *pv, *pv};
|
||||
} else {
|
||||
padding_tuple = std::get<std::tuple<int, int, int>>(padding);
|
||||
}
|
||||
|
||||
if (auto pv = std::get_if<int>(&dilation); pv) {
|
||||
dilation_tuple = std::tuple<int, int, int>{*pv, *pv, *pv};
|
||||
} else {
|
||||
dilation_tuple = std::get<std::tuple<int, int, int>>(dilation);
|
||||
}
|
||||
|
||||
return conv_transpose3d(
|
||||
input,
|
||||
weight,
|
||||
stride_tuple,
|
||||
padding_tuple,
|
||||
dilation_tuple,
|
||||
groups,
|
||||
s);
|
||||
},
|
||||
nb::arg(),
|
||||
nb::arg(),
|
||||
"stride"_a = 1,
|
||||
"padding"_a = 0,
|
||||
"dilation"_a = 1,
|
||||
"groups"_a = 1,
|
||||
nb::kw_only(),
|
||||
"stream"_a = nb::none(),
|
||||
nb::sig(
|
||||
"def conv_transpose3d(input: array, weight: array, /, stride: Union[int, Tuple[int, int, int]] = 1, padding: Union[int, Tuple[int, int, int]] = 0, dilation: Union[int, Tuple[int, int, int]] = 1, groups: int = 1, *, stream: Union[None, Stream, Device] = None) -> array"),
|
||||
R"pbdoc(
|
||||
3D transposed convolution over an input with several channels
|
||||
|
||||
Note: Only the default ``groups=1`` is currently supported.
|
||||
|
||||
Args:
|
||||
input (array): Input array of shape ``(N, D, H, W, C_in)``.
|
||||
weight (array): Weight array of shape ``(C_out, D, H, W, C_in)``.
|
||||
stride (int or tuple(int), optional): :obj:`tuple` of size 3 with
|
||||
kernel strides. All spatial dimensions get the same stride if
|
||||
only one number is specified. Default: ``1``.
|
||||
@ -3465,8 +3630,8 @@ void init_ops(nb::module_& m) {
|
||||
General convolution over an input with several channels
|
||||
|
||||
Args:
|
||||
input (array): Input array of shape ``(N, ..., C_in)``
|
||||
weight (array): Weight array of shape ``(C_out, ..., C_in)``
|
||||
input (array): Input array of shape ``(N, ..., C_in)``.
|
||||
weight (array): Weight array of shape ``(C_out, ..., C_in)``.
|
||||
stride (int or list(int), optional): :obj:`list` with kernel strides.
|
||||
All spatial dimensions get the same stride if
|
||||
only one number is specified. Default: ``1``.
|
||||
|
@ -866,6 +866,37 @@ class TestConv(mlx_tests.MLXTestCase):
|
||||
flip=flip,
|
||||
)
|
||||
|
||||
def test_conv_general_flip_grad(self):
|
||||
for s in (1, 2):
|
||||
w = mx.random.normal(shape=(1, 2, 2, 1))
|
||||
x = mx.random.normal(shape=(1, 2, 2, 1))
|
||||
|
||||
def conv_t(w):
|
||||
return mx.conv_general(
|
||||
x,
|
||||
w,
|
||||
stride=1,
|
||||
padding=(1, 1),
|
||||
kernel_dilation=1,
|
||||
input_dilation=s,
|
||||
flip=True,
|
||||
)
|
||||
|
||||
cotan = mx.random.normal(shape=(1, 2 + s, 2 + s, 1))
|
||||
|
||||
dw = mx.vjp(conv_t, (w,), (cotan,))[1][0]
|
||||
|
||||
x = x.squeeze()
|
||||
cotan = cotan.squeeze()
|
||||
dw = dw.squeeze()
|
||||
|
||||
dw00 = (cotan[:-1:s, :-1:s] * x).sum()
|
||||
dw01 = (cotan[:-1:s, 1::s] * x).sum()
|
||||
dw10 = (cotan[1::s, :-1:s] * x).sum()
|
||||
dw11 = (cotan[1::s, 1::s] * x).sum()
|
||||
expected = mx.array([[dw00, dw01], [dw10, dw11]])
|
||||
self.assertTrue(mx.allclose(dw, expected))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
601
python/tests/test_conv_transpose.py
Normal file
601
python/tests/test_conv_transpose.py
Normal file
@ -0,0 +1,601 @@
|
||||
# Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
import math
|
||||
import unittest
|
||||
from itertools import permutations
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx_tests
|
||||
import numpy as np
|
||||
|
||||
try:
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
has_torch = True
|
||||
except ImportError as e:
|
||||
has_torch = False
|
||||
|
||||
|
||||
class TestConvTranspose(mlx_tests.MLXTestCase):
|
||||
@unittest.skipIf(not has_torch, "requires Torch")
|
||||
def test_torch_conv_transpose_1D(self):
|
||||
def run_conv_transpose_1D(
|
||||
N,
|
||||
C,
|
||||
O,
|
||||
iH,
|
||||
kH,
|
||||
stride,
|
||||
padding,
|
||||
output_padding=0,
|
||||
dilation=1,
|
||||
groups=1,
|
||||
dtype="float32",
|
||||
atol=1e-5,
|
||||
):
|
||||
with self.subTest(
|
||||
dtype=dtype,
|
||||
N=N,
|
||||
C=C,
|
||||
O=O,
|
||||
iH=iH,
|
||||
kH=kH,
|
||||
stride=stride,
|
||||
padding=padding,
|
||||
dilation=dilation,
|
||||
groups=groups,
|
||||
):
|
||||
np_dtype = getattr(np, dtype)
|
||||
np.random.seed(0)
|
||||
in_np = np.random.normal(0, 1.0 / C, (N, iH, C)).astype(np_dtype)
|
||||
wt_np = np.random.normal(0, 1.0 / C, (O, kH, int(C / groups))).astype(
|
||||
np_dtype
|
||||
)
|
||||
|
||||
in_mx, wt_mx = map(mx.array, (in_np, wt_np))
|
||||
in_pt = torch.from_numpy(in_np.transpose(0, 2, 1))
|
||||
wt_pt = torch.from_numpy(wt_np.transpose(2, 0, 1))
|
||||
|
||||
out_mx = mx.conv_transpose1d(
|
||||
in_mx,
|
||||
wt_mx,
|
||||
stride=stride,
|
||||
padding=padding,
|
||||
dilation=dilation,
|
||||
groups=groups,
|
||||
)
|
||||
out_pt = torch.conv_transpose1d(
|
||||
in_pt,
|
||||
wt_pt,
|
||||
stride=stride,
|
||||
padding=padding,
|
||||
dilation=dilation,
|
||||
groups=groups,
|
||||
)
|
||||
out_pt = torch.transpose(out_pt, 2, 1)
|
||||
|
||||
self.assertEqual(out_pt.shape, out_mx.shape)
|
||||
self.assertTrue(np.allclose(out_pt.numpy(), out_mx, atol=atol))
|
||||
|
||||
for dtype in ("float32",):
|
||||
for N, C, O in (
|
||||
(1, 1, 1),
|
||||
(1, 6, 1),
|
||||
(1, 1, 6),
|
||||
(4, 32, 64),
|
||||
):
|
||||
for iH, kH, stride, padding in (
|
||||
(1, 1, 1, 0),
|
||||
(3, 3, 1, 0),
|
||||
(31, 5, 5, 2),
|
||||
):
|
||||
run_conv_transpose_1D(N, C, O, iH, kH, stride, padding, dtype=dtype)
|
||||
|
||||
# Groups tests
|
||||
N, C, O = (4, 32, 64)
|
||||
for iH, kH, stride, padding in (
|
||||
(1, 1, 1, 0),
|
||||
(3, 3, 1, 0),
|
||||
(31, 5, 5, 2),
|
||||
):
|
||||
for group in (1,):
|
||||
run_conv_transpose_1D(
|
||||
N, C, O, iH, kH, stride, padding, groups=group, dtype=dtype
|
||||
)
|
||||
|
||||
# Strided inputs tests
|
||||
for tpose_in, tpose_wt in (
|
||||
((0, 2, 1), (0, 1, 2)),
|
||||
((0, 2, 1), (0, 2, 1)),
|
||||
):
|
||||
with self.subTest(name="strided", tpose_in=tpose_in, tpose_wt=tpose_wt):
|
||||
in_np = np.random.normal(0, 1.0 / 16, (16, 16, 16)).astype(np.float32)
|
||||
wt_np = np.random.normal(0, 1.0 / 16, (16, 16, 16)).astype(np.float32)
|
||||
|
||||
in_mx, wt_mx = map(mx.array, (in_np, wt_np))
|
||||
in_mx_t = mx.transpose(in_mx, tpose_in)
|
||||
wt_mx_t = mx.transpose(wt_mx, tpose_wt)
|
||||
out_mx = mx.conv_transpose1d(in_mx_t, wt_mx_t)
|
||||
|
||||
in_pt = torch.from_numpy(in_np.transpose(tpose_in).transpose(0, 2, 1))
|
||||
wt_pt = torch.from_numpy(wt_np.transpose(tpose_wt).transpose(2, 0, 1))
|
||||
|
||||
out_pt = torch.conv_transpose1d(in_pt, wt_pt)
|
||||
out_pt = torch.transpose(out_pt, 2, 1)
|
||||
|
||||
self.assertEqual(out_pt.shape, out_mx.shape)
|
||||
self.assertTrue(np.allclose(out_pt.numpy(), out_mx, atol=1e-5))
|
||||
|
||||
@unittest.skipIf(not has_torch, "requires Torch")
|
||||
def test_torch_conv_transpose_1D_grad(self):
|
||||
def run_conv_transpose1D_grad(
|
||||
N,
|
||||
C,
|
||||
O,
|
||||
iH,
|
||||
kH,
|
||||
stride,
|
||||
padding,
|
||||
dilation=1,
|
||||
groups=1,
|
||||
dtype="float32",
|
||||
atol=1e-5,
|
||||
):
|
||||
with self.subTest(
|
||||
dtype=dtype,
|
||||
N=N,
|
||||
C=C,
|
||||
O=O,
|
||||
iH=iH,
|
||||
kH=kH,
|
||||
stride=stride,
|
||||
padding=padding,
|
||||
dilation=dilation,
|
||||
groups=groups,
|
||||
):
|
||||
np_dtype = getattr(np, dtype)
|
||||
np.random.seed(0)
|
||||
# oH = 1 + ((iH + 2 * padding - dilation * (kH - 1) - 1) // stride)
|
||||
|
||||
in_np = np.random.normal(0, 1.0 / C, (N, iH, C)).astype(np_dtype)
|
||||
wt_np = np.random.normal(0, 1.0 / C, (O, kH, C)).astype(np_dtype)
|
||||
|
||||
in_mx, wt_mx = map(mx.array, (in_np, wt_np))
|
||||
in_pt = torch.from_numpy(in_np.transpose(0, 2, 1)).requires_grad_(True)
|
||||
wt_pt = torch.from_numpy(wt_np.transpose(2, 0, 1)).requires_grad_(True)
|
||||
|
||||
out_pt = F.conv_transpose1d(
|
||||
in_pt, wt_pt, stride=stride, padding=padding, dilation=dilation
|
||||
)
|
||||
|
||||
# use torch to compute ct
|
||||
out_pt.retain_grad()
|
||||
(out_pt - torch.randn_like(out_pt)).abs().sum().backward()
|
||||
|
||||
pt_grad_in = in_pt.grad.permute(0, 2, 1).numpy()
|
||||
pt_grad_wt = wt_pt.grad.permute(1, 2, 0).numpy()
|
||||
|
||||
ct_mx = mx.array(out_pt.grad.numpy().transpose(0, 2, 1))
|
||||
|
||||
def f(a, b):
|
||||
return mx.conv_transpose1d(
|
||||
a,
|
||||
b,
|
||||
stride=stride,
|
||||
padding=padding,
|
||||
dilation=dilation,
|
||||
groups=groups,
|
||||
)
|
||||
|
||||
_, outs_mx = mx.vjp(
|
||||
f,
|
||||
[
|
||||
in_mx,
|
||||
wt_mx,
|
||||
],
|
||||
[
|
||||
ct_mx,
|
||||
],
|
||||
)
|
||||
|
||||
mx_grad_in, mx_grad_wt = outs_mx
|
||||
|
||||
self.assertEqual(pt_grad_in.shape, mx_grad_in.shape)
|
||||
self.assertEqual(in_mx.shape, mx_grad_in.shape)
|
||||
self.assertTrue(np.allclose(pt_grad_in, mx_grad_in, atol=atol))
|
||||
|
||||
self.assertEqual(pt_grad_wt.shape, mx_grad_wt.shape)
|
||||
self.assertEqual(wt_mx.shape, mx_grad_wt.shape)
|
||||
self.assertTrue(np.allclose(pt_grad_wt, mx_grad_wt, atol=atol))
|
||||
|
||||
for dtype in ("float32",):
|
||||
for N, C, O in (
|
||||
(1, 1, 1),
|
||||
(1, 6, 1),
|
||||
(1, 1, 6),
|
||||
(4, 32, 64),
|
||||
):
|
||||
for iH, kH, stride, padding in (
|
||||
(1, 1, 1, 0),
|
||||
(3, 3, 1, 0),
|
||||
(31, 5, 5, 2),
|
||||
):
|
||||
run_conv_transpose1D_grad(
|
||||
N, C, O, iH, kH, stride, padding, dtype=dtype
|
||||
)
|
||||
|
||||
@unittest.skipIf(not has_torch, "requires Torch")
|
||||
def test_torch_conv_transpose_2D(self):
|
||||
def run_conv_transpose2D(
|
||||
N,
|
||||
C,
|
||||
O,
|
||||
idim,
|
||||
kdim,
|
||||
stride,
|
||||
padding,
|
||||
dilation=(1, 1),
|
||||
groups=1,
|
||||
dtype="float32",
|
||||
atol=1e-5,
|
||||
):
|
||||
with self.subTest(
|
||||
dtype=dtype,
|
||||
N=N,
|
||||
C=C,
|
||||
O=O,
|
||||
idim=idim,
|
||||
kdim=kdim,
|
||||
stride=stride,
|
||||
padding=padding,
|
||||
dilation=dilation,
|
||||
groups=groups,
|
||||
):
|
||||
np_dtype = getattr(np, dtype)
|
||||
np.random.seed(0)
|
||||
iH, iW = idim
|
||||
kH, kW = kdim
|
||||
scale = 1.0 / math.sqrt(kH * kW * C)
|
||||
in_np = np.random.normal(0.0, scale, (N, iH, iW, C)).astype(np_dtype)
|
||||
wt_np = np.random.normal(0.0, 1.0, (O, kH, kW, int(C / groups))).astype(
|
||||
np_dtype
|
||||
)
|
||||
|
||||
in_mx, wt_mx = map(mx.array, (in_np, wt_np))
|
||||
in_pt = torch.from_numpy(in_np.transpose(0, 3, 1, 2)).to("cpu")
|
||||
wt_pt = torch.from_numpy(wt_np.transpose(3, 0, 1, 2)).to("cpu")
|
||||
|
||||
out_mx = mx.conv_transpose2d(
|
||||
in_mx,
|
||||
wt_mx,
|
||||
stride=stride,
|
||||
padding=padding,
|
||||
dilation=dilation,
|
||||
groups=groups,
|
||||
)
|
||||
out_pt = torch.conv_transpose2d(
|
||||
in_pt,
|
||||
wt_pt,
|
||||
stride=stride,
|
||||
padding=padding,
|
||||
dilation=dilation,
|
||||
groups=groups,
|
||||
)
|
||||
out_pt = torch.permute(out_pt, (0, 2, 3, 1)).numpy(force=True)
|
||||
|
||||
self.assertEqual(out_pt.shape, out_mx.shape)
|
||||
self.assertTrue(np.allclose(out_pt, out_mx, atol=atol))
|
||||
|
||||
for dtype in ("float32",):
|
||||
for N, C, O in (
|
||||
(1, 1, 1),
|
||||
(1, 6, 1),
|
||||
(1, 1, 6),
|
||||
(4, 32, 64),
|
||||
):
|
||||
for idim, kdim, stride, padding in (
|
||||
((1, 1), (1, 1), (1, 1), (0, 0)),
|
||||
((3, 3), (3, 1), (1, 1), (0, 0)),
|
||||
((31, 31), (5, 5), (5, 5), (2, 2)),
|
||||
):
|
||||
run_conv_transpose2D(
|
||||
N, C, O, idim, kdim, stride, padding, dtype=dtype
|
||||
)
|
||||
|
||||
# Groups tests
|
||||
N, C, O = (4, 32, 64)
|
||||
for idim, kdim, stride, padding in (
|
||||
((1, 1), (1, 1), (1, 1), (0, 0)),
|
||||
((3, 3), (3, 1), (1, 1), (0, 0)),
|
||||
((31, 31), (5, 5), (5, 5), (2, 2)),
|
||||
):
|
||||
for group in (1,):
|
||||
run_conv_transpose2D(
|
||||
N, C, O, idim, kdim, stride, padding, groups=group, dtype=dtype
|
||||
)
|
||||
|
||||
@unittest.skipIf(not has_torch, "requires Torch")
|
||||
def test_torch_conv_transpose_2D_grad(self):
|
||||
def run_conv_transpose2D_grad(
|
||||
N,
|
||||
C,
|
||||
O,
|
||||
idim,
|
||||
kdim,
|
||||
stride,
|
||||
padding,
|
||||
dilation=(1, 1),
|
||||
groups=1,
|
||||
dtype="float32",
|
||||
atol=1e-5,
|
||||
):
|
||||
with self.subTest(
|
||||
dtype=dtype,
|
||||
N=N,
|
||||
C=C,
|
||||
O=O,
|
||||
idim=idim,
|
||||
kdim=kdim,
|
||||
stride=stride,
|
||||
padding=padding,
|
||||
dilation=dilation,
|
||||
groups=groups,
|
||||
):
|
||||
np_dtype = getattr(np, dtype)
|
||||
np.random.seed(0)
|
||||
iH, iW = idim
|
||||
kH, kW = kdim
|
||||
scale = 1.0 / math.sqrt(kH * kW * C * O)
|
||||
|
||||
in_np = np.random.normal(0.0, scale, (N, iH, iW, C)).astype(np_dtype)
|
||||
wt_np = np.random.normal(0.0, scale, (O, kH, kW, C)).astype(np_dtype)
|
||||
|
||||
in_mx, wt_mx = map(mx.array, (in_np, wt_np))
|
||||
in_pt = torch.from_numpy(in_np.transpose(0, 3, 1, 2)).requires_grad_(
|
||||
True
|
||||
)
|
||||
wt_pt = torch.from_numpy(wt_np.transpose(3, 0, 1, 2)).requires_grad_(
|
||||
True
|
||||
)
|
||||
|
||||
out_pt = F.conv_transpose2d(
|
||||
in_pt, wt_pt, stride=stride, padding=padding, dilation=dilation
|
||||
)
|
||||
|
||||
# use torch to compute ct
|
||||
out_pt.retain_grad()
|
||||
(out_pt - torch.randn_like(out_pt)).abs().sum().backward()
|
||||
|
||||
pt_grad_in = in_pt.grad.permute(0, 2, 3, 1).numpy()
|
||||
pt_grad_wt = wt_pt.grad.permute(1, 2, 3, 0).numpy()
|
||||
|
||||
ct_mx = mx.array(out_pt.grad.numpy().transpose(0, 2, 3, 1))
|
||||
|
||||
def f(a, b):
|
||||
return mx.conv_transpose2d(
|
||||
a,
|
||||
b,
|
||||
stride=stride,
|
||||
padding=padding,
|
||||
dilation=dilation,
|
||||
groups=groups,
|
||||
)
|
||||
|
||||
_, outs_mx = mx.vjp(
|
||||
f,
|
||||
[in_mx, wt_mx],
|
||||
[ct_mx],
|
||||
)
|
||||
|
||||
mx_grad_in, mx_grad_wt = outs_mx
|
||||
|
||||
self.assertEqual(pt_grad_in.shape, mx_grad_in.shape)
|
||||
self.assertEqual(in_mx.shape, mx_grad_in.shape)
|
||||
self.assertTrue(np.allclose(pt_grad_in, mx_grad_in, atol=atol))
|
||||
|
||||
self.assertEqual(pt_grad_wt.shape, mx_grad_wt.shape)
|
||||
self.assertEqual(wt_mx.shape, mx_grad_wt.shape)
|
||||
self.assertTrue(np.allclose(pt_grad_wt, mx_grad_wt, atol=atol))
|
||||
|
||||
for dtype in ("float32",):
|
||||
for N, C, O in ((1, 1, 1), (1, 6, 1), (1, 1, 6), (4, 32, 64), (4, 16, 32)):
|
||||
for idim, kdim, stride, padding, dilation in (
|
||||
((1, 1), (1, 1), (1, 1), (0, 0), (1, 1)),
|
||||
((3, 3), (3, 1), (1, 1), (0, 0), (1, 1)),
|
||||
((31, 31), (5, 5), (5, 5), (2, 2), (1, 1)),
|
||||
((32, 32), (3, 3), (2, 2), (1, 1), (1, 1)),
|
||||
((31, 31), (5, 5), (5, 5), (2, 2), (3, 2)),
|
||||
((32, 32), (3, 3), (2, 2), (1, 1), (3, 2)),
|
||||
):
|
||||
run_conv_transpose2D_grad(
|
||||
N, C, O, idim, kdim, stride, padding, dilation, dtype=dtype
|
||||
)
|
||||
|
||||
@unittest.skipIf(not has_torch, "requires Torch")
|
||||
def test_torch_conv_transpose_3D(self):
|
||||
def run_conv_transpose3D(
|
||||
N,
|
||||
C,
|
||||
O,
|
||||
idim,
|
||||
kdim,
|
||||
stride,
|
||||
padding,
|
||||
dilation=(1, 1, 1),
|
||||
groups=1,
|
||||
dtype="float32",
|
||||
atol=1e-5,
|
||||
):
|
||||
with self.subTest(
|
||||
dtype=dtype,
|
||||
N=N,
|
||||
C=C,
|
||||
O=O,
|
||||
idim=idim,
|
||||
kdim=kdim,
|
||||
stride=stride,
|
||||
padding=padding,
|
||||
dilation=dilation,
|
||||
groups=groups,
|
||||
):
|
||||
np_dtype = getattr(np, dtype)
|
||||
np.random.seed(0)
|
||||
iD, iH, iW = idim
|
||||
kD, kH, kW = kdim
|
||||
scale = 1.0 / math.sqrt(kD * kH * kW * C * O)
|
||||
in_np = np.random.normal(0.0, scale, (N, iD, iH, iW, C)).astype(
|
||||
np_dtype
|
||||
)
|
||||
wt_np = np.random.normal(0.0, 1.0, (O, kD, kH, kW, C)).astype(np_dtype)
|
||||
|
||||
in_mx, wt_mx = map(mx.array, (in_np, wt_np))
|
||||
in_pt = torch.from_numpy(in_np.transpose(0, 4, 1, 2, 3))
|
||||
wt_pt = torch.from_numpy(wt_np.transpose(4, 0, 1, 2, 3))
|
||||
|
||||
out_mx = mx.conv_transpose3d(
|
||||
in_mx,
|
||||
wt_mx,
|
||||
stride=stride,
|
||||
padding=padding,
|
||||
dilation=dilation,
|
||||
groups=groups,
|
||||
)
|
||||
out_pt = torch.conv_transpose3d(
|
||||
in_pt,
|
||||
wt_pt,
|
||||
stride=stride,
|
||||
padding=padding,
|
||||
dilation=dilation,
|
||||
groups=groups,
|
||||
)
|
||||
out_pt = torch.permute(out_pt, (0, 2, 3, 4, 1)).numpy(force=True)
|
||||
|
||||
self.assertEqual(out_pt.shape, out_mx.shape)
|
||||
self.assertTrue(np.allclose(out_pt, out_mx, atol=atol))
|
||||
|
||||
for dtype in ("float32",):
|
||||
for N, C, O in (
|
||||
(1, 1, 1),
|
||||
(1, 6, 1),
|
||||
(1, 1, 6),
|
||||
(2, 8, 16),
|
||||
):
|
||||
for idim, kdim, stride, padding in (
|
||||
((1, 1, 1), (1, 1, 1), (1, 1, 1), (0, 0, 0)),
|
||||
((3, 3, 3), (3, 1, 1), (1, 1, 1), (0, 0, 0)),
|
||||
((15, 15, 15), (3, 3, 3), (3, 3, 3), (2, 2, 2)),
|
||||
):
|
||||
run_conv_transpose3D(
|
||||
N, C, O, idim, kdim, stride, padding, dtype=dtype
|
||||
)
|
||||
|
||||
@unittest.skipIf(not has_torch, "requires Torch")
|
||||
def test_torch_conv_transpose_3D_grad(self):
|
||||
def run_conv_transpose3D_grad(
|
||||
N,
|
||||
C,
|
||||
O,
|
||||
idim,
|
||||
kdim,
|
||||
stride,
|
||||
padding,
|
||||
dilation=(1, 1, 1),
|
||||
groups=1,
|
||||
dtype="float32",
|
||||
atol=1e-4,
|
||||
):
|
||||
with self.subTest(
|
||||
dtype=dtype,
|
||||
N=N,
|
||||
C=C,
|
||||
O=O,
|
||||
idim=idim,
|
||||
kdim=kdim,
|
||||
stride=stride,
|
||||
padding=padding,
|
||||
dilation=dilation,
|
||||
groups=groups,
|
||||
):
|
||||
np_dtype = getattr(np, dtype)
|
||||
np.random.seed(0)
|
||||
iD, iH, iW = idim
|
||||
kD, kH, kW = kdim
|
||||
scale = 1.0 / math.sqrt(kD * kH * kW * C * O)
|
||||
|
||||
in_np = np.random.normal(0.0, scale, (N, iD, iH, iW, C)).astype(
|
||||
np_dtype
|
||||
)
|
||||
wt_np = np.random.normal(0.0, scale, (O, kD, kH, kW, C)).astype(
|
||||
np_dtype
|
||||
)
|
||||
|
||||
in_mx, wt_mx = map(mx.array, (in_np, wt_np))
|
||||
in_pt = torch.from_numpy(in_np.transpose(0, 4, 1, 2, 3)).requires_grad_(
|
||||
True
|
||||
)
|
||||
wt_pt = torch.from_numpy(wt_np.transpose(4, 0, 1, 2, 3)).requires_grad_(
|
||||
True
|
||||
)
|
||||
|
||||
out_pt = F.conv_transpose3d(
|
||||
in_pt,
|
||||
wt_pt,
|
||||
stride=stride,
|
||||
padding=padding,
|
||||
dilation=dilation,
|
||||
groups=groups,
|
||||
)
|
||||
|
||||
# use torch to compute ct
|
||||
out_pt.retain_grad()
|
||||
(out_pt - torch.randn_like(out_pt)).abs().sum().backward()
|
||||
|
||||
pt_grad_in = in_pt.grad.permute(0, 2, 3, 4, 1).numpy()
|
||||
pt_grad_wt = wt_pt.grad.permute(1, 2, 3, 4, 0).numpy()
|
||||
|
||||
ct_mx = mx.array(out_pt.grad.numpy().transpose(0, 2, 3, 4, 1))
|
||||
|
||||
def f(a, b):
|
||||
return mx.conv_transpose3d(
|
||||
a,
|
||||
b,
|
||||
stride=stride,
|
||||
padding=padding,
|
||||
dilation=dilation,
|
||||
groups=groups,
|
||||
)
|
||||
|
||||
_, outs_mx = mx.vjp(
|
||||
f,
|
||||
[in_mx, wt_mx],
|
||||
[ct_mx],
|
||||
)
|
||||
|
||||
mx_grad_in, mx_grad_wt = outs_mx
|
||||
|
||||
self.assertEqual(pt_grad_in.shape, mx_grad_in.shape)
|
||||
self.assertEqual(in_mx.shape, mx_grad_in.shape)
|
||||
self.assertTrue(np.allclose(pt_grad_in, mx_grad_in, atol=atol))
|
||||
|
||||
self.assertEqual(pt_grad_wt.shape, mx_grad_wt.shape)
|
||||
self.assertEqual(wt_mx.shape, mx_grad_wt.shape)
|
||||
self.assertTrue(np.allclose(pt_grad_wt, mx_grad_wt, atol=atol))
|
||||
|
||||
for dtype in ("float32",):
|
||||
for N, C, O in ((1, 1, 1), (1, 6, 1), (1, 1, 6), (2, 4, 8), (2, 8, 16)):
|
||||
for idim, kdim, stride, padding, dilation in (
|
||||
((1, 1, 1), (1, 1, 1), (1, 1, 1), (0, 0, 0), (1, 1, 1)),
|
||||
((3, 3, 3), (3, 1, 1), (1, 1, 1), (0, 0, 0), (1, 1, 1)),
|
||||
((15, 15, 15), (5, 5, 5), (5, 5, 5), (2, 2, 2), (1, 1, 1)),
|
||||
((16, 16, 16), (3, 3, 3), (2, 2, 2), (1, 1, 1), (1, 1, 1)),
|
||||
((15, 15, 15), (5, 5, 5), (5, 5, 5), (2, 2, 2), (3, 2, 2)),
|
||||
((16, 16, 16), (3, 3, 3), (2, 2, 2), (1, 1, 1), (3, 2, 2)),
|
||||
):
|
||||
run_conv_transpose3D_grad(
|
||||
N, C, O, idim, kdim, stride, padding, dilation, dtype=dtype
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
Loading…
Reference in New Issue
Block a user