Convolution update (#651)

* Init steel conv and update Conv primitive

* Update slow CPU implementation to support flipping and input dilation winograd conv routing

Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
Jagrit Digani
2024-02-28 20:11:16 -08:00
committed by GitHub
parent f5f18b704f
commit 776c3d226d
27 changed files with 2830 additions and 906 deletions

View File

@@ -0,0 +1,129 @@
import argparse
import math
import os
import subprocess
import time
import mlx.core as mx
import numpy as np
import torch
device_name = subprocess.check_output(["sysctl", "-n", "machdep.cpu.brand_string"])
device_name = device_name.decode("utf-8").strip("\n")
N_warmup = 10
N_iter_bench = 100
N_iter_func = 5
def bench(f, a, b):
for i in range(N_warmup):
f(a, b)
torch.mps.synchronize()
s = time.perf_counter_ns()
for i in range(N_iter_bench):
f(a, b)
e = time.perf_counter_ns()
return (e - s) * 1e-9
def make_mx_conv_2D(strides=(1, 1), padding=(0, 0)):
def mx_conv_2D(a, b):
ys = []
for i in range(N_iter_func):
y = mx.conv2d(a, b, stride=strides, padding=padding)
ys.append(y)
mx.eval(ys)
return ys
return mx_conv_2D
def make_pt_conv_2D(strides=(1, 1), padding=(0, 0)):
@torch.no_grad()
def pt_conv_2D(a, b):
ys = []
for i in range(N_iter_func):
y = torch.conv2d(a, b, stride=strides, padding=padding)
ys.append(y)
torch.mps.synchronize()
return ys
return pt_conv_2D
def bench_shape(N, H, W, C, kH, kW, O, strides, padding, np_dtype):
scale = 1.0 / math.sqrt(kH * kH * C)
a_np = np.random.uniform(0, 0.5, (N, H, W, C)).astype(np_dtype)
b_np = np.random.uniform(-scale, scale, (O, kH, kW, C)).astype(np_dtype)
a_mx = mx.array(a_np)
b_mx = mx.array(b_np)
a_pt = torch.from_numpy(a_np.transpose((0, 3, 1, 2))).to("mps")
b_pt = torch.from_numpy(b_np.transpose((0, 3, 1, 2))).to("mps")
torch.mps.synchronize()
f_mx = make_mx_conv_2D(strides, padding)
f_pt = make_pt_conv_2D(strides, padding)
time_torch = bench(f_pt, a_pt, b_pt)
time_mlx = bench(f_mx, a_mx, b_mx)
out_mx = mx.conv2d(a_mx, b_mx, stride=strides, padding=padding)
out_pt = torch.conv2d(
a_pt.to("cpu"), b_pt.to("cpu"), stride=strides, padding=padding
)
out_pt = torch.permute(out_pt, (0, 2, 3, 1))
out_pt = out_pt.numpy(force=True)
atol = 2e-5 if np_dtype == np.float32 else 1e-4
if not np.allclose(out_pt, out_mx, atol=atol):
print(
f"Failed at {(N, H, W, C)}, {(O, kH, kW, C)} [strides = {strides}, padding = {padding}] with max(|a - b|) = {np.max(np.abs(out_pt - out_mx))}"
)
return time_mlx, time_torch
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Run conv benchmarks")
dtypes = ("float32",)
shapes = (
(4, 32, 32, 32, 5, 5, 32, (1, 1), (2, 2)),
(4, 32, 32, 64, 5, 5, 64, (1, 1), (2, 2)),
(4, 32, 32, 128, 5, 5, 128, (1, 1), (2, 2)),
(4, 32, 32, 256, 5, 5, 256, (1, 1), (2, 2)),
(4, 32, 32, 512, 5, 5, 512, (1, 1), (2, 2)),
(4, 64, 64, 32, 5, 5, 32, (1, 1), (2, 2)),
(4, 64, 64, 64, 5, 5, 64, (1, 1), (2, 2)),
(4, 64, 64, 128, 5, 5, 128, (1, 1), (2, 2)),
(4, 64, 64, 256, 5, 5, 256, (1, 1), (2, 2)),
(4, 128, 128, 32, 5, 5, 32, (1, 1), (2, 2)),
(4, 128, 128, 64, 5, 5, 64, (1, 1), (2, 2)),
(4, 128, 128, 128, 5, 5, 128, (1, 1), (2, 2)),
(4, 256, 256, 32, 5, 5, 3, (1, 1), (2, 2)),
(4, 256, 256, 3, 5, 5, 32, (1, 1), (2, 2)),
(4, 128, 128, 64, 5, 5, 3, (1, 1), (2, 2)),
(4, 128, 128, 3, 5, 5, 64, (1, 1), (2, 2)),
)
for dtype in dtypes:
print("(N, H, W, C), ( O, kH, kW, C), dtype, stride, pads, diff%")
for N, H, W, C, kH, kW, O, strides, padding in shapes:
np_dtype = getattr(np, dtype)
time_mlx, time_torch = bench_shape(
N, H, W, C, kH, kW, O, strides, padding, np_dtype
)
diff = time_torch / time_mlx - 1.0
print(
f"({N}, {H:3d}, {W:3d}, {C:3d}), ({O:3d}, {kH:2d}, {kW:2d}, {C:3d}), {dtype}, {strides}, {padding}, {100. * diff:+5.2f}%"
)
if time_mlx >= 2.0 * time_torch:
print("ATTENTION ^^^^^^^")

View File

@@ -35,6 +35,7 @@ Operations
convolve
conv1d
conv2d
conv_general
cos
cosh
dequantize

View File

@@ -1,6 +1,7 @@
// Copyright © 2023 Apple Inc.
// Copyright © 2023-2024 Apple Inc.
#include <cassert>
#include <numeric>
#ifdef ACCELERATE_NEW_LAPACK
#include <Accelerate/Accelerate.h>
@@ -27,14 +28,16 @@ void slow_conv_1D(
array out,
const std::vector<int>& padding,
const std::vector<int>& wt_strides,
const std::vector<int>& wt_dilation) {
const std::vector<int>& wt_dilation,
const std::vector<int>& in_dilation,
bool flip) {
const T* start_wt_ptr = wt.data<T>();
const T* in_ptr = in.data<T>();
T* out_ptr = out.data<T>();
const int N = in.shape(0); // Batch size, should be the same as out.shape(0)
const int iH = in.shape(1); // Input spatial dim
const int iH = 1 + in_dilation[0] * (in.shape(1) - 1); // Input spatial dim
const int oH = out.shape(1); // Output spatial dim
const int O = wt.shape(0); // Out channels
const int C = wt.shape(2); // In channels
@@ -61,12 +64,15 @@ void slow_conv_1D(
for (int wh = 0; wh < wH; ++wh) {
const T* wt_ptr = filter_wt_ptr + wh * wt_stride_H;
int ih = oh * wt_strides[0] - padding[0] + wh * wt_dilation[0];
int wh_flip = flip ? (wH - wh - 1) : wh;
int ih = oh * wt_strides[0] - padding[0] + wh_flip * wt_dilation[0];
if (ih >= 0 && ih < iH) {
auto ih_div = std::div(ih, in_dilation[0]);
if (ih >= 0 && ih < iH && ih_div.rem == 0) {
for (int c = 0; c < C; ++c) {
r += static_cast<float>(
in_ptr[ih * in_stride_H + c * in_stride_C]) *
in_ptr[ih_div.quot * in_stride_H + c * in_stride_C]) *
static_cast<float>(wt_ptr[c * wt_stride_C]);
} // c
@@ -90,14 +96,16 @@ void slow_conv_2D(
array out,
const std::vector<int>& padding,
const std::vector<int>& wt_strides,
const std::vector<int>& wt_dilation) {
const std::vector<int>& wt_dilation,
const std::vector<int>& in_dilation,
bool flip) {
const T* st_wt_ptr = wt.data<T>();
const T* st_in_ptr = in.data<T>();
T* st_out_ptr = out.data<T>();
const int N = in.shape(0); // Batch size, should be the same as out.shape(0)
const int iH = in.shape(1); // Input spatial dim
const int iW = in.shape(2); // Input spatial dim
const int iH = 1 + in_dilation[0] * (in.shape(1) - 1); // Input spatial dim
const int iW = 1 + in_dilation[1] * (in.shape(2) - 1); // Input spatial dim
const int oH = out.shape(1); // Output spatial dim
const int oW = out.shape(2); // Output spatial dim
const int O = wt.shape(0); // Out channels
@@ -120,6 +128,8 @@ void slow_conv_2D(
const size_t out_stride_W = out.strides()[2];
const size_t out_stride_O = out.strides()[3];
bool is_idil_one = in_dilation[0] == 1 && in_dilation[1] == 1;
auto pt_conv_no_checks =
[&](const T* in_ptr, const T* wt_ptr, T* out_ptr, int oh, int ow) {
out_ptr += oh * out_stride_H + ow * out_stride_W;
@@ -131,8 +141,10 @@ void slow_conv_2D(
for (int wh = 0; wh < wH; ++wh) {
for (int ww = 0; ww < wW; ++ww) {
int ih = ih_base + wh * wt_dilation[0];
int iw = iw_base + ww * wt_dilation[1];
int wh_flip = flip ? wH - wh - 1 : wh;
int ww_flip = flip ? wW - ww - 1 : ww;
int ih = ih_base + wh_flip * wt_dilation[0];
int iw = iw_base + ww_flip * wt_dilation[1];
const T* wt_ptr_pt = wt_ptr + wh * wt_stride_H + ww * wt_stride_W;
const T* in_ptr_pt = in_ptr + ih * in_stride_H + iw * in_stride_W;
@@ -153,25 +165,74 @@ void slow_conv_2D(
} // o
};
int jump_h = flip ? -wt_dilation[0] : wt_dilation[0];
int jump_w = flip ? -wt_dilation[1] : wt_dilation[1];
int init_h = (flip ? (wH - 1) * wt_dilation[0] : 0);
int init_w = (flip ? (wW - 1) * wt_dilation[1] : 0);
int f_wgt_jump_h = std::lcm(in_dilation[0], wt_dilation[0]) / wt_dilation[0];
int f_wgt_jump_w = std::lcm(in_dilation[1], wt_dilation[1]) / wt_dilation[1];
int f_out_jump_h = std::lcm(in_dilation[0], wt_strides[0]) / wt_strides[0];
int f_out_jump_w = std::lcm(in_dilation[1], wt_strides[1]) / wt_strides[1];
std::vector<int> base_h(f_out_jump_h);
std::vector<int> base_w(f_out_jump_w);
for (int i = 0; i < f_out_jump_h; ++i) {
int ih_loop = i * wt_strides[0] - padding[0] + init_h;
int wh_base = 0;
while (wh_base < wH && ih_loop % in_dilation[0] != 0) {
wh_base++;
ih_loop += jump_h;
}
base_h[i] = wh_base;
}
for (int j = 0; j < f_out_jump_w; ++j) {
int iw_loop = j * wt_strides[1] - padding[1] + init_w;
int ww_base = 0;
while (ww_base < wW && iw_loop % in_dilation[1] != 0) {
ww_base++;
iw_loop += jump_w;
}
base_w[j] = ww_base;
}
auto pt_conv_all_checks =
[&](const T* in_ptr, const T* wt_ptr, T* out_ptr, int oh, int ow) {
out_ptr += oh * out_stride_H + ow * out_stride_W;
int ih_base = oh * wt_strides[0] - padding[0];
int iw_base = ow * wt_strides[1] - padding[1];
int wh_base = base_h[oh % f_out_jump_h];
int ww_base = base_w[ow % f_out_jump_w];
for (int o = 0; o < O; ++o) {
float r = 0.;
for (int wh = 0; wh < wH; ++wh) {
for (int ww = 0; ww < wW; ++ww) {
int ih = ih_base + wh * wt_dilation[0];
int iw = iw_base + ww * wt_dilation[1];
for (int wh = wh_base; wh < wH; wh += f_wgt_jump_h) {
for (int ww = ww_base; ww < wW; ww += f_wgt_jump_w) {
int wh_flip = flip ? wH - wh - 1 : wh;
int ww_flip = flip ? wW - ww - 1 : ww;
int ih = ih_base + wh_flip * wt_dilation[0];
int iw = iw_base + ww_flip * wt_dilation[1];
if (ih >= 0 && ih < iH && iw >= 0 && iw < iW) {
const T* wt_ptr_pt =
wt_ptr + wh * wt_stride_H + ww * wt_stride_W;
int ih_dil = !is_idil_one ? (ih / in_dilation[0]) : ih;
int iw_dil = !is_idil_one ? (iw / in_dilation[1]) : iw;
const T* in_ptr_pt =
in_ptr + ih * in_stride_H + iw * in_stride_W;
in_ptr + ih_dil * in_stride_H + iw_dil * in_stride_W;
for (int c = 0; c < C; ++c) {
r += static_cast<float>(in_ptr_pt[0]) *
@@ -191,13 +252,17 @@ void slow_conv_2D(
};
int oH_border_0 = 0;
int oH_border_1 = (padding[0] + wt_strides[0] + 1) / wt_strides[0];
int oH_border_2 = (iH + padding[0] - wH * wt_dilation[0]) / wt_strides[0];
int oH_border_1 =
is_idil_one ? ((padding[0] + wt_strides[0] - 1) / wt_strides[0]) : oH;
int oH_border_2 = std::max(
oH_border_1, (iH + padding[0] - wH * wt_dilation[0]) / wt_strides[0]);
int oH_border_3 = oH;
int oW_border_0 = 0;
int oW_border_1 = (padding[1] + wt_strides[0] + 1) / wt_strides[1];
int oW_border_2 = (iW + padding[1] - wW * wt_dilation[1]) / wt_strides[1];
int oW_border_1 =
is_idil_one ? ((padding[1] + wt_strides[1] - 1) / wt_strides[1]) : oW;
int oW_border_2 = std::max(
oW_border_1, (iW + padding[1] - wW * wt_dilation[1]) / wt_strides[1]);
int oW_border_3 = oW;
for (int n = 0; n < N; ++n) {
@@ -246,15 +311,18 @@ void dispatch_slow_conv_1D(
array out,
const std::vector<int>& padding,
const std::vector<int>& wt_strides,
const std::vector<int>& wt_dilation) {
const std::vector<int>& wt_dilation,
const std::vector<int>& in_dilation,
bool flip) {
if (in.dtype() == float32) {
return slow_conv_1D<float>(in, wt, out, padding, wt_strides, wt_dilation);
return slow_conv_1D<float>(
in, wt, out, padding, wt_strides, wt_dilation, in_dilation, flip);
} else if (in.dtype() == float16) {
return slow_conv_1D<float16_t>(
in, wt, out, padding, wt_strides, wt_dilation);
in, wt, out, padding, wt_strides, wt_dilation, in_dilation, flip);
} else if (in.dtype() == bfloat16) {
return slow_conv_1D<bfloat16_t>(
in, wt, out, padding, wt_strides, wt_dilation);
in, wt, out, padding, wt_strides, wt_dilation, in_dilation, flip);
} else {
throw std::invalid_argument(
"[Convolution::eval] got unsupported data type.");
@@ -267,15 +335,18 @@ void dispatch_slow_conv_2D(
array out,
const std::vector<int>& padding,
const std::vector<int>& wt_strides,
const std::vector<int>& wt_dilation) {
const std::vector<int>& wt_dilation,
const std::vector<int>& in_dilation,
bool flip) {
if (in.dtype() == float32) {
return slow_conv_2D<float>(in, wt, out, padding, wt_strides, wt_dilation);
return slow_conv_2D<float>(
in, wt, out, padding, wt_strides, wt_dilation, in_dilation, flip);
} else if (in.dtype() == float16) {
return slow_conv_2D<float16_t>(
in, wt, out, padding, wt_strides, wt_dilation);
in, wt, out, padding, wt_strides, wt_dilation, in_dilation, flip);
} else if (in.dtype() == bfloat16) {
return slow_conv_2D<bfloat16_t>(
in, wt, out, padding, wt_strides, wt_dilation);
in, wt, out, padding, wt_strides, wt_dilation, in_dilation, flip);
} else {
throw std::invalid_argument(
"[Convolution::eval] got unsupported data type.");
@@ -493,13 +564,16 @@ void conv_1D_cpu(
array out,
const std::vector<int>& padding,
const std::vector<int>& wt_strides,
const std::vector<int>& wt_dilation) {
if (wt_dilation[0] == 1) {
const std::vector<int>& wt_dilation,
const std::vector<int>& in_dilation,
bool flip) {
if (wt_dilation[0] == 1 && in_dilation[0] == 1 && !flip) {
return explicit_gemm_conv_1D_cpu(
in, wt, out, padding, wt_strides, wt_dilation);
}
return dispatch_slow_conv_1D(in, wt, out, padding, wt_strides, wt_dilation);
return dispatch_slow_conv_1D(
in, wt, out, padding, wt_strides, wt_dilation, in_dilation, flip);
}
void conv_2D_cpu(
@@ -508,8 +582,11 @@ void conv_2D_cpu(
array out,
const std::vector<int>& padding,
const std::vector<int>& wt_strides,
const std::vector<int>& wt_dilation) {
return dispatch_slow_conv_2D(in, wt, out, padding, wt_strides, wt_dilation);
const std::vector<int>& wt_dilation,
const std::vector<int>& in_dilation,
bool flip) {
return dispatch_slow_conv_2D(
in, wt, out, padding, wt_strides, wt_dilation, in_dilation, flip);
}
} // namespace
@@ -523,12 +600,26 @@ void Convolution::eval(const std::vector<array>& inputs, array& out) {
// 2D convolution
if (in.ndim() == (2 + 2)) {
return conv_2D_cpu(
in, wt, out, padding_, kernel_strides_, kernel_dilation_);
in,
wt,
out,
padding_,
kernel_strides_,
kernel_dilation_,
input_dilation_,
flip_);
}
// 1D convolution
else if (in.ndim() == (1 + 2)) {
return conv_1D_cpu(
in, wt, out, padding_, kernel_strides_, kernel_dilation_);
in,
wt,
out,
padding_,
kernel_strides_,
kernel_dilation_,
input_dilation_,
flip_);
}
// Throw error
else {

View File

@@ -1,4 +1,4 @@
// Copyright © 2023 Apple Inc.
// Copyright © 2023-2024 Apple Inc.
#include <algorithm>
#include <cassert>
@@ -7,81 +7,72 @@
#include "mlx/backend/metal/copy.h"
#include "mlx/backend/metal/device.h"
#include "mlx/backend/metal/kernels/conv_params.h"
#include "mlx/backend/metal/kernels/defines.h"
#include "mlx/backend/metal/kernels/steel/conv/params.h"
#include "mlx/backend/metal/matmul.h"
#include "mlx/backend/metal/utils.h"
#include "mlx/primitives.h"
#include "mlx/utils.h"
using namespace mlx::steel;
namespace mlx::core {
namespace {
void explicit_gemm_conv_1D_gpu(
template <int N>
void explicit_gemm_conv_ND_gpu(
const Stream& s,
metal::Device& d,
const array& in,
const array& wt,
array out,
const MLXConvParams<1>& conv_params) {
// Pad input
std::vector<int> padded_shape = {
conv_params.N, conv_params.iS[0] + 2 * conv_params.pad[0], conv_params.C};
array in_padded(padded_shape, in.dtype(), nullptr, {});
const MLXConvParams<N>& conv_params) {
// Prepare unfolding array
std::vector<int> unfolded_shape = {
static_cast<int>(out.size() / conv_params.O),
static_cast<int>(wt.size() / conv_params.O)};
array in_unfolded(unfolded_shape, in.dtype(), nullptr, {});
// Fill with zeros
auto zero = array(0, in.dtype());
copy_gpu(zero, in_padded, CopyType::Scalar, s);
in_unfolded.set_data(allocator::malloc_or_wait(in_unfolded.nbytes()));
// Pick input slice from padded
size_t data_offset = conv_params.pad[0] * in_padded.strides()[1];
array in_padded_slice(in.shape(), in_padded.dtype(), nullptr, {});
in_padded_slice.copy_shared_buffer(
in_padded,
in_padded.strides(),
in_padded.flags(),
in_padded_slice.size(),
data_offset);
// Prepare unfolding kernel
std::ostringstream kname;
kname << "naive_unfold_nd_" << type_to_name(in_unfolded) << "_" << N;
auto compute_encoder = d.get_command_encoder(s.index);
auto kernel = d.get_kernel(kname.str());
compute_encoder->setComputePipelineState(kernel);
// Copy input values into the slice
copy_gpu_inplace(in, in_padded_slice, CopyType::GeneralGeneral, s);
set_array_buffer(compute_encoder, in, 0);
set_array_buffer(compute_encoder, in_unfolded, 1);
// Make strided view
std::vector<int> strided_shape = {
conv_params.N, conv_params.oS[0], conv_params.wS[0], conv_params.C};
compute_encoder->setBytes(&conv_params, sizeof(conv_params), 2);
std::vector<size_t> strided_strides = {
in_padded.strides()[0],
in_padded.strides()[1] * conv_params.str[0],
in_padded.strides()[1],
in_padded.strides()[2]};
auto flags = in_padded.flags();
// Launch unfolding kernel
int tgp_x = std::min(conv_params.C, 64);
tgp_x = 32 * ((tgp_x + 32 - 1) / 32);
int tgp_y = 256 / tgp_x;
array in_strided_view(strided_shape, in_padded.dtype(), nullptr, {});
in_strided_view.copy_shared_buffer(
in_padded, strided_strides, flags, in_strided_view.size(), 0);
MTL::Size group_dims = MTL::Size(tgp_x, tgp_y, 1);
MTL::Size grid_dims = MTL::Size(
conv_params.C, unfolded_shape[1] / conv_params.C, unfolded_shape[0]);
// Materialize strided view
std::vector<int> strided_reshape = {
conv_params.N * conv_params.oS[0], conv_params.wS[0] * conv_params.C};
array in_strided(strided_reshape, in_strided_view.dtype(), nullptr, {});
copy_gpu(in_strided_view, in_strided, CopyType::General, s);
compute_encoder->dispatchThreads(grid_dims, group_dims);
// Perform gemm
std::vector<array> copies = {zero, in_padded, in_strided};
std::vector<array> copies;
return steel_matmul(
s,
d,
/*a = */ in_strided,
/*a = */ in_unfolded,
/*b = */ wt,
/*c = */ out,
/*M = */ strided_reshape[0],
/*M = */ unfolded_shape[0],
/*N = */ conv_params.O,
/*K = */ strided_reshape[1],
/*K = */ unfolded_shape[1],
/*batch_size_out = */ 1,
/*a_cols = */ strided_reshape[1],
/*b_cols = */ strided_reshape[1],
/*a_cols = */ unfolded_shape[1],
/*b_cols = */ unfolded_shape[1],
/*a_transposed = */ false,
/*b_transposed = */ true,
/*copies = */ copies);
@@ -95,7 +86,9 @@ void conv_1D_gpu(
array out,
const std::vector<int>& padding,
const std::vector<int>& wt_strides,
const std::vector<int>& wt_dilation) {
const std::vector<int>& wt_dilation,
const std::vector<int>& in_dilation,
bool flip) {
// Make conv params
MLXConvParams<1> conv_params{
/* const int N = */ in.shape(0),
@@ -106,24 +99,19 @@ void conv_1D_gpu(
/* const int oS[NDIM] = */ {out.shape(1)},
/* const int str[NDIM] = */ {wt_strides[0]},
/* const int pad[NDIM] = */ {padding[0]},
/* const int dil[NDIM] = */ {wt_dilation[0]},
/* const int kdil[NDIM] = */ {wt_dilation[0]},
/* const int idil[NDIM] = */ {in_dilation[0]},
/* const size_t in_strides[NDIM + 2] = */
{in.strides()[0], in.strides()[1], in.strides()[2]},
/* const size_t wt_strides[NDIM + 2] = */
{wt.strides()[0], wt.strides()[1], wt.strides()[2]},
/* const size_t out_strides[NDIM + 2] = */
{out.strides()[0], out.strides()[1], out.strides()[2]},
};
/* const int groups = */ 1,
/* const bool flip = */ flip};
// Direct to explicit gemm conv
if (wt_dilation[0] == 1) {
explicit_gemm_conv_1D_gpu(s, d, in, wt, out, conv_params);
}
// Direct to fallback conv
else {
throw std::invalid_argument("[conv_1D_gpu] Dilation needs to be 1.");
}
return explicit_gemm_conv_ND_gpu(s, d, in, wt, out, conv_params);
}
void slow_conv_2D_gpu(
@@ -169,114 +157,262 @@ void implicit_gemm_conv_2D_gpu(
const array& wt,
array out,
const MLXConvParams<2>& conv_params) {
int bm = 32, bn = 32, bk = 16;
// Deduce implicit gemm size
int implicit_M = conv_params.N * conv_params.oS[0] * conv_params.oS[1];
int implicit_N = conv_params.O;
int implicit_K = conv_params.wS[0] * conv_params.wS[1] * conv_params.C;
// Determine block and warp tiles
int wm = 2, wn = 2;
int bm = implicit_M >= 8192 && conv_params.C >= 64 ? 64 : 32;
int bn = (bm == 64 || implicit_N >= 64) ? 64 : 32;
int bk = 16;
if (implicit_N <= 16) {
bn = 8;
wm = 4;
wn = 1;
}
int tn = (implicit_N + bn - 1) / bn;
int tm = (implicit_M + bm - 1) / bm;
int swizzle_log = 0;
// Fix small channel specialization
int n_channel_specialization = 0;
int channel_k_iters = ((conv_params.C + bk - 1) / bk);
int gemm_k_iters = conv_params.wS[0] * conv_params.wS[1] * channel_k_iters;
if (conv_params.C <= 2) {
gemm_k_iters = (implicit_K + bk - 1) / bk;
n_channel_specialization = conv_params.C;
} else if (conv_params.C <= 4) {
gemm_k_iters = ((conv_params.wS[0] * conv_params.wS[1] * 4) + bk - 1) / bk;
n_channel_specialization = conv_params.C;
}
bool small_filter = (!n_channel_specialization) &&
(conv_params.wS[0] <= 16 && conv_params.wS[1] <= 16);
// Fix host side helper params
int sign = (conv_params.flip ? -1 : 1);
int ijw = conv_params.in_strides[2] * conv_params.kdil[1];
int ijh = conv_params.in_strides[1] * conv_params.kdil[0];
int inp_jump_w = sign * ijw;
int inp_jump_h = sign * (ijh - (conv_params.wS[1] - 1) * ijw);
int inp_jump_c = bk - sign * (conv_params.wS[0] - 1) * ijh -
sign * (conv_params.wS[1] - 1) * ijw;
// Build implicit gemm params
ImplicitGemmConv2DParams gemm_params{
/* const int M = */ implicit_M,
/* const int N = */ implicit_N,
/* const int K = */ implicit_K,
/* const int gemm_k_iterations = */ gemm_k_iters,
/* const int inp_jump_w = */ inp_jump_w,
/* const int inp_jump_h = */ inp_jump_h,
/* const int inp_jump_c = */ inp_jump_c,
/* const int tiles_n = */ tn,
/* const int tiles_m = */ tm,
/* const int swizzle_log = */ swizzle_log};
// Determine kernel
std::ostringstream kname;
kname << "implicit_gemm_conv_2d_" << type_to_name(out) << "_bm" << bm << "_bn"
<< bn << "_bk" << bk << "_wm" << wm << "_wn" << wn;
<< bn << "_bk" << bk << "_wm" << wm << "_wn" << wn << "_channel_"
<< (n_channel_specialization ? std::to_string(n_channel_specialization)
: "l")
<< "_filter_" << (small_filter ? 's' : 'l');
// Encode and dispatch kernel
auto compute_encoder = d.get_command_encoder(s.index);
auto kernel = d.get_kernel(kname.str());
compute_encoder->setComputePipelineState(kernel);
int implicit_M = conv_params.N * conv_params.oS[0] * conv_params.oS[1];
int implicit_N = conv_params.O;
size_t grid_dim_x = (implicit_N + bn - 1) / bn;
size_t grid_dim_y = (implicit_M + bm - 1) / bm;
// Deduce grid launch dimensions
int tile = 1 << swizzle_log;
size_t grid_dim_y = (tm + tile - 1) / tile;
size_t grid_dim_x = tn * tile;
MTL::Size group_dims = MTL::Size(32, wn, wm);
MTL::Size grid_dims = MTL::Size(grid_dim_x, grid_dim_y, 1);
// Encode arrays
set_array_buffer(compute_encoder, in, 0);
set_array_buffer(compute_encoder, wt, 1);
set_array_buffer(compute_encoder, out, 2);
// Encode params
compute_encoder->setBytes(&conv_params, sizeof(MLXConvParams<2>), 3);
compute_encoder->setBytes(&gemm_params, sizeof(ImplicitGemmConv2DParams), 4);
// Launch kernel
compute_encoder->dispatchThreadgroups(grid_dims, group_dims);
}
void explicit_gemm_conv_2D_gpu(
void implicit_gemm_conv_2D_general_gpu(
const Stream& s,
metal::Device& d,
const array& in,
const array& wt,
array out,
const MLXConvParams<2>& conv_params) {
// Pad input
std::vector<int> padded_shape = {
conv_params.N,
conv_params.iS[0] + 2 * conv_params.pad[0],
conv_params.iS[1] + 2 * conv_params.pad[1],
conv_params.C};
array in_padded(padded_shape, in.dtype(), nullptr, {});
// Deduce implicit gemm size
int implicit_M = conv_params.N * conv_params.oS[0] * conv_params.oS[1];
int implicit_N = conv_params.O;
int implicit_K = conv_params.wS[0] * conv_params.wS[1] * conv_params.C;
// Fill with zeros
auto zero = array(0, in.dtype());
copy_gpu(array(0, in.dtype()), in_padded, CopyType::Scalar, s);
// Determine block and warp tiles
int wm = 2, wn = 2;
// Pick input slice from padded
size_t data_offset = conv_params.pad[0] * in_padded.strides()[1] +
conv_params.pad[1] * in_padded.strides()[2];
array in_padded_slice(in.shape(), in_padded.dtype(), nullptr, {});
in_padded_slice.copy_shared_buffer(
in_padded,
in_padded.strides(),
in_padded.flags(),
in_padded_slice.size(),
data_offset);
// Make jump params
int f_wgt_jump_h =
std::lcm(conv_params.idil[0], conv_params.kdil[0]) / conv_params.kdil[0];
int f_wgt_jump_w =
std::lcm(conv_params.idil[1], conv_params.kdil[1]) / conv_params.kdil[1];
// Copy input values into the slice
copy_gpu_inplace(in, in_padded_slice, CopyType::GeneralGeneral, s);
int f_out_jump_h =
std::lcm(conv_params.idil[0], conv_params.str[0]) / conv_params.str[0];
int f_out_jump_w =
std::lcm(conv_params.idil[1], conv_params.str[1]) / conv_params.str[1];
// Make strided view
std::vector<int> strided_shape = {
conv_params.N,
conv_params.oS[0],
conv_params.oS[1],
conv_params.wS[0],
conv_params.wS[1],
conv_params.C};
int adj_out_h = (conv_params.oS[0] + f_out_jump_h - 1) / f_out_jump_h;
int adj_out_w = (conv_params.oS[1] + f_out_jump_w - 1) / f_out_jump_w;
int adj_out_hw = adj_out_h * adj_out_w;
int adj_implicit_m = conv_params.N * adj_out_hw;
std::vector<size_t> strided_strides = {
in_padded.strides()[0],
in_padded.strides()[1] * conv_params.str[0],
in_padded.strides()[2] * conv_params.str[1],
in_padded.strides()[1],
in_padded.strides()[2],
in_padded.strides()[3]};
auto flags = in_padded.flags();
Conv2DGeneralJumpParams jump_params{
/* const int f_wgt_jump_h = */ f_wgt_jump_h,
/* const int f_wgt_jump_w = */ f_wgt_jump_w,
array in_strided_view(strided_shape, in_padded.dtype(), nullptr, {});
in_strided_view.copy_shared_buffer(
in_padded, strided_strides, flags, in_strided_view.size(), 0);
/* const int f_out_jump_h = */ f_out_jump_h,
/* const int f_out_jump_w = */ f_out_jump_w,
// Materialize strided view
std::vector<int> strided_reshape = {
conv_params.N * conv_params.oS[0] * conv_params.oS[1],
conv_params.wS[0] * conv_params.wS[1] * conv_params.C};
array in_strided(strided_reshape, in_strided_view.dtype(), nullptr, {});
copy_gpu(in_strided_view, in_strided, CopyType::General, s);
/* const int adj_out_h = */ adj_out_h,
/* const int adj_out_w = */ adj_out_w,
/* const int adj_out_hw = */ adj_out_hw,
/* const int adj_implicit_m = */ adj_implicit_m};
// Perform gemm
std::vector<array> copies = {zero, in_padded, in_strided};
return steel_matmul(
s,
d,
/*a = */ in_strided,
/*b = */ wt,
/*c = */ out,
/*M = */ strided_reshape[0],
/*N = */ conv_params.O,
/*K = */ strided_reshape[1],
/*batch_size_out = */ 1,
/*a_cols = */ strided_reshape[1],
/*b_cols = */ strided_reshape[1],
/*a_transposed = */ false,
/*b_transposed = */ true,
/*copies = */ copies);
// Make base info
std::vector<Conv2DGeneralBaseInfo> base_h(f_out_jump_h);
std::vector<Conv2DGeneralBaseInfo> base_w(f_out_jump_w);
int jump_h = conv_params.flip ? -conv_params.kdil[0] : conv_params.kdil[0];
int jump_w = conv_params.flip ? -conv_params.kdil[1] : conv_params.kdil[1];
int init_h =
(conv_params.flip ? (conv_params.wS[0] - 1) * conv_params.kdil[0] : 0);
int init_w =
(conv_params.flip ? (conv_params.wS[1] - 1) * conv_params.kdil[1] : 0);
for (int i = 0; i < f_out_jump_h; ++i) {
int ih_loop = i * conv_params.str[0] - conv_params.pad[0] + init_h;
int wh_base = 0;
while (wh_base < conv_params.wS[0] && ih_loop % conv_params.idil[0] != 0) {
wh_base++;
ih_loop += jump_h;
}
int wh_size =
((conv_params.wS[0] - wh_base) + f_wgt_jump_h - 1) / f_wgt_jump_h;
base_h[i] = {wh_base, wh_size};
}
for (int j = 0; j < f_out_jump_w; ++j) {
int iw_loop = j * conv_params.str[1] - conv_params.pad[1] + init_w;
int ww_base = 0;
while (ww_base < conv_params.wS[1] && iw_loop % conv_params.idil[1] != 0) {
ww_base++;
iw_loop += jump_w;
}
int ww_size =
((conv_params.wS[1] - ww_base) + f_wgt_jump_w - 1) / f_wgt_jump_w;
base_w[j] = {ww_base, ww_size};
}
// Collect block sizes
int bm = adj_implicit_m >= 8192 && conv_params.C >= 64 ? 64 : 32;
int bn = (bm == 64 && implicit_N >= 64) ? 64 : 32;
int bk = 16;
int tn = (implicit_N + bn - 1) / bn;
int tm = (adj_implicit_m + bm - 1) / bm;
int swizzle_log = 0;
// Get channel iteration info
int channel_k_iters = ((conv_params.C + bk - 1) / bk);
int gemm_k_iters = channel_k_iters;
// Fix host side helper params
int sign = (conv_params.flip ? -1 : 1);
int ijw = conv_params.in_strides[2] * conv_params.kdil[1];
int ijh = conv_params.in_strides[1] * conv_params.kdil[0];
int inp_jump_w = sign * ijw;
int inp_jump_h = sign * (ijh - (conv_params.wS[1] - 1) * ijw);
int inp_jump_c = bk - sign * (conv_params.wS[0] - 1) * ijh -
sign * (conv_params.wS[1] - 1) * ijw;
// Build implicit gemm params
ImplicitGemmConv2DParams gemm_params{
/* const int M = */ implicit_M,
/* const int N = */ implicit_N,
/* const int K = */ implicit_K,
/* const int gemm_k_iterations = */ gemm_k_iters,
/* const int inp_jump_w = */ inp_jump_w,
/* const int inp_jump_h = */ inp_jump_h,
/* const int inp_jump_c = */ inp_jump_c,
/* const int tiles_n = */ tn,
/* const int tiles_m = */ tm,
/* const int swizzle_log = */ swizzle_log};
// Determine kernel
std::ostringstream kname;
kname << "implicit_gemm_conv_2d_general_" << type_to_name(out) << "_bm" << bm
<< "_bn" << bn << "_bk" << bk << "_wm" << wm << "_wn" << wn;
// Encode and dispatch kernel
auto compute_encoder = d.get_command_encoder(s.index);
auto kernel = d.get_kernel(kname.str());
compute_encoder->setComputePipelineState(kernel);
// Deduce grid launch dimensions
int tile = 1 << swizzle_log;
size_t grid_dim_y = (tm + tile - 1) / tile;
size_t grid_dim_x = tn * tile;
size_t grid_dim_z = f_out_jump_h * f_out_jump_w;
MTL::Size group_dims = MTL::Size(32, wn, wm);
MTL::Size grid_dims = MTL::Size(grid_dim_x, grid_dim_y, grid_dim_z);
// Encode arrays
set_array_buffer(compute_encoder, in, 0);
set_array_buffer(compute_encoder, wt, 1);
set_array_buffer(compute_encoder, out, 2);
// Encode params
compute_encoder->setBytes(&conv_params, sizeof(MLXConvParams<2>), 3);
compute_encoder->setBytes(&gemm_params, sizeof(ImplicitGemmConv2DParams), 4);
compute_encoder->setBytes(&jump_params, sizeof(Conv2DGeneralJumpParams), 5);
compute_encoder->setBytes(
base_h.data(), sizeof(Conv2DGeneralBaseInfo) * base_h.size(), 6);
compute_encoder->setBytes(
base_w.data(), sizeof(Conv2DGeneralBaseInfo) * base_w.size(), 7);
// Launch kernel
compute_encoder->dispatchThreadgroups(grid_dims, group_dims);
}
void winograd_conv_2D_gpu(
@@ -301,6 +437,7 @@ void winograd_conv_2D_gpu(
// Fill with zeros
array zero_arr = array(0, in.dtype());
copy_gpu(zero_arr, in_padded, CopyType::Scalar, s);
copies_w.push_back(zero_arr);
// Pick input slice from padded
size_t data_offset = conv_params.pad[0] * in_padded.strides()[1] +
@@ -329,7 +466,8 @@ void winograd_conv_2D_gpu(
/* const int oS[NDIM] = */ {out.shape(1), out.shape(2)},
/* const int str[NDIM] = */ {1, 1},
/* const int pad[NDIM] = */ {0, 0},
/* const int dil[NDIM] = */ {1, 1},
/* const int kdil[NDIM] = */ {1, 1},
/* const int idil[NDIM] = */ {1, 1},
/* const size_t in_strides[NDIM + 2] = */
{in_padded.strides()[0],
in_padded.strides()[1],
@@ -339,6 +477,8 @@ void winograd_conv_2D_gpu(
{wt.strides()[0], wt.strides()[1], wt.strides()[2], wt.strides()[3]},
/* const size_t out_strides[NDIM + 2] = */
{out.strides()[0], out.strides()[1], out.strides()[2], out.strides()[3]},
/* const int groups = */ 1,
/* const bool flip = */ false,
};
int O_c = conv_params.O;
@@ -462,6 +602,8 @@ void conv_2D_gpu(
const std::vector<int>& padding,
const std::vector<int>& wt_strides,
const std::vector<int>& wt_dilation,
const std::vector<int>& in_dilation,
bool flip,
std::vector<array>& copies) {
// Make conv params
MLXConvParams<2> conv_params{
@@ -473,37 +615,47 @@ void conv_2D_gpu(
/* const int oS[NDIM] = */ {out.shape(1), out.shape(2)},
/* const int str[NDIM] = */ {wt_strides[0], wt_strides[1]},
/* const int pad[NDIM] = */ {padding[0], padding[1]},
/* const int dil[NDIM] = */ {wt_dilation[0], wt_dilation[1]},
/* const int kdil[NDIM] = */ {wt_dilation[0], wt_dilation[1]},
/* const int idil[NDIM] = */ {in_dilation[0], in_dilation[1]},
/* const size_t in_strides[NDIM + 2] = */
{in.strides()[0], in.strides()[1], in.strides()[2], in.strides()[3]},
/* const size_t wt_strides[NDIM + 2] = */
{wt.strides()[0], wt.strides()[1], wt.strides()[2], wt.strides()[3]},
/* const size_t out_strides[NDIM + 2] = */
{out.strides()[0], out.strides()[1], out.strides()[2], out.strides()[3]},
/* const int groups = */ 1,
/* const bool flip = */ flip,
};
bool is_stride_one = conv_params.str[0] == 1 && conv_params.str[1] == 1;
bool is_kdil_one = conv_params.kdil[0] == 1 && conv_params.kdil[1] == 1;
bool is_idil_one = conv_params.idil[0] == 1 && conv_params.idil[1] == 1;
bool inp_large = (conv_params.in_strides[0] >= 1ul << 18);
bool channels_large = (conv_params.C + conv_params.O) >= 512;
bool channels_med = (conv_params.C + conv_params.O) >= 256;
// Direct to winograd conv
if (conv_params.C % 32 == 0 && conv_params.O % 32 == 0 &&
conv_params.C >= 64 && conv_params.O >= 64 && conv_params.wS[0] == 3 &&
conv_params.wS[1] == 3 && conv_params.str[0] == 1 &&
conv_params.str[1] == 1 && conv_params.dil[0] == 1 &&
conv_params.dil[1] == 1) {
winograd_conv_2D_gpu(s, d, in, wt, out, conv_params, copies);
if (!flip && is_stride_one && is_kdil_one && is_idil_one &&
conv_params.wS[0] == 3 && conv_params.wS[1] == 3 &&
conv_params.C % 32 == 0 && conv_params.O % 32 == 0 &&
(channels_large || (channels_med && inp_large))) {
return winograd_conv_2D_gpu(s, d, in, wt, out, conv_params, copies);
}
// Direct to implicit gemm conv
else if (conv_params.C % 32 == 0 && conv_params.O % 32 == 0) {
implicit_gemm_conv_2D_gpu(s, d, in, wt, out, conv_params);
if (is_idil_one && (conv_params.C <= 4 || conv_params.C % 16 == 0) &&
(conv_params.O <= 16 || conv_params.O % 16 == 0)) {
return implicit_gemm_conv_2D_gpu(s, d, in, wt, out, conv_params);
}
else if (conv_params.C % 16 == 0 && conv_params.O % 16 == 0) {
return implicit_gemm_conv_2D_general_gpu(s, d, in, wt, out, conv_params);
}
// Direct to explicit gemm conv
else if (wt_dilation[0] == 1 && wt_dilation[1] == 1) {
explicit_gemm_conv_2D_gpu(s, d, in, wt, out, conv_params);
}
// Direct to fallback conv
else {
slow_conv_2D_gpu(s, d, in, wt, out, conv_params);
return explicit_gemm_conv_ND_gpu(s, d, in, wt, out, conv_params);
}
}
@@ -534,11 +686,31 @@ void Convolution::eval_gpu(const std::vector<array>& inputs, array& out) {
// 2D conv
if (out.ndim() == 4) {
conv_2D_gpu(
s, d, in, wt, out, padding_, kernel_strides_, kernel_dilation_, copies);
s,
d,
in,
wt,
out,
padding_,
kernel_strides_,
kernel_dilation_,
input_dilation_,
flip_,
copies);
}
// 1D conv
else if (out.ndim() == 3) {
conv_1D_gpu(s, d, in, wt, out, padding_, kernel_strides_, kernel_dilation_);
conv_1D_gpu(
s,
d,
in,
wt,
out,
padding_,
kernel_strides_,
kernel_dilation_,
input_dilation_,
flip_);
}
// Throw error
else {

View File

@@ -51,11 +51,7 @@ endfunction(build_kernel_base)
function(build_kernel KERNEL)
set(SRCFILE ${CMAKE_CURRENT_SOURCE_DIR}/${KERNEL}.metal)
set(HEADERS_PADDED ${HEADERS})
if(${KERNEL} STREQUAL "conv")
set(HEADERS_PADDED ${HEADERS_PADDED} ${CMAKE_CURRENT_SOURCE_DIR}/conv.h)
endif()
build_kernel_base(${KERNEL} ${SRCFILE} "${HEADERS_PADDED}")
build_kernel_base(${KERNEL} ${SRCFILE} "${HEADERS}")
endfunction(build_kernel)
foreach(KERNEL ${KERNELS})

View File

@@ -1,481 +0,0 @@
// Copyright © 2023 Apple Inc.
#pragma once
#include <metal_simdgroup>
#include <metal_simdgroup_matrix>
#include <metal_stdlib>
#include "mlx/backend/metal/kernels/bf16.h"
#include "mlx/backend/metal/kernels/conv_params.h"
#define MLX_MTL_CONST static constant constexpr const
using namespace metal;
///////////////////////////////////////////////////////////////////////////////
// Loading helper
///////////////////////////////////////////////////////////////////////////////
template <
typename T,
int BM,
int BN,
int BK,
int vec_size,
int tgp_size,
int tgp_padding = 0>
struct Conv2DInputBlockLoader {
// Destination dimensions
MLX_MTL_CONST int dst_fd = BM;
MLX_MTL_CONST int dst_ld = BK + tgp_padding;
MLX_MTL_CONST int n_vecs = BK / vec_size;
// Stride along block row within the block
MLX_MTL_CONST int bstride = tgp_size / n_vecs;
MLX_MTL_CONST int n_rows = dst_fd / bstride;
// Thread location indices
const short thread_idx;
const short bi;
const short bj;
// threadgroup and device memory
threadgroup T* dst;
const device T* src;
const constant MLXConvParams<2>& params;
int weight_h;
int weight_w;
int offsets_n[n_rows];
int offsets_oh[n_rows];
int offsets_ow[n_rows];
/* Constructor */
METAL_FUNC Conv2DInputBlockLoader(
const device T* src_,
threadgroup T* dst_,
const constant MLXConvParams<2>& params_,
uint3 tid [[threadgroup_position_in_grid]],
uint3 lid [[thread_position_in_threadgroup]],
uint simd_group_id [[simdgroup_index_in_threadgroup]],
uint simd_lane_id [[thread_index_in_simdgroup]])
: thread_idx(simd_group_id * 32 + simd_lane_id),
bi(thread_idx / n_vecs),
bj(vec_size * (thread_idx % n_vecs)),
dst(dst_ + bi * dst_ld + bj),
src(src_ + bj),
params(params_),
weight_h(0),
weight_w(0) {
int out_n_pixels = params.oS[0] * params.oS[1];
for (int i = 0; i < n_rows; ++i) {
int offset_nhw = tid.y * BM + bi + i * bstride;
offsets_n[i] = offset_nhw / out_n_pixels;
int hw = offset_nhw % out_n_pixels;
offsets_oh[i] = hw / params.oS[1];
offsets_ow[i] = hw % params.oS[1];
}
(void)lid;
}
/* Load from device memory into threadgroup memory - without bound checking */
METAL_FUNC void load_unsafe() const {
#pragma clang loop unroll(full)
for (short i = 0, is = 0; i < n_rows; ++i, is += bstride) {
int n = offsets_n[i];
int oh = offsets_oh[i];
int ow = offsets_ow[i];
int ih = oh * params.str[0] - params.pad[0] + weight_h * params.dil[0];
int iw = ow * params.str[1] - params.pad[1] + weight_w * params.dil[1];
// Read from input if in bounds
if (ih >= 0 && ih < params.iS[0] && iw >= 0 && iw < params.iS[1]) {
const device T* curr_src = src + n * params.in_strides[0] +
ih * params.in_strides[1] + iw * params.in_strides[2];
#pragma clang loop unroll(full)
for (short j = 0; j < vec_size; ++j) {
dst[is * dst_ld + j] = curr_src[j];
}
}
// Zero pad otherwise
else {
#pragma clang loop unroll(full)
for (short j = 0; j < vec_size; ++j) {
dst[is * dst_ld + j] = T(0);
}
}
}
}
/* Iteration helper */
METAL_FUNC void next() {
if (++weight_w < params.wS[1]) {
return;
}
weight_w = 0;
if (++weight_h < params.wS[0]) {
return;
}
weight_h = 0;
src += BK;
}
};
template <
typename T,
int BM,
int BN,
int BK,
int vec_size,
int tgp_size,
int tgp_padding = 0>
struct Conv2DWeightBlockLoader {
// Destination dimensions
MLX_MTL_CONST int dst_fd = BN;
MLX_MTL_CONST int dst_ld = BK + tgp_padding;
MLX_MTL_CONST int n_vecs = BK / vec_size;
// Stride along block row within the block
MLX_MTL_CONST int bstride = tgp_size / n_vecs;
MLX_MTL_CONST int n_rows = dst_fd / bstride;
// Leading dimension for src
const int src_ld;
// Thread location indices
const short thread_idx;
const short bi;
const short bj;
// threadgroup and device memory
threadgroup T* dst;
const device T* src;
const constant MLXConvParams<2>& params;
int weight_h;
int weight_w;
/* Constructor */
METAL_FUNC Conv2DWeightBlockLoader(
const device T* src_,
threadgroup T* dst_,
const constant MLXConvParams<2>& params_,
uint3 tid [[threadgroup_position_in_grid]],
uint3 lid [[thread_position_in_threadgroup]],
uint simd_group_id [[simdgroup_index_in_threadgroup]],
uint simd_lane_id [[thread_index_in_simdgroup]])
: src_ld(params_.wt_strides[0]),
thread_idx(simd_group_id * 32 + simd_lane_id),
bi(thread_idx / n_vecs),
bj(vec_size * (thread_idx % n_vecs)),
dst(dst_ + bi * dst_ld + bj),
src(src_ + bi * src_ld + bj),
params(params_),
weight_h(0),
weight_w(0) {
(void)lid;
(void)tid;
}
/* Load from device memory into threadgroup memory - without bound checking */
METAL_FUNC void load_unsafe() const {
const device T* curr_src =
src + weight_h * params.wt_strides[1] + weight_w * params.wt_strides[2];
#pragma clang loop unroll(full)
for (short i = 0; i < dst_fd; i += bstride) {
#pragma clang loop unroll(full)
for (short j = 0; j < vec_size; j++) {
dst[i * dst_ld + j] = curr_src[i * src_ld + j];
}
}
}
/* Iteration helper */
METAL_FUNC void next() {
if (++weight_w < params.wS[1]) {
return;
}
weight_w = 0;
if (++weight_h < params.wS[0]) {
return;
}
weight_h = 0;
src += BK;
}
};
///////////////////////////////////////////////////////////////////////////////
// Transforms
///////////////////////////////////////////////////////////////////////////////
template <typename OutT, typename InT>
struct TransformNone {
static METAL_FUNC OutT apply(InT x) {
return static_cast<OutT>(x);
}
};
template <typename T>
struct AccumHelper {
typedef float accum_type;
};
///////////////////////////////////////////////////////////////////////////////
// MMA helper
///////////////////////////////////////////////////////////////////////////////
template <
typename T,
int BM,
int BN,
int BK,
int WM,
int WN,
bool transpose_a,
bool transpose_b,
int tgp_padding_a = 0,
int tgp_padding_b = 0,
typename AccumType = typename AccumHelper<T>::accum_type,
typename Epilogue = TransformNone<T, AccumType>>
struct Conv2DBlockMMA {
// Warp tile size along M
MLX_MTL_CONST int TM = BM / (WM * 8);
// Warp tile size along N
MLX_MTL_CONST int TN = BN / (WN * 8);
// Warp tile simdgroup matrix strides along M
MLX_MTL_CONST int TM_stride = 8 * WM;
// Warp tile simdgroup matrix strides along M
MLX_MTL_CONST int TN_stride = 8 * WN;
// Leading dimensions of threadgroup A, B blocks
MLX_MTL_CONST int lda_tgp = (transpose_a ? BM : BK) + tgp_padding_a;
MLX_MTL_CONST int ldb_tgp = (transpose_b ? BK : BN) + tgp_padding_b;
// Strides of A, B along reduction axis
MLX_MTL_CONST short simd_stride_a =
transpose_a ? TM_stride : TM_stride * lda_tgp;
MLX_MTL_CONST short simd_stride_b =
transpose_b ? TN_stride * ldb_tgp : TN_stride;
// Jump between elements
MLX_MTL_CONST short jump_a = transpose_a ? lda_tgp : 1;
MLX_MTL_CONST short jump_b = transpose_b ? ldb_tgp : 1;
// Offsets within threadgroup
const int tm;
const int tn;
// Simdgroup matrices
simdgroup_matrix<AccumType, 8, 8> Asimd[TM];
simdgroup_matrix<AccumType, 8, 8> Bsimd[TN];
simdgroup_matrix<AccumType, 8, 8> results[TM * TN] = {
simdgroup_matrix<AccumType, 8, 8>(0)};
short sm;
short sn;
/* Constructor */
METAL_FUNC Conv2DBlockMMA(
uint simd_group_id [[simdgroup_index_in_threadgroup]],
uint simd_lane_id [[thread_index_in_simdgroup]])
: tm(8 * (simd_group_id / WN)), tn(8 * (simd_group_id % WN)) {
short qid = simd_lane_id / 4;
sm = (qid & 4) + (simd_lane_id / 2) % 4;
sn = (qid & 2) * 2 + (simd_lane_id % 2) * 2;
}
/* (BM, BK) X (BK, BN) multiply accumulate function */
METAL_FUNC void mma(const threadgroup T* As, const threadgroup T* Bs) {
// Iterate over BK in blocks of 8
#pragma clang loop unroll(full)
for (short kk = 0; kk < BK; kk += 8) {
short2 offset_a =
transpose_a ? short2(tm + sm, kk + sn) : short2(kk + sn, tm + sm);
short2 offset_b =
transpose_b ? short2(kk + sm, tn + sn) : short2(tn + sn, kk + sm);
const threadgroup T* As__ = As + offset_a.y * lda_tgp + offset_a.x;
const threadgroup T* Bs__ = Bs + offset_b.y * ldb_tgp + offset_b.x;
simdgroup_barrier(mem_flags::mem_none);
// Load elements from threadgroup A as simdgroup matrices
#pragma clang loop unroll(full)
for (short i = 0; i < TM; i++) {
Asimd[i].thread_elements()[0] = static_cast<AccumType>(As__[0]);
Asimd[i].thread_elements()[1] = static_cast<AccumType>(As__[jump_a]);
As__ += simd_stride_a;
}
simdgroup_barrier(mem_flags::mem_none);
// Load elements from threadgroup B as simdgroup matrices
#pragma clang loop unroll(full)
for (short j = 0; j < TN; j++) {
Bsimd[j].thread_elements()[0] = static_cast<AccumType>(Bs__[0]);
Bsimd[j].thread_elements()[1] = static_cast<AccumType>(Bs__[jump_b]);
Bs__ += simd_stride_b;
}
simdgroup_barrier(mem_flags::mem_none);
// Multiply and accumulate into result simdgroup matrices
#pragma clang loop unroll(full)
for (short i = 0; i < TM; i++) {
#pragma clang loop unroll(full)
for (short j = 0; j < TN; j++) {
simdgroup_multiply_accumulate(
results[i * TN + j], Asimd[i], Bsimd[j], results[i * TN + j]);
}
}
}
}
/* Store results from simdgroup_matrix results into device memory */
METAL_FUNC void store_result(device T* C, const int ldc) const {
#pragma clang loop unroll(full)
for (int i = 0; i < TM; i++) {
#pragma clang loop unroll(full)
for (int j = 0; j < TN; j++) {
C[(i * TM_stride + sm + tm) * ldc + j * TN_stride + tn + sn] =
Epilogue::apply(results[i * TN + j].thread_elements()[0]);
C[(i * TM_stride + sm + tm) * ldc + j * TN_stride + tn + sn + 1] =
Epilogue::apply(results[i * TN + j].thread_elements()[1]);
}
}
}
METAL_FUNC void
store_result_safe(device T* C, const int ldc, short2 dst_tile_dims) const {
#pragma clang loop unroll(full)
for (int i = 0; i < TM; i++) {
if (tm + i * TM_stride + sm < dst_tile_dims.y) {
#pragma clang loop unroll(full)
for (int j = 0; j < TN; j++) {
if (tn + j * TN_stride + sn < dst_tile_dims.x) {
C[(tm + i * TM_stride + sm) * ldc + tn + j * TN_stride + sn] =
Epilogue::apply(results[i * TN + j].thread_elements()[0]);
}
if (tn + j * TN_stride + sn + 1 < dst_tile_dims.x) {
C[(tm + i * TM_stride + sm) * ldc + tn + j * TN_stride + sn + 1] =
Epilogue::apply(results[i * TN + j].thread_elements()[1]);
}
}
}
}
}
};
///////////////////////////////////////////////////////////////////////////////
// GEMM kernels
///////////////////////////////////////////////////////////////////////////////
template <
typename T,
int BM,
int BN,
int BK,
int WM,
int WN,
bool transpose_a,
bool transpose_b,
typename AccumType = typename AccumHelper<T>::accum_type,
typename Epilogue = TransformNone<T, AccumType>>
struct Conv2DImplicitGEMMKernel {
MLX_MTL_CONST short tgp_padding_a = 16 / sizeof(T);
MLX_MTL_CONST short tgp_padding_b = 16 / sizeof(T);
MLX_MTL_CONST short tgp_mem_size_a =
transpose_a ? BK * (BM + tgp_padding_a) : BM * (BK + tgp_padding_a);
MLX_MTL_CONST short tgp_mem_size_b =
transpose_b ? BN * (BK + tgp_padding_b) : BK * (BN + tgp_padding_b);
MLX_MTL_CONST short tgp_mem_size = tgp_mem_size_a + tgp_mem_size_b;
MLX_MTL_CONST short tgp_size = WM * WN * 32;
MLX_MTL_CONST short vec_size = (BM == 64 && BN == 64) ? 8 : 4;
using loader_a_t =
Conv2DInputBlockLoader<T, BM, BN, BK, vec_size, tgp_size, tgp_padding_a>;
using loader_b_t =
Conv2DWeightBlockLoader<T, BM, BN, BK, vec_size, tgp_size, tgp_padding_b>;
using mma_t = Conv2DBlockMMA<
T,
BM,
BN,
BK,
WM,
WN,
transpose_a,
transpose_b,
tgp_padding_a,
tgp_padding_b,
AccumType,
Epilogue>;
/* Main kernel function */
static METAL_FUNC void run(
const device T* A [[buffer(0)]],
const device T* B [[buffer(1)]],
device T* C [[buffer(2)]],
const constant MLXConvParams<2>& params [[buffer(3)]],
threadgroup T* tgp_memory [[threadgroup(0)]],
uint3 tid [[threadgroup_position_in_grid]],
uint3 lid [[thread_position_in_threadgroup]],
uint simd_gid [[simdgroup_index_in_threadgroup]],
uint simd_lid [[thread_index_in_simdgroup]]) {
const int c_row = tid.y * BM;
const int c_col = tid.x * BN;
const int K = params.wt_strides[0];
const int N = params.O;
B += c_col * K;
C += c_row * N + c_col;
// Prepare threadgroup memory for loading
threadgroup T* As = tgp_memory;
threadgroup T* Bs = tgp_memory + tgp_mem_size_a;
// Prepare threadgroup loading operations
loader_a_t loader_a(A, As, params, tid, lid, simd_gid, simd_lid);
loader_b_t loader_b(B, Bs, params, tid, lid, simd_gid, simd_lid);
// Prepare threadgroup mma operation
mma_t mma_op(simd_gid, simd_lid);
for (int k = 0; k < K; k += BK) {
threadgroup_barrier(mem_flags::mem_threadgroup);
// Load elements into threadgroup
loader_a.load_unsafe();
loader_b.load_unsafe();
threadgroup_barrier(mem_flags::mem_threadgroup);
// Multiply and accumulate threadgroup elements
mma_op.mma(As, Bs);
// Prepare for next iteration
loader_a.next();
loader_b.next();
}
threadgroup_barrier(mem_flags::mem_none);
// Store results to device memory
mma_op.store_result(C, N);
}
};

View File

@@ -1,16 +1,102 @@
// Copyright © 2023 Apple Inc.
// Copyright © 2023-2024 Apple Inc.
#include <metal_stdlib>
#include <metal_simdgroup>
#include <metal_simdgroup_matrix>
#include <metal_stdlib>
#include "mlx/backend/metal/kernels/conv_params.h"
#include "mlx/backend/metal/kernels/steel/conv/params.h"
#include "mlx/backend/metal/kernels/bf16.h"
#include "mlx/backend/metal/kernels/conv.h"
#define MLX_MTL_CONST static constant constexpr const
using namespace metal;
///////////////////////////////////////////////////////////////////////////////
/// Slow and naive kernels
/// Naive unfold with dilation
///////////////////////////////////////////////////////////////////////////////
template <typename T, int N>
[[kernel]] void naive_unfold_Nd(
const device T* in [[buffer(0)]],
device T* out [[buffer(1)]],
const constant MLXConvParams<N>* params [[buffer(2)]],
uint3 gid [[thread_position_in_grid]]) {
int filter_size = params->C;
for(short i = 0; i < N; i++) filter_size *= params->wS[i];
int out_pixels = 1;
for(short i = 0; i < N; i++) out_pixels *= params->oS[i];
// Set out
out += gid.z * filter_size + gid.y * (params->C);
// Corrdinates in input
int is[N] = {0};
// gid.z: N oS (Batch and row in unfolded output)
// gid.y: wS (Filter location to unfold input)
// gid.x: C (channel)
int n = (gid.z) / out_pixels;
int oS = (gid.z) % out_pixels;
int wS = gid.y;
bool valid = n < params->N;
// Unroll dimensions
for (int i = N - 1; i >= 0; --i) {
int os_ = (oS % params->oS[i]);
int ws_ = (wS % params->wS[i]);
ws_ = params->flip ? params->wS[i] - ws_ - 1 : ws_;
int is_ = os_ * params->str[i] - params->pad[i] + ws_ * params->kdil[i];
int is_max = 1 + params->idil[i] * (params->iS[i] - 1);
valid &= is_ >= 0 && is_ < is_max && (is_ % params->idil[i] == 0);
is[i] = is_ / params->idil[i];
oS /= params->oS[i];
wS /= params->wS[i];
}
if(valid) {
size_t in_offset = n * params->in_strides[0];
for(int i = 0; i < N; ++i) {
in_offset += is[i] * params->in_strides[i + 1];
}
out[gid.x] = in[in_offset + gid.x];
} else {
out[gid.x] = T(0);
}
}
#define instantiate_naive_unfold_nd(name, itype, n) \
template [[host_name("naive_unfold_nd_" #name "_" #n)]] \
[[kernel]] void naive_unfold_Nd( \
const device itype* in [[buffer(0)]], \
device itype* out [[buffer(1)]], \
const constant MLXConvParams<n>* params [[buffer(2)]], \
uint3 gid [[thread_position_in_grid]]);
#define instantiate_naive_unfold_nd_dims(name, itype) \
instantiate_naive_unfold_nd(name, itype, 1) \
instantiate_naive_unfold_nd(name, itype, 2) \
instantiate_naive_unfold_nd(name, itype, 3)
instantiate_naive_unfold_nd_dims(float32, float);
instantiate_naive_unfold_nd_dims(float16, half);
instantiate_naive_unfold_nd_dims(bfloat16, bfloat16_t);
///////////////////////////////////////////////////////////////////////////////
/// Slow and naive conv2d kernels
///////////////////////////////////////////////////////////////////////////////
template <typename T,
@@ -58,8 +144,8 @@ template <typename T,
// Local in
for(int m = 0; m < TM; m++) {
int i = out_h[m] * params.str[0] - params.pad[0] + h * params.dil[0];
int j = out_w[m] * params.str[1] - params.pad[1] + w * params.dil[1];
int i = out_h[m] * params.str[0] - params.pad[0] + h * params.kdil[0];
int j = out_w[m] * params.str[1] - params.pad[1] + w * params.kdil[1];
bool valid = i >= 0 && i < params.iS[0] && j >= 0 && j < params.iS[1];
in_local[m] = valid ? in[i * params.in_strides[1] + j * params.in_strides[2] + c] : T(0);
@@ -116,59 +202,6 @@ instantiate_naive_conv_2d_blocks(float32, float);
instantiate_naive_conv_2d_blocks(float16, half);
instantiate_naive_conv_2d_blocks(bfloat16, bfloat16_t);
///////////////////////////////////////////////////////////////////////////////
/// Implicit gemm kernels
///////////////////////////////////////////////////////////////////////////////
template <typename T,
int BM,
int BN,
int BK,
int WM,
int WN>
[[kernel, max_total_threads_per_threadgroup(WM * WN * 32)]] void implicit_gemm_conv_2d(
const device T* in [[buffer(0)]],
const device T* wt [[buffer(1)]],
device T* out [[buffer(2)]],
const constant MLXConvParams<2>& params [[buffer(3)]],
uint3 tid [[threadgroup_position_in_grid]],
uint3 lid [[thread_position_in_threadgroup]],
uint simd_gid [[simdgroup_index_in_threadgroup]],
uint simd_lid [[thread_index_in_simdgroup]]) {
using gemm_kernel = Conv2DImplicitGEMMKernel<T, BM, BN, BK, WM, WN, /*transpose_a*/ false, /*transpose_b*/ true>;
threadgroup T tgp_memory[gemm_kernel::tgp_mem_size];
gemm_kernel::run(
in, wt, out,
params, tgp_memory,
tid, lid, simd_gid, simd_lid
);
}
#define instantiate_implicit_conv_2d(name, itype, bm, bn, bk, wm, wn) \
template [[host_name("implicit_gemm_conv_2d_" #name "_bm" #bm "_bn" #bn "_bk" #bk "_wm" #wm "_wn" #wn)]] \
[[kernel]] void implicit_gemm_conv_2d<itype, bm, bn, bk, wm, wn>( \
const device itype* in [[buffer(0)]], \
const device itype* wt [[buffer(1)]], \
device itype* out [[buffer(2)]], \
const constant MLXConvParams<2>& params [[buffer(3)]], \
uint3 tid [[threadgroup_position_in_grid]], \
uint3 lid [[thread_position_in_threadgroup]], \
uint simd_gid [[simdgroup_index_in_threadgroup]], \
uint simd_lid [[thread_index_in_simdgroup]]);
#define instantiate_implicit_2d_blocks(name, itype) \
instantiate_implicit_conv_2d(name, itype, 32, 32, 32, 2, 2) \
instantiate_implicit_conv_2d(name, itype, 32, 32, 16, 2, 2) \
instantiate_implicit_conv_2d(name, itype, 64, 64, 16, 2, 2)
instantiate_implicit_2d_blocks(float32, float);
instantiate_implicit_2d_blocks(float16, half);
instantiate_implicit_2d_blocks(bfloat16, bfloat16_t);
///////////////////////////////////////////////////////////////////////////////
/// Winograd kernels
///////////////////////////////////////////////////////////////////////////////

View File

@@ -1,19 +0,0 @@
// Copyright © 2023 Apple Inc.
#pragma once
template <int NDIM>
struct MLXConvParams {
const int N; // Batch size
const int C; // In channels
const int O; // Out channels
const int iS[NDIM]; // Input spatial dim
const int wS[NDIM]; // Weight spatial dim
const int oS[NDIM]; // Output spatial dim
const int str[NDIM]; // Kernel strides
const int pad[NDIM]; // Input padding
const int dil[NDIM]; // Kernel dilation
const size_t in_strides[NDIM + 2]; // In strides
const size_t wt_strides[NDIM + 2]; // Wt strides
const size_t out_strides[NDIM + 2]; // Out strides
};

View File

@@ -0,0 +1,11 @@
// Copyright © 2024 Apple Inc.
#pragma once
#include "mlx/backend/metal/kernels/steel/utils.h"
#include "mlx/backend/metal/kernels/steel/conv/loader.h"
#include "mlx/backend/metal/kernels/steel/conv/params.h"
using namespace metal;
using namespace mlx::steel;

View File

@@ -0,0 +1,189 @@
// Copyright © 2024 Apple Inc.
#include <metal_stdlib>
#include "mlx/backend/metal/kernels/steel/gemm/mma.h"
#include "mlx/backend/metal/kernels/steel/conv/conv.h"
#include "mlx/backend/metal/kernels/steel/conv/params.h"
#include "mlx/backend/metal/kernels/bf16.h"
using namespace metal;
template <typename T,
int BM,
int BN,
int BK,
int WM,
int WN,
int N_CHANNELS = 0,
bool SMALL_FILTER = false>
[[kernel, max_total_threads_per_threadgroup(WM * WN * 32)]] void implicit_gemm_conv_2d(
const device T* A [[buffer(0)]],
const device T* B [[buffer(1)]],
device T* C [[buffer(2)]],
const constant MLXConvParams<2>* params [[buffer(3)]],
const constant ImplicitGemmConv2DParams* gemm_params [[buffer(4)]],
uint3 tid [[threadgroup_position_in_grid]],
uint3 lid [[thread_position_in_threadgroup]],
uint simd_gid [[simdgroup_index_in_threadgroup]],
uint simd_lid [[thread_index_in_simdgroup]]) {
using namespace mlx::steel;
(void)lid;
constexpr bool transpose_a = false;
constexpr bool transpose_b = true;
constexpr short tgp_padding_a = 16 / sizeof(T);
constexpr short tgp_padding_b = 16 / sizeof(T);
constexpr short shape_a_cols = (transpose_a ? BM : BK) + tgp_padding_a;
constexpr short shape_b_cols = (transpose_b ? BK : BN) + tgp_padding_b;
constexpr short shape_a_rows = (transpose_a ? BK : BM);
constexpr short shape_b_rows = (transpose_b ? BN : BK);
constexpr short tgp_mem_size_a = shape_a_cols * shape_a_rows;
constexpr short tgp_mem_size_b = shape_b_cols * shape_b_rows;
constexpr short tgp_size = WM * WN * 32;
// Input loader
using loader_a_t = typename metal::conditional_t<
// Check for small channel specialization
N_CHANNELS != 0 && N_CHANNELS <= 4,
// Go to small channel specialization
Conv2DInputBlockLoaderSmallChannels<
T, BM, BN, BK, tgp_size, N_CHANNELS, tgp_padding_a>,
// Else go to general loader
typename metal::conditional_t<
// Check if filter size is small enough
SMALL_FILTER,
// Go to small filter specialization
Conv2DInputBlockLoaderSmallFilter<
T, BM, BN, BK, tgp_size, tgp_padding_a>,
// Else go to large filter generalization
Conv2DInputBlockLoaderLargeFilter<
T, BM, BN, BK, tgp_size, tgp_padding_a>
>
>;
// Weight loader
using loader_b_t = typename metal::conditional_t<
// Check for small channel specialization
N_CHANNELS != 0 && N_CHANNELS <= 4,
// Go to small channel specialization
Conv2DWeightBlockLoaderSmallChannels<
T, BM, BN, BK, tgp_size, N_CHANNELS, tgp_padding_b>,
// Else go to general loader
Conv2DWeightBlockLoader<T, BM, BN, BK, tgp_size, tgp_padding_b>
>;
using mma_t = BlockMMA<
T,
T,
BM,
BN,
BK,
WM,
WN,
transpose_a,
transpose_b,
shape_a_cols,
shape_b_cols>;
threadgroup T As[tgp_mem_size_a];
threadgroup T Bs[tgp_mem_size_b];
const int tid_y = ((tid.y) << gemm_params->swizzle_log) +
((tid.x) & ((1 << gemm_params->swizzle_log) - 1));
const int tid_x = (tid.x) >> gemm_params->swizzle_log;
if (gemm_params->tiles_n <= tid_x || gemm_params->tiles_m <= tid_y) {
return;
}
const int c_row = tid_y * BM;
const int c_col = tid_x * BN;
const int K = gemm_params->K;
const int N = gemm_params->N;
B += c_col * K;
C += c_row * N + c_col;
const int2 offsets_a(0, c_row);
const int2 offsets_b(0, c_col);
// Prepare threadgroup loading operations
loader_a_t loader_a(A, As, offsets_a, params, gemm_params, simd_gid, simd_lid);
loader_b_t loader_b(B, Bs, offsets_b, params, gemm_params, simd_gid, simd_lid);
// Prepare threadgroup mma operation
mma_t mma_op(simd_gid, simd_lid);
int gemm_k_iterations = gemm_params->gemm_k_iterations;
for (int k = 0; k < gemm_k_iterations; k++) {
threadgroup_barrier(mem_flags::mem_threadgroup);
// Load elements into threadgroup
loader_a.load_unsafe();
loader_b.load_unsafe();
threadgroup_barrier(mem_flags::mem_threadgroup);
// Multiply and accumulate threadgroup elements
mma_op.mma(As, Bs);
// Prepare for next iteration
loader_a.next();
loader_b.next();
}
threadgroup_barrier(mem_flags::mem_none);
// Store results to device memory
short tgp_bm = min(BM, gemm_params->M - c_row);
short tgp_bn = min(BN, gemm_params->N - c_col);
mma_op.store_result_safe(C, N, short2(tgp_bn, tgp_bm));
}
#define instantiate_implicit_conv_2d(name, itype, bm, bn, bk, wm, wn, channel_name, n_channels, filter_name, small_filter) \
template [[host_name("implicit_gemm_conv_2d_" #name "_bm" #bm "_bn" #bn "_bk" #bk "_wm" #wm "_wn" #wn "_channel_" #channel_name "_filter_" #filter_name)]] \
[[kernel]] void implicit_gemm_conv_2d<itype, bm, bn, bk, wm, wn, n_channels, small_filter>( \
const device itype* A [[buffer(0)]], \
const device itype* B [[buffer(1)]], \
device itype* C [[buffer(2)]], \
const constant MLXConvParams<2>* params [[buffer(3)]], \
const constant ImplicitGemmConv2DParams* gemm_params [[buffer(4)]], \
uint3 tid [[threadgroup_position_in_grid]], \
uint3 lid [[thread_position_in_threadgroup]], \
uint simd_gid [[simdgroup_index_in_threadgroup]], \
uint simd_lid [[thread_index_in_simdgroup]]);
#define instantiate_implicit_2d_filter(name, itype, bm, bn, bk, wm, wn) \
instantiate_implicit_conv_2d(name, itype, bm, bn, bk, wm, wn, l, 0, s, true) \
instantiate_implicit_conv_2d(name, itype, bm, bn, bk, wm, wn, l, 0, l, false) \
instantiate_implicit_conv_2d(name, itype, bm, bn, bk, wm, wn, 1, 1, l, false) \
instantiate_implicit_conv_2d(name, itype, bm, bn, bk, wm, wn, 2, 2, l, false) \
instantiate_implicit_conv_2d(name, itype, bm, bn, bk, wm, wn, 3, 3, l, false) \
instantiate_implicit_conv_2d(name, itype, bm, bn, bk, wm, wn, 4, 4, l, false)
#define instantiate_implicit_2d_blocks(name, itype) \
instantiate_implicit_2d_filter(name, itype, 32, 8, 16, 4, 1) \
instantiate_implicit_2d_filter(name, itype, 64, 8, 16, 4, 1) \
instantiate_implicit_2d_filter(name, itype, 32, 32, 16, 2, 2) \
instantiate_implicit_2d_filter(name, itype, 32, 64, 16, 2, 2) \
instantiate_implicit_2d_filter(name, itype, 64, 32, 16, 2, 2) \
instantiate_implicit_2d_filter(name, itype, 64, 64, 16, 2, 2)
instantiate_implicit_2d_blocks(float32, float);
instantiate_implicit_2d_blocks(float16, half);
instantiate_implicit_2d_blocks(bfloat16, bfloat16_t);

View File

@@ -0,0 +1,209 @@
// Copyright © 2024 Apple Inc.
#include <metal_stdlib>
#include "mlx/backend/metal/kernels/steel/gemm/mma.h"
#include "mlx/backend/metal/kernels/steel/conv/conv.h"
#include "mlx/backend/metal/kernels/steel/conv/params.h"
#include "mlx/backend/metal/kernels/steel/conv/loaders/loader_general.h"
#include "mlx/backend/metal/kernels/bf16.h"
using namespace metal;
using namespace mlx::steel;
template <typename T,
int BM,
int BN,
int BK,
int WM,
int WN,
typename AccumType = float,
typename Epilogue = TransformNone<T, AccumType>>
[[kernel, max_total_threads_per_threadgroup(WM * WN * 32)]] void implicit_gemm_conv_2d_general(
const device T* A [[buffer(0)]],
const device T* B [[buffer(1)]],
device T* C [[buffer(2)]],
const constant MLXConvParams<2>* params [[buffer(3)]],
const constant ImplicitGemmConv2DParams* gemm_params [[buffer(4)]],
const constant Conv2DGeneralJumpParams* jump_params [[buffer(5)]],
const constant Conv2DGeneralBaseInfo* base_h [[buffer(6)]],
const constant Conv2DGeneralBaseInfo* base_w [[buffer(7)]],
uint3 tid [[threadgroup_position_in_grid]],
uint3 lid [[thread_position_in_threadgroup]],
uint simd_gid [[simdgroup_index_in_threadgroup]],
uint simd_lid [[thread_index_in_simdgroup]]) {
(void)lid;
constexpr bool transpose_a = false;
constexpr bool transpose_b = true;
constexpr short tgp_padding_a = 16 / sizeof(T);
constexpr short tgp_padding_b = 16 / sizeof(T);
constexpr short shape_a_cols = (transpose_a ? BM : BK) + tgp_padding_a;
constexpr short shape_b_cols = (transpose_b ? BK : BN) + tgp_padding_b;
constexpr short shape_a_rows = (transpose_a ? BK : BM);
constexpr short shape_b_rows = (transpose_b ? BN : BK);
constexpr short tgp_mem_size_a = shape_a_cols * shape_a_rows;
constexpr short tgp_mem_size_b = shape_b_cols * shape_b_rows;
constexpr short tgp_size = WM * WN * 32;
// Input loader
using loader_a_t = Conv2DInputBlockLoaderGeneral<
T, BM, BN, BK, tgp_size, tgp_padding_a>;
// Weight loader
using loader_b_t = Conv2DWeightBlockLoaderGeneral<
T, BM, BN, BK, tgp_size, tgp_padding_b>;
using mma_t = BlockMMA<
T,
T,
BM,
BN,
BK,
WM,
WN,
transpose_a,
transpose_b,
shape_a_cols,
shape_b_cols>;
threadgroup T As[tgp_mem_size_a];
threadgroup T Bs[tgp_mem_size_b];
const int tid_y = ((tid.y) << gemm_params->swizzle_log) +
((tid.x) & ((1 << gemm_params->swizzle_log) - 1));
const int tid_x = (tid.x) >> gemm_params->swizzle_log;
if (gemm_params->tiles_n <= tid_x || gemm_params->tiles_m <= tid_y) {
return;
}
const int tid_z = tid.z;
const int base_oh = tid_z / jump_params->f_out_jump_w;
const int base_ow = tid_z % jump_params->f_out_jump_w;
const int base_wh = base_h[base_oh].weight_base;
const int base_ww = base_w[base_ow].weight_base;
const int base_wh_size = base_h[base_oh].weight_size;
const int base_ww_size = base_w[base_ow].weight_size;
const int c_row = tid_y * BM;
const int c_col = tid_x * BN;
const int K = gemm_params->K;
B += c_col * K;
const int4 offsets_a(0, c_row, base_oh, base_ow);
const int2 offsets_b(0, c_col);
// Prepare threadgroup loading operations
loader_a_t loader_a(A, As, offsets_a, params, jump_params, base_wh, base_ww, simd_gid, simd_lid);
loader_b_t loader_b(B, Bs, offsets_b, params, jump_params, base_wh, base_ww, simd_gid, simd_lid);
// Prepare threadgroup mma operation
mma_t mma_op(simd_gid, simd_lid);
int gemm_k_iterations = base_wh_size * base_ww_size * gemm_params->gemm_k_iterations;
for (int k = 0; k < gemm_k_iterations; k++) {
threadgroup_barrier(mem_flags::mem_threadgroup);
// Load elements into threadgroup
loader_a.load_unsafe();
loader_b.load_unsafe();
threadgroup_barrier(mem_flags::mem_threadgroup);
// Multiply and accumulate threadgroup elements
mma_op.mma(As, Bs);
// Prepare for next iteration
loader_a.next();
loader_b.next();
}
threadgroup_barrier(mem_flags::mem_none);
// Store results to device memory
{
// Adjust for simdgroup and thread locatio
int offset_m = c_row + mma_op.sm + mma_op.tm;
int offset_n = c_col + mma_op.sn + mma_op.tn;
C += offset_n;
if (offset_n >= gemm_params->N)
return;
short diff = gemm_params->N - offset_n;
STEEL_PRAGMA_UNROLL
for (int i = 0; i < mma_t::TM; i++) {
int cm = offset_m + i * mma_t::TM_stride;
int n = cm / jump_params->adj_out_hw;
int hw = cm % jump_params->adj_out_hw;
int oh = (hw / jump_params->adj_out_w) * jump_params->f_out_jump_h + base_oh;
int ow = (hw % jump_params->adj_out_w) * jump_params->f_out_jump_w + base_ow;
if(n < params->N && oh < params->oS[0] && ow < params->oS[1]) {
int offset_cm = n * params->out_strides[0] + oh * params->out_strides[1] + ow * params->out_strides[2];
STEEL_PRAGMA_UNROLL
for (int j = 0; j < mma_t::TN; j++) {
// Get accumulated result and associated offset in C
thread const auto& accum = mma_op.results[i * mma_t::TN + j].thread_elements();
int offset = offset_cm + (j * mma_t::TN_stride);
// Apply epilogue and output C
if (j * mma_t::TN_stride < diff) {
C[offset] = Epilogue::apply(accum[0]);
}
if (j * mma_t::TN_stride + 1 < diff) {
C[offset + 1] = Epilogue::apply(accum[1]);
}
}
}
}
}
}
#define instantiate_implicit_conv_2d(name, itype, bm, bn, bk, wm, wn) \
template [[host_name("implicit_gemm_conv_2d_general_" #name "_bm" #bm "_bn" #bn "_bk" #bk "_wm" #wm "_wn" #wn)]] \
[[kernel]] void implicit_gemm_conv_2d_general<itype, bm, bn, bk, wm, wn>( \
const device itype* A [[buffer(0)]], \
const device itype* B [[buffer(1)]], \
device itype* C [[buffer(2)]], \
const constant MLXConvParams<2>* params [[buffer(3)]], \
const constant ImplicitGemmConv2DParams* gemm_params [[buffer(4)]], \
const constant Conv2DGeneralJumpParams* jump_params [[buffer(5)]], \
const constant Conv2DGeneralBaseInfo* base_h [[buffer(6)]], \
const constant Conv2DGeneralBaseInfo* base_w [[buffer(7)]], \
uint3 tid [[threadgroup_position_in_grid]], \
uint3 lid [[thread_position_in_threadgroup]], \
uint simd_gid [[simdgroup_index_in_threadgroup]], \
uint simd_lid [[thread_index_in_simdgroup]]);
#define instantiate_implicit_2d_filter(name, itype, bm, bn, bk, wm, wn) \
instantiate_implicit_conv_2d(name, itype, bm, bn, bk, wm, wn)
#define instantiate_implicit_2d_blocks(name, itype) \
instantiate_implicit_2d_filter(name, itype, 32, 8, 16, 4, 1) \
instantiate_implicit_2d_filter(name, itype, 64, 8, 16, 4, 1) \
instantiate_implicit_2d_filter(name, itype, 32, 32, 16, 2, 2) \
instantiate_implicit_2d_filter(name, itype, 32, 64, 16, 2, 2) \
instantiate_implicit_2d_filter(name, itype, 64, 32, 16, 2, 2) \
instantiate_implicit_2d_filter(name, itype, 64, 64, 16, 2, 2)
instantiate_implicit_2d_blocks(float32, float);
instantiate_implicit_2d_blocks(float16, half);
instantiate_implicit_2d_blocks(bfloat16, bfloat16_t);

View File

@@ -0,0 +1,6 @@
// Copyright © 2024 Apple Inc.
#pragma once
#include "mlx/backend/metal/kernels/steel/conv/loaders/loader_channel_l.h"
#include "mlx/backend/metal/kernels/steel/conv/loaders/loader_channel_n.h"

View File

@@ -0,0 +1,449 @@
// Copyright © 2024 Apple Inc.
#pragma once
#include "mlx/backend/metal/kernels/steel/utils.h"
#include "mlx/backend/metal/kernels/steel/conv/params.h"
///////////////////////////////////////////////////////////////////////////////
// Loading helper
///////////////////////////////////////////////////////////////////////////////
namespace mlx {
namespace steel {
template <
typename T,
short BM,
short BN,
short BK,
short tgp_size,
short tgp_padding = 0>
struct Conv2DInputBlockLoaderLargeFilter {
// Destination dimensions
STEEL_CONST short BROWS = BM;
STEEL_CONST short BCOLS = BK;
// Read dimensions
STEEL_CONST short dst_ld = BCOLS + tgp_padding;
STEEL_CONST short vec_size = tgp_size / (BROWS * BCOLS) >= 8 ? 8 : 4;
// Thread read shape
STEEL_CONST short TCOLS = BCOLS / vec_size;
STEEL_CONST short TROWS = tgp_size / TCOLS;
// Rows / strided reads within the block
STEEL_CONST short n_rows = BROWS / TROWS;
// Thread location indices
const short thread_idx;
const short bi;
const short bj;
// threadgroup and device memory
threadgroup T* dst;
const constant MLXConvParams<2>* params;
const constant ImplicitGemmConv2DParams* gemm_params;
short weight_h;
short weight_w;
const device T* src[n_rows];
int read_n[n_rows];
int read_ih[n_rows];
int read_iw[n_rows];
/* Constructor */
METAL_FUNC Conv2DInputBlockLoaderLargeFilter(
const device T* src_,
threadgroup T* dst_,
const int2 offsets,
const constant MLXConvParams<2>* params_,
const constant ImplicitGemmConv2DParams* gemm_params_,
uint simd_group_id [[simdgroup_index_in_threadgroup]],
uint simd_lane_id [[thread_index_in_simdgroup]])
: thread_idx(simd_group_id * 32 + simd_lane_id),
bi(thread_idx / TCOLS),
bj(vec_size * (thread_idx % TCOLS)),
dst(dst_ + bi * dst_ld + bj),
params(params_),
gemm_params(gemm_params_),
weight_h(0),
weight_w(0) {
int out_n_pixels = params->oS[0] * params->oS[1];
STEEL_PRAGMA_UNROLL
for (short i = 0; i < n_rows; ++i) {
int offset_nhw = offsets.y + bi + i * TROWS;
int n = offset_nhw / out_n_pixels;
int hw = offset_nhw % out_n_pixels;
int oh = hw / params->oS[1];
int ow = hw % params->oS[1];
int ih = oh * params->str[0] - params->pad[0];
int iw = ow * params->str[1] - params->pad[1];
read_n[i] = n;
read_ih[i] = ih;
read_iw[i] = iw;
// Adjust for flip
if (params->flip) {
ih += (params->wS[0] - 1) * params->kdil[0];
iw += (params->wS[1] - 1) * params->kdil[1];
}
// Read from input if in bounds
src[i] = src_ + n * params->in_strides[0] + ih * params->in_strides[1] +
iw * params->in_strides[2] + bj;
}
}
/* Load from device memory into threadgroup memory - without bound checking */
METAL_FUNC void load_unsafe() const {
STEEL_PRAGMA_UNROLL
for (short i = 0, is = 0; i < n_rows; ++i, is += TROWS) {
// Find bounds
int n = read_n[i];
int ih = read_ih[i] + weight_h * params->kdil[0];
int iw = read_iw[i] + weight_w * params->kdil[1];
// Read from input if in bounds
if ((n < params->N) && (ih >= 0 && ih < params->iS[0]) &&
(iw >= 0 && iw < params->iS[1])) {
STEEL_PRAGMA_UNROLL
for (short j = 0; j < vec_size; ++j) {
dst[is * dst_ld + j] = src[i][j];
}
}
// Zero pad otherwise
else {
STEEL_PRAGMA_UNROLL
for (short j = 0; j < vec_size; ++j) {
dst[is * dst_ld + j] = T(0);
}
}
}
}
/* Iteration helper */
METAL_FUNC void next() {
if (++weight_w < params->wS[1]) {
STEEL_PRAGMA_UNROLL
for (short i = 0; i < n_rows; i++) {
src[i] += gemm_params->inp_jump_w;
}
return;
}
weight_w = 0;
if (++weight_h < params->wS[0]) {
STEEL_PRAGMA_UNROLL
for (short i = 0; i < n_rows; i++) {
src[i] += gemm_params->inp_jump_h;
}
return;
}
weight_h = 0;
STEEL_PRAGMA_UNROLL
for (short i = 0; i < n_rows; i++) {
src[i] += gemm_params->inp_jump_c;
}
}
};
template <
typename T,
short BM,
short BN,
short BK,
short tgp_size,
short tgp_padding = 0>
struct Conv2DInputBlockLoaderSmallFilter {
// Destination dimensions
STEEL_CONST short BROWS = BM;
STEEL_CONST short BCOLS = BK;
// Read dimensions
STEEL_CONST short dst_ld = BCOLS + tgp_padding;
STEEL_CONST short vec_size = tgp_size / (BROWS * BCOLS) >= 8 ? 8 : 4;
// Thread read shape
STEEL_CONST short TCOLS = BCOLS / vec_size;
STEEL_CONST short TROWS = tgp_size / TCOLS;
// Rows / strided reads within the block
STEEL_CONST short n_rows = BROWS / TROWS;
using mask_t = short;
// Thread location indices
const short thread_idx;
const short bi;
const short bj;
// threadgroup and device memory
threadgroup T* dst;
const constant MLXConvParams<2>* params;
const constant ImplicitGemmConv2DParams* gemm_params;
short weight_h;
short weight_w;
const device T* src[n_rows];
mask_t mask_h[n_rows];
mask_t mask_w[n_rows];
/* Constructor */
METAL_FUNC Conv2DInputBlockLoaderSmallFilter(
const device T* src_,
threadgroup T* dst_,
const int2 offsets,
const constant MLXConvParams<2>* params_,
const constant ImplicitGemmConv2DParams* gemm_params_,
uint simd_group_id [[simdgroup_index_in_threadgroup]],
uint simd_lane_id [[thread_index_in_simdgroup]])
: thread_idx(simd_group_id * 32 + simd_lane_id),
bi(thread_idx / TCOLS),
bj(vec_size * (thread_idx % TCOLS)),
dst(dst_ + bi * dst_ld + bj),
params(params_),
gemm_params(gemm_params_),
weight_h(0),
weight_w(0) {
int out_n_pixels = params->oS[0] * params->oS[1];
int read_n[n_rows];
int read_ih[n_rows];
int read_iw[n_rows];
STEEL_PRAGMA_UNROLL
for (short i = 0; i < n_rows; ++i) {
int offset_nhw = offsets.y + bi + i * TROWS;
int n = offset_nhw / out_n_pixels;
int hw = offset_nhw % out_n_pixels;
int oh = hw / params->oS[1];
int ow = hw % params->oS[1];
int ih = oh * params->str[0] - params->pad[0];
int iw = ow * params->str[1] - params->pad[1];
read_n[i] = n;
read_ih[i] = ih;
read_iw[i] = iw;
// Adjust for flip
if (params->flip) {
ih += (params->wS[0] - 1) * params->kdil[0];
iw += (params->wS[1] - 1) * params->kdil[1];
}
// Read from input if in bounds
src[i] = src_ + n * params->in_strides[0] + ih * params->in_strides[1] +
iw * params->in_strides[2] + bj;
}
STEEL_PRAGMA_UNROLL
for (short i = 0; i < n_rows; ++i) {
mask_h[i] = 0;
mask_w[i] = 0;
}
for (short kh = 0; kh < params->wS[0]; kh++) {
short flip_h = params->flip ? params->wS[0] - kh - 1 : kh;
STEEL_PRAGMA_UNROLL
for (short i = 0; i < n_rows; ++i) {
int n = read_n[i];
int ih = read_ih[i] + flip_h * params->kdil[0];
bool in_bounds = n < params->N && ih >= 0 && ih < params->iS[0];
mask_h[i] |= (in_bounds << kh);
}
}
for (short kw = 0; kw < params->wS[1]; kw++) {
short flip_w = params->flip ? params->wS[1] - kw - 1 : kw;
STEEL_PRAGMA_UNROLL
for (short i = 0; i < n_rows; ++i) {
int iw = read_iw[i] + flip_w * params->kdil[1];
bool in_bounds = iw >= 0 && iw < params->iS[1];
mask_w[i] |= (in_bounds << kw);
}
}
}
/* Load from device memory into threadgroup memory - without bound checking */
METAL_FUNC void load_unsafe() const {
mask_t h_mask = mask_t(1) << weight_h;
mask_t w_mask = mask_t(1) << weight_w;
STEEL_PRAGMA_UNROLL
for (short i = 0, is = 0; i < n_rows; ++i, is += TROWS) {
// Read from input if in bounds
if ((mask_h[i] & h_mask) && (mask_w[i] & w_mask)) {
STEEL_PRAGMA_UNROLL
for (short j = 0; j < vec_size; ++j) {
dst[is * dst_ld + j] = src[i][j];
}
}
// Zero pad otherwise
else {
STEEL_PRAGMA_UNROLL
for (short j = 0; j < vec_size; ++j) {
dst[is * dst_ld + j] = T(0);
}
}
}
}
/* Iteration helper */
METAL_FUNC void next() {
if (++weight_w < params->wS[1]) {
STEEL_PRAGMA_UNROLL
for (short i = 0; i < n_rows; i++) {
src[i] += gemm_params->inp_jump_w;
}
return;
}
weight_w = 0;
if (++weight_h < params->wS[0]) {
STEEL_PRAGMA_UNROLL
for (short i = 0; i < n_rows; i++) {
src[i] += gemm_params->inp_jump_h;
}
return;
}
weight_h = 0;
STEEL_PRAGMA_UNROLL
for (short i = 0; i < n_rows; i++) {
src[i] += gemm_params->inp_jump_c;
}
}
};
template <
typename T,
short BM,
short BN,
short BK,
short tgp_size,
short tgp_padding = 0>
struct Conv2DWeightBlockLoader {
// Destination dimensions
STEEL_CONST short BROWS = BN;
STEEL_CONST short BCOLS = BK;
// Read dimensions
STEEL_CONST short dst_ld = BCOLS + tgp_padding;
STEEL_CONST short vec_size =
(BN == 8) ? 1 : (tgp_size / (BROWS * BCOLS) >= 8 ? 8 : 4);
// Thread read shape
STEEL_CONST short TCOLS = BCOLS / vec_size;
STEEL_CONST short TROWS = tgp_size / TCOLS;
// Rows / strided reads within the block
STEEL_CONST short n_rows = BROWS / TROWS;
// Leading dimension for src
const int src_ld;
// Thread location indices
const short thread_idx;
const short bi;
const short bj;
// threadgroup and device memory
threadgroup T* dst;
const device T* src;
const constant MLXConvParams<2>* params;
int weight_hw;
const int read_n;
const bool do_read;
/* Constructor */
METAL_FUNC Conv2DWeightBlockLoader(
const device T* src_,
threadgroup T* dst_,
const int2 offsets,
const constant MLXConvParams<2>* params_,
const constant ImplicitGemmConv2DParams* gemm_params_,
uint simd_group_id [[simdgroup_index_in_threadgroup]],
uint simd_lane_id [[thread_index_in_simdgroup]])
: src_ld(params_->wt_strides[0]),
thread_idx(simd_group_id * 32 + simd_lane_id),
bi(thread_idx / TCOLS),
bj(vec_size * (thread_idx % TCOLS)),
dst(dst_ + bi * dst_ld + bj),
src(src_ + bi * src_ld + bj),
params(params_),
weight_hw(0),
read_n(offsets.y + bi),
do_read(read_n + n_rows * TROWS <= gemm_params_->N) {}
/* Load from device memory into threadgroup memory - without bound checking */
METAL_FUNC void load_unsafe() const {
if (BN != 8 || do_read) {
STEEL_PRAGMA_UNROLL
for (short i = 0; i < BN; i += TROWS) {
STEEL_PRAGMA_UNROLL
for (short j = 0; j < vec_size; j++) {
dst[i * dst_ld + j] = src[i * src_ld + j];
}
}
} else {
for (short i = 0; i < BN; i += TROWS) {
if ((read_n + i) < params->O) {
STEEL_PRAGMA_UNROLL
for (short j = 0; j < vec_size; j++) {
dst[i * dst_ld + j] = src[i * src_ld + j];
}
} else {
STEEL_PRAGMA_UNROLL
for (short j = 0; j < vec_size; j++) {
dst[i * dst_ld + j] = T(0);
}
}
}
}
}
/* Iteration helper */
METAL_FUNC void next() {
if (++weight_hw < (params->wS[1] * params->wS[0])) {
src += params->wt_strides[2];
return;
}
weight_hw = 0;
src += BK - (params->wS[1] * params->wS[0] - 1) * params->wt_strides[2];
}
};
} // namespace steel
} // namespace mlx

View File

@@ -0,0 +1,319 @@
// Copyright © 2024 Apple Inc.
#pragma once
#include "mlx/backend/metal/kernels/steel/utils.h"
#include "mlx/backend/metal/kernels/steel/conv/params.h"
///////////////////////////////////////////////////////////////////////////////
// Loading helper
///////////////////////////////////////////////////////////////////////////////
namespace mlx {
namespace steel {
template <short n_channels_>
struct ChannelHelper {
STEEL_CONST short n_channels = n_channels_;
STEEL_CONST short vec_size = n_channels_ <= 4 ? 4 : 8;
STEEL_CONST short excess = vec_size - n_channels_;
};
template <>
struct ChannelHelper<1> {
STEEL_CONST short n_channels = 1;
STEEL_CONST short vec_size = 1;
STEEL_CONST short excess = 0;
};
template <>
struct ChannelHelper<2> {
STEEL_CONST short n_channels = 2;
STEEL_CONST short vec_size = 2;
STEEL_CONST short excess = 0;
};
template <>
struct ChannelHelper<3> {
STEEL_CONST short n_channels = 3;
STEEL_CONST short vec_size = 4;
STEEL_CONST short excess = 1;
};
template <>
struct ChannelHelper<4> {
STEEL_CONST short n_channels = 4;
STEEL_CONST short vec_size = 4;
STEEL_CONST short excess = 0;
};
template <
typename T,
short BM,
short BN,
short BK,
short tgp_size,
short n_channels,
short tgp_padding = 0>
struct Conv2DInputBlockLoaderSmallChannels {
// Destination dimensions
STEEL_CONST short BROWS = BM;
STEEL_CONST short BCOLS = BK;
// Read dimensions
STEEL_CONST short dst_ld = BCOLS + tgp_padding;
STEEL_CONST short vec_size = ChannelHelper<n_channels>::vec_size;
// Thread read shape
STEEL_CONST short TCOLS = BCOLS / vec_size;
STEEL_CONST short TROWS = tgp_size / TCOLS;
// Rows / strided reads within the block
STEEL_CONST short n_rows = BROWS / TROWS;
// Thread location indices
const short thread_idx;
const short bi;
const short bj;
// threadgroup and device memory
threadgroup T* dst;
const constant MLXConvParams<2>* params;
const constant ImplicitGemmConv2DParams* gemm_params;
short weight_hw;
const device T* src[n_rows];
int read_n[n_rows];
int read_ih[n_rows];
int read_iw[n_rows];
/* Constructor */
METAL_FUNC Conv2DInputBlockLoaderSmallChannels(
const device T* src_,
threadgroup T* dst_,
const int2 offsets,
const constant MLXConvParams<2>* params_,
const constant ImplicitGemmConv2DParams* gemm_params_,
uint simd_group_id [[simdgroup_index_in_threadgroup]],
uint simd_lane_id [[thread_index_in_simdgroup]])
: thread_idx(simd_group_id * 32 + simd_lane_id),
bi(thread_idx / TCOLS),
bj(vec_size * (thread_idx % TCOLS)),
dst(dst_ + bi * dst_ld + bj),
params(params_),
gemm_params(gemm_params_),
weight_hw(thread_idx % TCOLS) {
int out_n_pixels = params->oS[0] * params->oS[1];
STEEL_PRAGMA_UNROLL
for (short i = 0; i < n_rows; ++i) {
int offset_nhw = offsets.y + bi + i * TROWS;
int n = offset_nhw / out_n_pixels;
int hw = offset_nhw % out_n_pixels;
int oh = hw / params->oS[1];
int ow = hw % params->oS[1];
int ih = oh * params->str[0] - params->pad[0];
int iw = ow * params->str[1] - params->pad[1];
// Read from input if in bounds
src[i] = src_ + n * params->in_strides[0] + ih * params->in_strides[1] +
iw * params->in_strides[2];
read_n[i] = n;
read_ih[i] = ih;
read_iw[i] = iw;
}
}
/* Load from device memory into threadgroup memory - without bound checking */
METAL_FUNC void load_unsafe() const {
if (weight_hw >= params->wS[1] * params->wS[0]) {
STEEL_PRAGMA_UNROLL
for (short i = 0; i < BROWS; i += TROWS) {
STEEL_PRAGMA_UNROLL
for (short j = 0; j < vec_size; j++) {
dst[i * dst_ld + j] = T(0);
}
}
return;
}
int wh = (weight_hw / params->wS[1]);
int ww = (weight_hw % params->wS[1]);
int flip_h = params->flip ? params->wS[0] - wh - 1 : wh;
int flip_w = params->flip ? params->wS[1] - ww - 1 : ww;
int weight_h = flip_h * params->kdil[0];
int weight_w = flip_w * params->kdil[1];
STEEL_PRAGMA_UNROLL
for (short i = 0, is = 0; i < n_rows; ++i, is += TROWS) {
// Find bounds
int n = read_n[i];
int ih = read_ih[i] + weight_h;
int iw = read_iw[i] + weight_w;
// Read from input if in bounds
if ((n < params->N) && (ih >= 0 && ih < params->iS[0]) &&
(iw >= 0 && iw < params->iS[1])) {
const device T* curr_src = src[i] + weight_h * params->in_strides[1] +
weight_w * params->in_strides[2];
STEEL_PRAGMA_UNROLL
for (short j = 0; j < n_channels; ++j) {
dst[is * dst_ld + j] = curr_src[j];
}
STEEL_PRAGMA_UNROLL
for (short j = n_channels; j < vec_size; ++j) {
dst[is * dst_ld + j] = T(0);
}
}
// Zero pad otherwise
else {
STEEL_PRAGMA_UNROLL
for (short j = 0; j < vec_size; ++j) {
dst[is * dst_ld + j] = T(0);
}
}
}
}
/* Iteration helper */
METAL_FUNC void next() {
weight_hw += TCOLS;
}
};
template <
typename T,
short BM,
short BN,
short BK,
short tgp_size,
short n_channels,
short tgp_padding = 0>
struct Conv2DWeightBlockLoaderSmallChannels {
// Destination dimensions
STEEL_CONST short BROWS = BN;
STEEL_CONST short BCOLS = BK;
// Read dimensions
STEEL_CONST short dst_ld = BCOLS + tgp_padding;
STEEL_CONST short vec_size = ChannelHelper<n_channels>::vec_size;
// Thread read shape
STEEL_CONST short TCOLS = BCOLS / vec_size;
STEEL_CONST short TROWS = tgp_size / TCOLS;
// Rows / strided reads within the block
STEEL_CONST short n_rows = BROWS / TROWS;
// Leading dimension for src
const int src_ld;
// Thread location indices
const short thread_idx;
const short bi;
const short bj;
// threadgroup and device memory
threadgroup T* dst;
const device T* src;
const constant MLXConvParams<2>* params;
int weight_hw;
const int read_n;
const bool do_read;
/* Constructor */
METAL_FUNC Conv2DWeightBlockLoaderSmallChannels(
const device T* src_,
threadgroup T* dst_,
const int2 offsets,
const constant MLXConvParams<2>* params_,
const constant ImplicitGemmConv2DParams* gemm_params_,
uint simd_group_id [[simdgroup_index_in_threadgroup]],
uint simd_lane_id [[thread_index_in_simdgroup]])
: src_ld(params_->wt_strides[0]),
thread_idx(simd_group_id * 32 + simd_lane_id),
bi(thread_idx / TCOLS),
bj(vec_size * (thread_idx % TCOLS)),
dst(dst_ + bi * dst_ld + bj),
src(src_ + bi * src_ld),
params(params_),
weight_hw(thread_idx % TCOLS),
read_n(offsets.y + bi),
do_read(read_n + BN <= gemm_params_->N) {}
/* Load from device memory into threadgroup memory - without bound checking */
METAL_FUNC void load_unsafe() const {
if (bi >= BROWS || bj >= BCOLS)
return;
if (read_n >= params->O || weight_hw >= params->wS[1] * params->wS[0]) {
STEEL_PRAGMA_UNROLL
for (short i = 0; i < BROWS; i += TROWS) {
STEEL_PRAGMA_UNROLL
for (short j = 0; j < vec_size; j++) {
dst[i * dst_ld + j] = T(0);
}
}
return;
}
const device T* curr_src = src + weight_hw * params->wt_strides[2];
if (BN != 8 || do_read) {
STEEL_PRAGMA_UNROLL
for (short i = 0; i < BROWS; i += TROWS) {
STEEL_PRAGMA_UNROLL
for (short j = 0; j < n_channels; j++) {
dst[i * dst_ld + j] = curr_src[i * src_ld + j];
}
STEEL_PRAGMA_UNROLL
for (short j = n_channels; j < vec_size; j++) {
dst[i * dst_ld + j] = T(0);
}
}
} else {
for (short i = 0; i < BROWS; i += TROWS) {
if (((read_n + i) < params->O)) {
STEEL_PRAGMA_UNROLL
for (short j = 0; j < n_channels; j++) {
dst[i * dst_ld + j] = curr_src[i * src_ld + j];
}
STEEL_PRAGMA_UNROLL
for (short j = n_channels; j < vec_size; j++) {
dst[i * dst_ld + j] = T(0);
}
} else {
STEEL_PRAGMA_UNROLL
for (short j = 0; j < vec_size; j++) {
dst[i * dst_ld + j] = T(0);
}
}
}
}
}
/* Iteration helper */
METAL_FUNC void next() {
weight_hw += TCOLS;
}
};
} // namespace steel
} // namespace mlx

View File

@@ -0,0 +1,288 @@
// Copyright © 2024 Apple Inc.
#pragma once
#include "mlx/backend/metal/kernels/steel/utils.h"
#include "mlx/backend/metal/kernels/steel/conv/params.h"
///////////////////////////////////////////////////////////////////////////////
// Loading helper
///////////////////////////////////////////////////////////////////////////////
namespace mlx {
namespace steel {
template <
typename T,
short BM,
short BN,
short BK,
short tgp_size,
short tgp_padding = 0>
struct Conv2DInputBlockLoaderGeneral {
// Destination dimensions
STEEL_CONST short BROWS = BM;
STEEL_CONST short BCOLS = BK;
// Read dimensions
STEEL_CONST short dst_ld = BCOLS + tgp_padding;
STEEL_CONST short vec_size = tgp_size / (BROWS * BCOLS) >= 8 ? 8 : 4;
// Thread read shape
STEEL_CONST short TCOLS = BCOLS / vec_size;
STEEL_CONST short TROWS = tgp_size / TCOLS;
// Rows / strided reads within the block
STEEL_CONST short n_rows = BROWS / TROWS;
// Thread location indices
const short thread_idx;
const short bi;
const short bj;
// threadgroup and device memory
threadgroup T* dst;
const constant MLXConvParams<2>* params;
const constant Conv2DGeneralJumpParams* jump_params;
const short base_wh;
const short base_ww;
short weight_h;
short weight_w;
const device T* src[n_rows];
int read_n[n_rows];
int read_ih[n_rows];
int read_iw[n_rows];
/* Constructor */
METAL_FUNC Conv2DInputBlockLoaderGeneral(
const device T* src_,
threadgroup T* dst_,
const int4 offsets,
const constant MLXConvParams<2>* params_,
const constant Conv2DGeneralJumpParams* jump_params_,
const short base_wh_,
const short base_ww_,
uint simd_group_id [[simdgroup_index_in_threadgroup]],
uint simd_lane_id [[thread_index_in_simdgroup]])
: thread_idx(simd_group_id * 32 + simd_lane_id),
bi(thread_idx / TCOLS),
bj(vec_size * (thread_idx % TCOLS)),
dst(dst_ + bi * dst_ld + bj),
params(params_),
jump_params(jump_params_),
base_wh(base_wh_),
base_ww(base_ww_),
weight_h(base_wh_),
weight_w(base_ww_) {
STEEL_PRAGMA_UNROLL
for (short i = 0; i < n_rows; ++i) {
int offset_nhw = offsets.y + bi + i * TROWS;
int n = offset_nhw / jump_params->adj_out_hw;
int hw = offset_nhw % jump_params->adj_out_hw;
int oh =
(hw / jump_params->adj_out_w) * jump_params->f_out_jump_h + offsets.z;
int ow =
(hw % jump_params->adj_out_w) * jump_params->f_out_jump_w + offsets.w;
int ih = oh * params->str[0] - params->pad[0];
int iw = ow * params->str[1] - params->pad[1];
read_n[i] = n;
read_ih[i] = ih;
read_iw[i] = iw;
// Read from input if in bounds
src[i] = src_ + n * params->in_strides[0] + bj;
}
}
/* Load from device memory into threadgroup memory - without bound checking */
METAL_FUNC void load_unsafe() const {
STEEL_PRAGMA_UNROLL
for (short i = 0, is = 0; i < n_rows; ++i, is += TROWS) {
// Find bounds
int n = read_n[i];
int h_flip = params->flip ? params->wS[0] - weight_h - 1 : weight_h;
int w_flip = params->flip ? params->wS[1] - weight_w - 1 : weight_w;
int ih_dil = read_ih[i] + h_flip * params->kdil[0];
int iw_dil = read_iw[i] + w_flip * params->kdil[1];
int ih = ih_dil / params->idil[0];
int iw = iw_dil / params->idil[1];
size_t offset = ih * params->in_strides[1] + iw * params->in_strides[2];
// Read from input if in bounds
if ((n < params->N) && (ih_dil >= 0 && ih < params->iS[0]) &&
(iw_dil >= 0 && iw < params->iS[1])) {
STEEL_PRAGMA_UNROLL
for (short j = 0; j < vec_size; ++j) {
dst[is * dst_ld + j] = (src[i])[offset + j];
}
}
// Zero pad otherwise
else {
STEEL_PRAGMA_UNROLL
for (short j = 0; j < vec_size; ++j) {
dst[is * dst_ld + j] = T(0);
}
}
}
}
/* Iteration helper */
METAL_FUNC void next() {
weight_w += jump_params->f_wgt_jump_w;
if (weight_w < params->wS[1]) {
return;
}
weight_w = base_ww;
weight_h += jump_params->f_wgt_jump_h;
if (weight_h < params->wS[0]) {
return;
}
weight_h = base_wh;
STEEL_PRAGMA_UNROLL
for (short i = 0; i < n_rows; i++) {
src[i] += BK;
}
}
};
template <
typename T,
short BM,
short BN,
short BK,
short tgp_size,
short tgp_padding = 0>
struct Conv2DWeightBlockLoaderGeneral {
// Destination dimensions
STEEL_CONST short BROWS = BN;
STEEL_CONST short BCOLS = BK;
// Read dimensions
STEEL_CONST short dst_ld = BCOLS + tgp_padding;
STEEL_CONST short vec_size =
(BN == 8) ? 1 : (tgp_size / (BROWS * BCOLS) >= 8 ? 8 : 4);
// Thread read shape
STEEL_CONST short TCOLS = BCOLS / vec_size;
STEEL_CONST short TROWS = tgp_size / TCOLS;
// Rows / strided reads within the block
STEEL_CONST short n_rows = BROWS / TROWS;
// Leading dimension for src
const int src_ld;
// Thread location indices
const short thread_idx;
const short bi;
const short bj;
// threadgroup and device memory
threadgroup T* dst;
const device T* src;
const constant MLXConvParams<2>* params;
const constant Conv2DGeneralJumpParams* jump_params;
const short base_wh;
const short base_ww;
short weight_h;
short weight_w;
const int start_row;
/* Constructor */
METAL_FUNC Conv2DWeightBlockLoaderGeneral(
const device T* src_,
threadgroup T* dst_,
const int2 offsets,
const constant MLXConvParams<2>* params_,
const constant Conv2DGeneralJumpParams* jump_params_,
const short base_wh_,
const short base_ww_,
uint simd_group_id [[simdgroup_index_in_threadgroup]],
uint simd_lane_id [[thread_index_in_simdgroup]])
: src_ld(params_->wt_strides[0]),
thread_idx(simd_group_id * 32 + simd_lane_id),
bi(thread_idx / TCOLS),
bj(vec_size * (thread_idx % TCOLS)),
dst(dst_ + bi * dst_ld + bj),
src(src_ + bi * src_ld + bj),
params(params_),
jump_params(jump_params_),
base_wh(base_wh_),
base_ww(base_ww_),
weight_h(base_wh_),
weight_w(base_ww_),
start_row(offsets.y + bi) {}
/* Load from device memory into threadgroup memory - without bound checking */
METAL_FUNC void load_unsafe() const {
const device T* curr_src = src + weight_h * params->wt_strides[1] +
weight_w * params->wt_strides[2];
if ((start_row + BN <= params->O)) {
STEEL_PRAGMA_UNROLL
for (short i = 0; i < BN; i += TROWS) {
STEEL_PRAGMA_UNROLL
for (short j = 0; j < vec_size; j++) {
dst[i * dst_ld + j] = curr_src[i * src_ld + j];
}
}
} else {
for (short i = 0; i < BN; i += TROWS) {
if ((start_row + i) < params->O) {
STEEL_PRAGMA_UNROLL
for (short j = 0; j < vec_size; j++) {
dst[i * dst_ld + j] = curr_src[i * src_ld + j];
}
} else {
STEEL_PRAGMA_UNROLL
for (short j = 0; j < vec_size; j++) {
dst[i * dst_ld + j] = T(0);
}
}
}
}
}
/* Iteration helper */
METAL_FUNC void next() {
weight_w += jump_params->f_wgt_jump_w;
if (weight_w < params->wS[1]) {
return;
}
weight_w = base_ww;
weight_h += jump_params->f_wgt_jump_h;
if (weight_h < params->wS[0]) {
return;
}
weight_h = base_wh;
src += BK;
}
};
} // namespace steel
} // namespace mlx

View File

@@ -0,0 +1,62 @@
// Copyright © 2024 Apple Inc.
#pragma once
template <int NDIM>
struct MLXConvParams {
const int N; // Batch size
const int C; // In channels
const int O; // Out channels
const int iS[NDIM]; // Input spatial dim
const int wS[NDIM]; // Weight spatial dim
const int oS[NDIM]; // Output spatial dim
const int str[NDIM]; // Kernel strides
const int pad[NDIM]; // Input padding
const int kdil[NDIM]; // Kernel dilation
const int idil[NDIM]; // Input dilation
const size_t in_strides[NDIM + 2]; // In strides
const size_t wt_strides[NDIM + 2]; // Wt strides
const size_t out_strides[NDIM + 2]; // Out strides
const int groups; // Input channel groups
const bool flip;
};
namespace mlx {
namespace steel {
struct ImplicitGemmConv2DParams {
const int M;
const int N;
const int K;
const int gemm_k_iterations;
const int inp_jump_w;
const int inp_jump_h;
const int inp_jump_c;
const int tiles_n;
const int tiles_m;
const int swizzle_log;
};
struct Conv2DGeneralJumpParams {
const int f_wgt_jump_h;
const int f_wgt_jump_w;
const int f_out_jump_h;
const int f_out_jump_w;
const int adj_out_h;
const int adj_out_w;
const int adj_out_hw;
const int adj_implicit_m;
};
struct Conv2DGeneralBaseInfo {
int weight_base;
int weight_size;
};
} // namespace steel
} // namespace mlx

View File

@@ -4,6 +4,7 @@
#include "mlx/backend/metal/kernels/steel/gemm/loader.h"
#include "mlx/backend/metal/kernels/steel/gemm/mma.h"
#include "mlx/backend/metal/kernels/steel/gemm/params.h"
#include "mlx/backend/metal/kernels/steel/gemm/transforms.h"
#include "mlx/backend/metal/kernels/steel/utils.h"

View File

@@ -2,9 +2,15 @@
#pragma once
#include <metal_simdgroup>
#include <metal_simdgroup_matrix>
#include <metal_stdlib>
#include "mlx/backend/metal/kernels/steel/gemm/transforms.h"
#include "mlx/backend/metal/kernels/steel/utils.h"
using namespace metal;
///////////////////////////////////////////////////////////////////////////////
// MMA helper
///////////////////////////////////////////////////////////////////////////////
@@ -167,6 +173,9 @@ struct BlockMMA {
C += (sm + tm) * ldc + (tn + sn);
dst_tile_dims -= short2(tn + sn, sm + tm);
if (dst_tile_dims.x <= 0 || dst_tile_dims.y <= 0)
return;
STEEL_PRAGMA_UNROLL
for (int i = 0; i < TM; i++) {
if (i * TM_stride < dst_tile_dims.y) {
@@ -236,6 +245,9 @@ struct BlockMMA {
D += (sm + tm) * ldd + tn + sn;
dst_tile_dims -= short2(tn + sn, sm + tm);
if (dst_tile_dims.x <= 0 || dst_tile_dims.y <= 0)
return;
STEEL_PRAGMA_UNROLL
for (int i = 0; i < TM; i++) {
if (i * TM_stride < dst_tile_dims.y) {

View File

@@ -1,5 +0,0 @@
// Copyright © 2024 Apple Inc.
#pragma once
#include "mlx/backend/metal/kernels/steel/gemm/params.h"

View File

@@ -3,7 +3,6 @@
#pragma once
#include <metal_stdlib>
#include "mlx/backend/metal/kernels/steel/host.h"
#define STEEL_CONST static constant constexpr const
#define STEEL_PRAGMA_UNROLL _Pragma("clang loop unroll(full)")

View File

@@ -8,7 +8,7 @@
#include "mlx/backend/metal/copy.h"
#include "mlx/backend/metal/device.h"
#include "mlx/backend/metal/kernels/defines.h"
#include "mlx/backend/metal/kernels/steel/host.h"
#include "mlx/backend/metal/kernels/steel/gemm/params.h"
#include "mlx/backend/metal/matmul.h"
#include "mlx/backend/metal/mps/gemm.h"
#include "mlx/backend/metal/utils.h"

View File

@@ -2696,33 +2696,78 @@ array cummin(
namespace {
// Conv helpers
inline int conv_out_axis_size(
int in_dim,
int wt_dim,
int stride,
int padding,
int dilation) {
int ker = dilation * (wt_dim - 1);
return ((in_dim + 2 * padding - ker - 1) / stride) + 1;
inline int conv_out_axis_size(int in_dim, int wt_dim, int stride, int padding) {
return ((in_dim + padding - wt_dim) / stride) + 1;
}
// Conv helpers
inline int dilate_size(int dim, int dil) {
return 1 + dil * (dim - 1);
}
inline std::vector<int> conv_out_shape(
const std::vector<int>& in_shape,
const std::vector<int>& wt_shape,
const std::vector<int>& strides,
const std::vector<int>& pads,
const std::vector<int>& dilation) {
const std::vector<int>& pads_lo,
const std::vector<int>& pads_hi,
const std::vector<int>& kernel_dilation,
const std::vector<int>& input_dilation) {
int N = in_shape[0];
int O = wt_shape[0];
std::vector<int> out_shape(in_shape.size());
int i = 0;
out_shape[i++] = N;
int spatial_dims = in_shape.size() - 2;
if (strides.size() != spatial_dims) {
std::ostringstream msg;
msg << "[conv] Invalid strides " << strides << "for " << spatial_dims
<< "D convolution.";
throw std::invalid_argument(msg.str());
}
if (pads_lo.size() != spatial_dims || pads_hi.size() != spatial_dims) {
std::ostringstream msg;
msg << "[conv] Invalid pading " << pads_lo << " | " << pads_hi << "for "
<< spatial_dims << "D convolution.";
throw std::invalid_argument(msg.str());
}
if (kernel_dilation.size() != spatial_dims) {
std::ostringstream msg;
msg << "[conv] Invalid kernel dilation " << kernel_dilation << "for "
<< spatial_dims << "D convolution.";
throw std::invalid_argument(msg.str());
}
if (input_dilation.size() != spatial_dims) {
std::ostringstream msg;
msg << "[conv] Invalid input dilation " << input_dilation << "for "
<< spatial_dims << "D convolution.";
throw std::invalid_argument(msg.str());
}
for (; i < in_shape.size() - 1; i++) {
if (pads[i - 1] < 0) {
if (kernel_dilation[i - 1] <= 0) {
std::ostringstream msg;
msg << "[conv] Kernel dilation sizes must be positive."
<< " Got kernel dilation " << kernel_dilation << ".";
throw std::invalid_argument(msg.str());
}
if (input_dilation[i - 1] <= 0) {
std::ostringstream msg;
msg << "[conv] Input dilation sizes must be positive."
<< " Got input dilation " << input_dilation << ".";
throw std::invalid_argument(msg.str());
}
if (pads_lo[i - 1] < 0 || pads_hi[i - 1] < 0) {
std::ostringstream msg;
msg << "[conv] Padding sizes must be non-negative."
<< " Got padding " << pads << ".";
<< " Got padding " << pads_lo << " | " << pads_hi << ".";
throw std::invalid_argument(msg.str());
}
@@ -2733,22 +2778,19 @@ inline std::vector<int> conv_out_shape(
throw std::invalid_argument(msg.str());
}
if (dilation[i - 1] <= 0) {
std::ostringstream msg;
msg << "[conv] Dilation sizes must be positive."
<< " Got dilation " << dilation << ".";
throw std::invalid_argument(msg.str());
}
int kd = dilate_size(wt_shape[i], kernel_dilation[i - 1]);
int id = dilate_size(in_shape[i], input_dilation[i - 1]);
out_shape[i] = conv_out_axis_size(
in_shape[i], wt_shape[i], strides[i - 1], pads[i - 1], dilation[i - 1]);
id, kd, strides[i - 1], pads_lo[i - 1] + pads_hi[i - 1]);
if (out_shape[i] <= 0) {
std::ostringstream msg;
msg << "[conv] Spatial dimensions of input after padding "
<< " cannot be smaller than weight spatial dimensions."
<< " Got input with shape " << in_shape << " and padding " << pads
<< " for weight of shape " << wt_shape << ".";
<< " Got error at axis " << i << " for input with shape " << in_shape
<< ", padding low " << pads_lo << ", padding high " << pads_hi
<< ", and weight of shape " << wt_shape << ".";
throw std::invalid_argument(msg.str());
}
}
@@ -2803,43 +2845,16 @@ array conv1d(
int dilation /* = 1 */,
int groups /* = 1 */,
StreamOrDevice s /* = {} */) {
// Run checks
if (groups != 1) {
throw std::invalid_argument("[conv1d] Cannot handle groups != 1 yet");
}
if (dilation != 1) {
throw std::invalid_argument("[conv1d] Cannot handle dilation != 1 yet");
}
// Run checks
run_conv_checks(in_, wt_, 1);
auto in = in_;
auto wt = wt_;
// Type promotion
auto out_type = promote_types(in.dtype(), wt.dtype());
in = astype(in, out_type, s);
wt = astype(wt, out_type, s);
std::vector<int> strides_vec = {stride};
std::vector<int> padding_vec = {padding};
std::vector<int> dilation_vec = {dilation};
// Get output shapes
std::vector<int> out_shape = conv_out_shape(
in.shape(), wt.shape(), strides_vec, padding_vec, dilation_vec);
return array(
out_shape,
in.dtype(),
std::make_unique<Convolution>(
to_stream(s),
padding_vec,
strides_vec,
dilation_vec,
std::vector<int>(1, 1)),
{in, wt});
return conv_general(
/* const array& input = */ in_,
/* const array& weight = */ wt_,
/* std::vector<int> stride = */ {stride},
/* std::vector<int> padding = */ {padding},
/* std::vector<int> kernel_dilation = */ {dilation},
/* std::vector<int> input_dilation = */ {1},
/* int groups = */ groups,
/* bool flip = */ false,
s);
}
/** 2D convolution with a filter */
@@ -2851,42 +2866,98 @@ array conv2d(
const std::pair<int, int>& dilation /* = {1, 1} */,
int groups /* = 1 */,
StreamOrDevice s /* = {} */) {
return conv_general(
/* const array& input = */ in_,
/* const array& weight = */ wt_,
/* std::vector<int> stride = */ {stride.first, stride.second},
/* std::vector<int> padding = */ {padding.first, padding.second},
/* std::vector<int> kernel_dilation = */
{dilation.first, dilation.second},
/* std::vector<int> input_dilation = */ {1, 1},
/* int groups = */ groups,
/* bool flip = */ false,
s);
}
/** General convolution with a filter */
array conv_general(
array in,
array wt,
std::vector<int> stride /* = {} */,
std::vector<int> padding_lo /* = {} */,
std::vector<int> padding_hi /* = {} */,
std::vector<int> kernel_dilation /* = {} */,
std::vector<int> input_dilation /* = {} */,
int groups /* = 1 */,
bool flip /* = false */,
StreamOrDevice s /* = {} */) {
// Run checks
if (groups != 1) {
throw std::invalid_argument("[conv2d] Cannot handle groups != 1 yet");
throw std::invalid_argument("[conv] Cannot handle groups != 1 yet");
}
if (dilation.first != 1 || dilation.second != 1) {
throw std::invalid_argument("[conv2d] Cannot handle dilation != 1 yet");
int spatial_dims = in.ndim() - 2;
if (spatial_dims < 1 || spatial_dims > 2) {
throw std::invalid_argument(
"[conv] Can only work with inputs that have 1 or 2 spatial dimensions."
" The inputs must be in the format [N, ..., C_in]");
}
// Run checks
run_conv_checks(in_, wt_, 2);
auto in = in_;
auto wt = wt_;
run_conv_checks(in, wt, spatial_dims);
// Type promotion
auto out_type = promote_types(in.dtype(), wt.dtype());
in = astype(in, out_type, s);
wt = astype(wt, out_type, s);
std::vector<int> strides_vec = {stride.first, stride.second};
std::vector<int> padding_vec = {padding.first, padding.second};
std::vector<int> dilation_vec = {dilation.first, dilation.second};
if (stride.size() <= 1) {
int stride_int = stride.size() ? stride[0] : 1;
stride = std::vector<int>(spatial_dims, stride_int);
}
if (padding_lo.size() <= 1) {
int padding_int = padding_lo.size() ? padding_lo[0] : 0;
padding_lo = std::vector<int>(spatial_dims, padding_int);
}
if (padding_hi.size() <= 1) {
int padding_int = padding_hi.size() ? padding_hi[0] : 0;
padding_hi = std::vector<int>(spatial_dims, padding_int);
}
if (kernel_dilation.size() <= 1) {
int kernel_dilation_int = kernel_dilation.size() ? kernel_dilation[0] : 1;
kernel_dilation = std::vector<int>(spatial_dims, kernel_dilation_int);
}
if (input_dilation.size() <= 1) {
int input_dilation_int = input_dilation.size() ? input_dilation[0] : 1;
input_dilation = std::vector<int>(spatial_dims, input_dilation_int);
}
// Get output shapes
std::vector<int> out_shape = conv_out_shape(
in.shape(), wt.shape(), strides_vec, padding_vec, dilation_vec);
in.shape(),
wt.shape(),
stride,
padding_lo,
padding_hi,
kernel_dilation,
input_dilation);
return array(
out_shape,
in.dtype(),
std::make_unique<Convolution>(
to_stream(s),
padding_vec,
strides_vec,
dilation_vec,
std::vector<int>(2, 1)),
stride,
padding_lo,
kernel_dilation,
input_dilation,
groups,
flip),
{in, wt});
}

View File

@@ -1026,6 +1026,43 @@ array cummin(
/** Convolution operations */
/** General convolution with a filter */
array conv_general(
array input,
array weight,
std::vector<int> stride = {},
std::vector<int> padding_lo = {},
std::vector<int> padding_hi = {},
std::vector<int> kernel_dilation = {},
std::vector<int> input_dilation = {},
int groups = 1,
bool flip = false,
StreamOrDevice s = {});
/** General convolution with a filter */
inline array conv_general(
const array& input,
const array& weight,
std::vector<int> stride = {},
std::vector<int> padding = {},
std::vector<int> kernel_dilation = {},
std::vector<int> input_dilation = {},
int groups = 1,
bool flip = false,
StreamOrDevice s = {}) {
return conv_general(
/* const array& input = */ input,
/* const array& weight = */ weight,
/* std::vector<int> stride = */ stride,
/* std::vector<int> padding_lo = */ padding,
/* std::vector<int> padding_hi = */ padding,
/* std::vector<int> kernel_dilation = */ kernel_dilation,
/* std::vector<int> input_dilation = */ input_dilation,
/* int groups = */ groups,
/* bool flip = */ flip,
/* StreamOrDevice s = */ s);
}
/** 1D convolution with a filter */
array conv1d(
const array& input,

View File

@@ -679,21 +679,13 @@ bool Concatenate::is_equivalent(const Primitive& other) const {
return axis_ == c_other.axis_;
}
std::vector<array> Convolution::vjp(
const std::vector<array>& primals,
const std::vector<array>& cotangents,
const std::vector<int>& argnums,
const std::vector<array>&) {
assert(primals.size() == 2);
std::vector<array> grads;
// Collect info
auto& in = primals[0];
auto& wt = primals[1];
auto cotan = cotangents[0];
int O = wt.shape(0);
array conv_weight_backward_patches(
const array& in,
const array& wt,
const array& cotan,
const std::vector<int>& kernel_strides,
const std::vector<int>& padding,
StreamOrDevice s) {
// Resolve Padded input shapes and strides
std::vector<int> padding_starts(in.ndim(), 0);
std::vector<int> padding_ends = in.shape();
@@ -701,9 +693,9 @@ std::vector<array> Convolution::vjp(
// padded shape
for (int i = 1; i < in.ndim() - 1; i++) {
in_padded_shape[i] += 2 * padding_[i - 1];
padding_ends[i] += padding_[i - 1];
padding_starts[i] += padding_[i - 1];
in_padded_shape[i] += 2 * padding[i - 1];
padding_ends[i] += padding[i - 1];
padding_starts[i] += padding[i - 1];
}
// padded strides (contiguous)
@@ -712,6 +704,12 @@ std::vector<array> Convolution::vjp(
in_padded_strides[i] = in_padded_strides[i + 1] * in_padded_shape[i + 1];
}
// Pad input
std::vector<int> padded_axes(in.ndim() - 2, 0);
std::iota(padded_axes.begin(), padded_axes.end(), 1);
auto in_padded =
pad(in, padded_axes, padding, padding, array(0, in.dtype()), s);
// Resolve strided patches
// patches are shaped as
@@ -726,62 +724,108 @@ std::vector<array> Convolution::vjp(
std::vector<size_t> patches_strides(patches_shape.size(), 1);
patches_strides[0] = in_padded_strides[0];
for (int i = 1; i < n_spatial_dim + 1; i++) {
patches_strides[i] = in_padded_strides[i] * kernel_strides_[i - 1];
patches_strides[i] = in_padded_strides[i] * kernel_strides[i - 1];
}
for (int i = 1; i < in.ndim(); i++) {
patches_strides[n_spatial_dim + i] = in_padded_strides[i];
}
// Reshape cotangents and weights for gemm
cotan = reshape(cotangents[0], {-1, O}, stream());
auto weight_reshaped = reshape(wt, {O, -1}, stream());
// Make patches from in
auto in_patches = as_strided(in_padded, patches_shape, patches_strides, 0, s);
// Prepare for matmul
int O = wt.shape(0);
auto cotan_mat = reshape(cotan, {-1, O}, s);
in_patches = reshape(in_patches, {cotan_mat.shape(0), -1}, s);
auto grad = matmul(transpose(cotan_mat, {1, 0}, s), in_patches, s);
grad = reshape(grad, wt.shape(), s);
return grad;
}
std::vector<array> Convolution::vjp(
const std::vector<array>& primals,
const std::vector<array>& cotangents,
const std::vector<int>& argnums,
const std::vector<array>&) {
assert(primals.size() == 2);
std::vector<array> grads;
// Collect info
auto& in = primals[0];
auto& wt = primals[1];
auto& cotan = cotangents[0];
for (int a : argnums) {
// Grads for input
if (a == 0) {
// Gemm with cotangents to get patches
auto grad_patches = matmul(cotan, weight_reshaped, stream());
std::vector<int> padding_lo = padding_;
std::vector<int> padding_hi = padding_;
// Prepare base grad array to accumulate on
int in_padded_size = in_padded_strides[0] * in_padded_shape[0];
auto grad = zeros(
{
in_padded_size,
},
in.dtype(),
for (int i = 0; i < padding_lo.size(); ++i) {
int wt_size = 1 + kernel_dilation_[i] * (wt.shape(1 + i) - 1);
padding_lo[i] = wt_size - padding_[i] - 1;
int in_size = 1 + input_dilation_[i] * (in.shape(1 + i) - 1);
int out_size = 1 + kernel_strides_[i] * (cotan.shape(1 + i) - 1);
padding_hi[i] = in_size - out_size + padding_[i];
}
auto wt_trans = swapaxes(wt, 0, -1, stream());
auto grad = conv_general(
/* const array& input = */ cotan,
/* const array& weight = */ wt_trans,
/* std::vector<int> stride = */ input_dilation_,
/* std::vector<int> padding_lo = */ padding_lo,
/* std::vector<int> padding_hi = */ padding_hi,
/* std::vector<int> kernel_dilation = */ kernel_dilation_,
/* std::vector<int> input_dilation = */ kernel_strides_,
/* int groups = */ 1,
/* bool flip = */ !flip_,
stream());
// Create index map
int patches_size = grad_patches.size();
auto idx = arange(in_padded_size, stream());
idx = as_strided(idx, patches_shape, patches_strides, 0, stream());
idx = reshape(idx, {patches_size}, stream());
// Flatten patches and scatter
auto flat_patches = reshape(grad_patches, {patches_size, 1}, stream());
grad = scatter_add(grad, idx, flat_patches, 0, stream());
// Reshape and slice away padding
grad = reshape(grad, in_padded_shape, stream());
grad = slice(grad, padding_starts, padding_ends, stream());
grads.push_back(grad);
}
// Grads for weight
else if (a == 1) {
// Make patches from in
std::vector<int> padded_axes(in.ndim() - 2, 0);
std::iota(padded_axes.begin(), padded_axes.end(), 1);
auto in_padded = pad(
in, padded_axes, padding_, padding_, array(0, in.dtype()), stream());
auto in_patches =
as_strided(in_padded, patches_shape, patches_strides, 0, stream());
in_patches = reshape(in_patches, {cotan.shape(0), -1}, stream());
bool no_dilation = true;
auto grad =
matmul(transpose(cotan, {1, 0}, stream()), in_patches, stream());
grad = reshape(grad, wt.shape(), stream());
grads.push_back(grad);
for (int i = 0; i < input_dilation_.size(); i++) {
no_dilation &= (input_dilation_[i] == 1) && (kernel_dilation_[i] == 1);
}
if (no_dilation) {
auto grad = conv_weight_backward_patches(
in, wt, cotan, kernel_strides_, padding_, stream());
grads.push_back(grad);
} else {
std::vector<int> padding_lo = padding_;
std::vector<int> padding_hi = padding_;
for (int i = 0; i < padding_hi.size(); ++i) {
int in_size = 1 + input_dilation_[i] * (in.shape(1 + i) - 1);
int out_size = 1 + kernel_strides_[i] * (cotan.shape(1 + i) - 1);
int wt_size = 1 + kernel_dilation_[i] * (wt.shape(1 + i) - 1);
padding_hi[i] = out_size - in_size + wt_size - padding_[i] - 1;
}
auto in_trans = swapaxes(in, 0, -1, stream());
auto cotan_trans = swapaxes(cotan, 0, -1, stream());
auto grad_trans = conv_general(
/* const array& input = */ in_trans,
/* const array& weight = */ cotan_trans,
/* std::vector<int> stride = */ kernel_dilation_,
/* std::vector<int> padding_lo = */ padding_lo,
/* std::vector<int> padding_hi = */ padding_hi,
/* std::vector<int> kernel_dilation = */ kernel_strides_,
/* std::vector<int> input_dilation = */ input_dilation_,
/* int groups = */ 1,
/* bool flip = */ flip_,
stream());
auto grad = swapaxes(grad_trans, 0, -1, stream());
grads.push_back(grad);
}
}
}
@@ -793,7 +837,8 @@ bool Convolution::is_equivalent(const Primitive& other) const {
return padding_ == c_other.padding_ &&
kernel_strides_ == c_other.kernel_strides_ &&
kernel_dilation_ == c_other.kernel_dilation_ &&
input_dilation_ == c_other.input_dilation_;
input_dilation_ == c_other.input_dilation_ &&
groups_ == c_other.groups_ && flip_ == c_other.flip_;
}
std::vector<array> Copy::vjp(

View File

@@ -544,15 +544,19 @@ class Convolution : public UnaryPrimitive {
public:
explicit Convolution(
Stream stream,
const std::vector<int>& padding,
const std::vector<int>& kernel_strides,
const std::vector<int>& padding,
const std::vector<int>& kernel_dilation,
const std::vector<int>& input_dilation)
const std::vector<int>& input_dilation,
const int groups = 1,
const bool flip = false)
: UnaryPrimitive(stream),
padding_(padding),
kernel_strides_(kernel_strides),
kernel_dilation_(kernel_dilation),
input_dilation_(input_dilation){};
input_dilation_(input_dilation),
groups_(groups),
flip_(flip){};
void eval_cpu(const std::vector<array>& inputs, array& out) override;
void eval_gpu(const std::vector<array>& inputs, array& out) override;
@@ -571,6 +575,8 @@ class Convolution : public UnaryPrimitive {
std::vector<int> kernel_strides_;
std::vector<int> kernel_dilation_;
std::vector<int> input_dilation_;
int groups_;
bool flip_;
void eval(const std::vector<array>& inputs, array& out);
};

View File

@@ -3081,7 +3081,7 @@ void init_ops(py::module_& m) {
py::kw_only(),
"stream"_a = none,
R"pbdoc(
conv2d(input: array, weight: array, /, stride: Union[int, Tuple[int, int]] = 1, padding: Union[int, Tuple[int, int]] = 0, dilation: Union[int, Tuple[int, int]] = 1, groups: Union[int, Tuple[int, int]] = 1, *, stream: Union[None, Stream, Device] = None) -> array
conv2d(input: array, weight: array, /, stride: Union[int, Tuple[int, int]] = 1, padding: Union[int, Tuple[int, int]] = 0, dilation: Union[int, Tuple[int, int]] = 1, groups: int = 1, *, stream: Union[None, Stream, Device] = None) -> array
2D convolution over an input with several channels
@@ -3105,6 +3105,114 @@ void init_ops(py::module_& m) {
array: The convolved array.
)pbdoc");
m.def(
"conv_general",
[](const array& input,
const array& weight,
const std::variant<int, std::vector<int>>& stride,
const std::variant<
int,
std::vector<int>,
std::pair<std::vector<int>, std::vector<int>>>& padding,
const std::variant<int, std::vector<int>>& kernel_dilation,
const std::variant<int, std::vector<int>>& input_dilation,
int groups,
bool flip,
StreamOrDevice s) {
std::vector<int> stride_vec;
std::vector<int> padding_lo_vec;
std::vector<int> padding_hi_vec;
std::vector<int> kernel_dilation_vec;
std::vector<int> input_dilation_vec;
if (auto pv = std::get_if<int>(&stride); pv) {
stride_vec.push_back(*pv);
} else {
stride_vec = std::get<std::vector<int>>(stride);
}
if (auto pv = std::get_if<int>(&padding); pv) {
padding_lo_vec.push_back(*pv);
padding_hi_vec.push_back(*pv);
} else if (auto pv = std::get_if<std::vector<int>>(&padding); pv) {
padding_lo_vec = *pv;
padding_hi_vec = *pv;
} else {
auto [pl, ph] =
std::get<std::pair<std::vector<int>, std::vector<int>>>(padding);
padding_lo_vec = pl;
padding_hi_vec = ph;
}
if (auto pv = std::get_if<int>(&kernel_dilation); pv) {
kernel_dilation_vec.push_back(*pv);
} else {
kernel_dilation_vec = std::get<std::vector<int>>(kernel_dilation);
}
if (auto pv = std::get_if<int>(&input_dilation); pv) {
input_dilation_vec.push_back(*pv);
} else {
input_dilation_vec = std::get<std::vector<int>>(input_dilation);
}
return conv_general(
/* const array& input = */ input,
/* const array& weight = */ weight,
/* std::vector<int> stride = */ stride_vec,
/* std::vector<int> padding_lo = */ padding_lo_vec,
/* std::vector<int> padding_hi = */ padding_lo_vec,
/* std::vector<int> kernel_dilation = */ kernel_dilation_vec,
/* std::vector<int> input_dilation = */ input_dilation_vec,
/* int groups = */ groups,
/* bool flip = */ flip,
s);
},
"input"_a,
"weight"_a,
py::pos_only(),
"stride"_a = 1,
"padding"_a = 0,
"kernel_dilation"_a = 1,
"input_dilation"_a = 1,
"groups"_a = 1,
"flip"_a = false,
py::kw_only(),
"stream"_a = none,
R"pbdoc(
conv_general(input: array, weight: array, /, stride: Union[int, List[int]] = 1, padding: Union[int, List[int], Tuple[List[int], List[int]]] = 0, kernel_dilation: Union[int, List[int]] = 1, input_dilation: Union[int, List[int]] = 1, groups: int = 1, flip: bool = false, *, stream: Union[None, Stream, Device] = None) -> array
General convolution over an input with several channels
.. note::
* Only 1d and 2d convolutions are supported at the moment
* the default ``groups=1`` is currently supported.
Args:
input (array): Input array of shape ``(N, ..., C_in)``
weight (array): Weight array of shape ``(C_out, ..., C_in)``
stride (int or list(int), optional): :obj:`list` with kernel strides.
All spatial dimensions get the same stride if
only one number is specified. Default: ``1``.
padding (int, list(int), or tuple(list(int), list(int)), optional):
:obj:`list` with input padding. All spatial dimensions get the same
padding if only one number is specified. Default: ``0``.
kernel_dilation (int or list(int), optional): :obj:`list` with
kernel dilation. All spatial dimensions get the same dilation
if only one number is specified. Default: ``1``
input_dilation (int or list(int), optional): :obj:`list` with
input dilation. All spatial dimensions get the same dilation
if only one number is specified. Default: ``1``
groups (int, optional): Input feature groups. Default: ``1``.
flip (bool, optional): Flip the order in which the spatial dimensions of
the weights are processed. Performs the cross-correlation operator when
``flip`` is ``False`` and the convolution operator otherwise.
Default: ``False``.
Returns:
array: The convolved array.
)pbdoc");
m.def(
"save",
&mlx_save_helper,
"file"_a,

View File

@@ -1,4 +1,4 @@
# Copyright © 2023 Apple Inc.
# Copyright © 2023-2024 Apple Inc.
import math
import unittest
@@ -388,13 +388,8 @@ class TestConv(mlx_tests.MLXTestCase):
_, outs_mx = mx.vjp(
f,
[
in_mx,
wt_mx,
],
[
ct_mx,
],
[in_mx, wt_mx],
[ct_mx],
)
pt_grad_in = F.grad.conv1d_input(
in_pt.shape,
@@ -428,18 +423,218 @@ class TestConv(mlx_tests.MLXTestCase):
self.assertTrue(np.allclose(pt_grad_wt, mx_grad_wt, atol=atol))
for dtype in ("float32",):
for N, C, O in (
(1, 1, 1),
(1, 6, 1),
(1, 1, 6),
(4, 32, 64),
):
for idim, kdim, stride, padding in (
((1, 1), (1, 1), (1, 1), (0, 0)),
((3, 3), (3, 1), (1, 1), (0, 0)),
((31, 31), (5, 5), (5, 5), (2, 2)),
for N, C, O in ((1, 1, 1), (1, 6, 1), (1, 1, 6), (4, 32, 64), (4, 16, 32)):
for idim, kdim, stride, padding, dilation in (
((1, 1), (1, 1), (1, 1), (0, 0), (1, 1)),
((3, 3), (3, 1), (1, 1), (0, 0), (1, 1)),
((31, 31), (5, 5), (5, 5), (2, 2), (1, 1)),
((32, 32), (3, 3), (2, 2), (1, 1), (1, 1)),
((31, 31), (5, 5), (5, 5), (2, 2), (3, 2)),
((32, 32), (3, 3), (2, 2), (1, 1), (3, 2)),
):
run_conv2D_grad(N, C, O, idim, kdim, stride, padding, dtype=dtype)
run_conv2D_grad(
N, C, O, idim, kdim, stride, padding, dilation, dtype=dtype
)
def __conv_general_test(
self,
in_shape,
wt_shape,
stride=1,
padding=0,
kernel_dilation=1,
input_dilation=1,
groups=1,
flip=False,
np_dtype=np.float32,
atol=1e-5,
):
with self.subTest(
in_shape=in_shape,
wt_shape=wt_shape,
stride=stride,
padding=padding,
kernel_dilation=kernel_dilation,
input_dilation=input_dilation,
groups=groups,
flip=flip,
np_dtype=np_dtype,
):
scale = 1.0 / math.sqrt(np.prod(wt_shape[1:]))
in_np = np.random.normal(0.0, scale, in_shape).astype(np_dtype)
wt_np = np.random.normal(0.0, scale, wt_shape).astype(np_dtype)
in_mx, wt_mx = map(mx.array, (in_np, wt_np))
in_pt, wt_pt = map(
lambda x: torch.from_numpy(np.moveaxis(x, -1, 1)).to("cpu"),
(in_np, wt_np),
)
out_mx = mx.conv_general(
in_mx,
wt_mx,
stride=stride,
padding=padding,
kernel_dilation=kernel_dilation,
input_dilation=input_dilation,
groups=groups,
flip=flip,
)
def conv_general_pt(
inp, wt, stride, padding, kernel_dilation, input_dilation, groups, flip
):
C = inp.size()[1]
ndim = inp.ndim - 2
map_ints = lambda x: [x] * ndim if isinstance(x, int) else x
stride, padding, kernel_dilation, input_dilation = map(
map_ints, (stride, padding, kernel_dilation, input_dilation)
)
torch_convt_list = (
F.conv_transpose1d,
F.conv_transpose2d,
F.conv_transpose3d,
)
torch_conv_list = (F.conv1d, F.conv2d, F.conv3d)
conv_f = torch_conv_list[ndim - 1]
convt_f = torch_convt_list[ndim - 1]
if flip:
wt = torch.flip(wt, tuple(np.arange(2, wt.ndim)))
if not np.all(input_dilation == 1):
ones = torch.ones(
[C]
+ [
1,
]
* (ndim + 1)
).to(inp.dtype)
inp = convt_f(inp, ones, stride=input_dilation, groups=C)
return conv_f(
inp,
wt,
stride=stride,
padding=padding,
dilation=kernel_dilation,
groups=groups,
)
out_pt = conv_general_pt(
in_pt,
wt_pt,
stride=stride,
padding=padding,
kernel_dilation=kernel_dilation,
input_dilation=input_dilation,
groups=groups,
flip=flip,
)
out_pt = np.moveaxis(out_pt.numpy(), 1, -1)
self.assertEqual(out_mx.shape, out_pt.shape)
self.assertTrue(np.allclose(out_mx, out_pt, atol=atol))
@unittest.skipIf(not has_torch, "requires Torch")
def test_torch_conv_general(self):
in_shape = (2, 32, 32, 16)
wt_shape = (32, 5, 5, 16)
stride = (1, 1)
padding = (2, 2)
kernel_dilation = (2, 3)
input_dilation = (1, 1)
flip = False
self.__conv_general_test(
in_shape,
wt_shape,
stride,
padding,
kernel_dilation,
input_dilation,
flip=flip,
)
in_shape = (2, 32, 32, 16)
wt_shape = (32, 5, 10, 16)
stride = (2, 3)
padding = (0, 0)
kernel_dilation = (3, 2)
input_dilation = (2, 4)
flip = False
self.__conv_general_test(
in_shape,
wt_shape,
stride,
padding,
kernel_dilation,
input_dilation,
flip=flip,
)
in_shape = (2, 32, 32, 16)
wt_shape = (32, 5, 10, 16)
stride = (2, 2)
padding = (3, 2)
kernel_dilation = (3, 2)
input_dilation = (2, 4)
flip = False
self.__conv_general_test(
in_shape,
wt_shape,
stride,
padding,
kernel_dilation,
input_dilation,
flip=flip,
)
in_shape = (2, 32, 32, 16)
wt_shape = (32, 5, 10, 16)
stride = (2, 3)
padding = (3, 2)
kernel_dilation = (3, 2)
input_dilation = (2, 5)
flip = False
self.__conv_general_test(
in_shape,
wt_shape,
stride,
padding,
kernel_dilation,
input_dilation,
flip=flip,
)
in_shape = (2, 32, 32, 16)
wt_shape = (32, 5, 5, 16)
stride = (2, 3)
padding = (0, 0)
kernel_dilation = (3, 1)
input_dilation = (2, 5)
flip = True
self.__conv_general_test(
in_shape,
wt_shape,
stride,
padding,
kernel_dilation,
input_dilation,
flip=flip,
)
if __name__ == "__main__":