mirror of
https://github.com/ml-explore/mlx.git
synced 2025-10-22 11:14:32 +08:00
Convolution update (#651)
* Init steel conv and update Conv primitive * Update slow CPU implementation to support flipping and input dilation winograd conv routing Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
@@ -1,4 +1,4 @@
|
||||
// Copyright © 2023 Apple Inc.
|
||||
// Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
#include <algorithm>
|
||||
#include <cassert>
|
||||
@@ -7,81 +7,72 @@
|
||||
|
||||
#include "mlx/backend/metal/copy.h"
|
||||
#include "mlx/backend/metal/device.h"
|
||||
#include "mlx/backend/metal/kernels/conv_params.h"
|
||||
#include "mlx/backend/metal/kernels/defines.h"
|
||||
#include "mlx/backend/metal/kernels/steel/conv/params.h"
|
||||
#include "mlx/backend/metal/matmul.h"
|
||||
#include "mlx/backend/metal/utils.h"
|
||||
#include "mlx/primitives.h"
|
||||
#include "mlx/utils.h"
|
||||
|
||||
using namespace mlx::steel;
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
namespace {
|
||||
|
||||
void explicit_gemm_conv_1D_gpu(
|
||||
template <int N>
|
||||
void explicit_gemm_conv_ND_gpu(
|
||||
const Stream& s,
|
||||
metal::Device& d,
|
||||
const array& in,
|
||||
const array& wt,
|
||||
array out,
|
||||
const MLXConvParams<1>& conv_params) {
|
||||
// Pad input
|
||||
std::vector<int> padded_shape = {
|
||||
conv_params.N, conv_params.iS[0] + 2 * conv_params.pad[0], conv_params.C};
|
||||
array in_padded(padded_shape, in.dtype(), nullptr, {});
|
||||
const MLXConvParams<N>& conv_params) {
|
||||
// Prepare unfolding array
|
||||
std::vector<int> unfolded_shape = {
|
||||
static_cast<int>(out.size() / conv_params.O),
|
||||
static_cast<int>(wt.size() / conv_params.O)};
|
||||
array in_unfolded(unfolded_shape, in.dtype(), nullptr, {});
|
||||
|
||||
// Fill with zeros
|
||||
auto zero = array(0, in.dtype());
|
||||
copy_gpu(zero, in_padded, CopyType::Scalar, s);
|
||||
in_unfolded.set_data(allocator::malloc_or_wait(in_unfolded.nbytes()));
|
||||
|
||||
// Pick input slice from padded
|
||||
size_t data_offset = conv_params.pad[0] * in_padded.strides()[1];
|
||||
array in_padded_slice(in.shape(), in_padded.dtype(), nullptr, {});
|
||||
in_padded_slice.copy_shared_buffer(
|
||||
in_padded,
|
||||
in_padded.strides(),
|
||||
in_padded.flags(),
|
||||
in_padded_slice.size(),
|
||||
data_offset);
|
||||
// Prepare unfolding kernel
|
||||
std::ostringstream kname;
|
||||
kname << "naive_unfold_nd_" << type_to_name(in_unfolded) << "_" << N;
|
||||
auto compute_encoder = d.get_command_encoder(s.index);
|
||||
auto kernel = d.get_kernel(kname.str());
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
|
||||
// Copy input values into the slice
|
||||
copy_gpu_inplace(in, in_padded_slice, CopyType::GeneralGeneral, s);
|
||||
set_array_buffer(compute_encoder, in, 0);
|
||||
set_array_buffer(compute_encoder, in_unfolded, 1);
|
||||
|
||||
// Make strided view
|
||||
std::vector<int> strided_shape = {
|
||||
conv_params.N, conv_params.oS[0], conv_params.wS[0], conv_params.C};
|
||||
compute_encoder->setBytes(&conv_params, sizeof(conv_params), 2);
|
||||
|
||||
std::vector<size_t> strided_strides = {
|
||||
in_padded.strides()[0],
|
||||
in_padded.strides()[1] * conv_params.str[0],
|
||||
in_padded.strides()[1],
|
||||
in_padded.strides()[2]};
|
||||
auto flags = in_padded.flags();
|
||||
// Launch unfolding kernel
|
||||
int tgp_x = std::min(conv_params.C, 64);
|
||||
tgp_x = 32 * ((tgp_x + 32 - 1) / 32);
|
||||
int tgp_y = 256 / tgp_x;
|
||||
|
||||
array in_strided_view(strided_shape, in_padded.dtype(), nullptr, {});
|
||||
in_strided_view.copy_shared_buffer(
|
||||
in_padded, strided_strides, flags, in_strided_view.size(), 0);
|
||||
MTL::Size group_dims = MTL::Size(tgp_x, tgp_y, 1);
|
||||
MTL::Size grid_dims = MTL::Size(
|
||||
conv_params.C, unfolded_shape[1] / conv_params.C, unfolded_shape[0]);
|
||||
|
||||
// Materialize strided view
|
||||
std::vector<int> strided_reshape = {
|
||||
conv_params.N * conv_params.oS[0], conv_params.wS[0] * conv_params.C};
|
||||
array in_strided(strided_reshape, in_strided_view.dtype(), nullptr, {});
|
||||
copy_gpu(in_strided_view, in_strided, CopyType::General, s);
|
||||
compute_encoder->dispatchThreads(grid_dims, group_dims);
|
||||
|
||||
// Perform gemm
|
||||
std::vector<array> copies = {zero, in_padded, in_strided};
|
||||
std::vector<array> copies;
|
||||
return steel_matmul(
|
||||
s,
|
||||
d,
|
||||
/*a = */ in_strided,
|
||||
/*a = */ in_unfolded,
|
||||
/*b = */ wt,
|
||||
/*c = */ out,
|
||||
/*M = */ strided_reshape[0],
|
||||
/*M = */ unfolded_shape[0],
|
||||
/*N = */ conv_params.O,
|
||||
/*K = */ strided_reshape[1],
|
||||
/*K = */ unfolded_shape[1],
|
||||
/*batch_size_out = */ 1,
|
||||
/*a_cols = */ strided_reshape[1],
|
||||
/*b_cols = */ strided_reshape[1],
|
||||
/*a_cols = */ unfolded_shape[1],
|
||||
/*b_cols = */ unfolded_shape[1],
|
||||
/*a_transposed = */ false,
|
||||
/*b_transposed = */ true,
|
||||
/*copies = */ copies);
|
||||
@@ -95,7 +86,9 @@ void conv_1D_gpu(
|
||||
array out,
|
||||
const std::vector<int>& padding,
|
||||
const std::vector<int>& wt_strides,
|
||||
const std::vector<int>& wt_dilation) {
|
||||
const std::vector<int>& wt_dilation,
|
||||
const std::vector<int>& in_dilation,
|
||||
bool flip) {
|
||||
// Make conv params
|
||||
MLXConvParams<1> conv_params{
|
||||
/* const int N = */ in.shape(0),
|
||||
@@ -106,24 +99,19 @@ void conv_1D_gpu(
|
||||
/* const int oS[NDIM] = */ {out.shape(1)},
|
||||
/* const int str[NDIM] = */ {wt_strides[0]},
|
||||
/* const int pad[NDIM] = */ {padding[0]},
|
||||
/* const int dil[NDIM] = */ {wt_dilation[0]},
|
||||
/* const int kdil[NDIM] = */ {wt_dilation[0]},
|
||||
/* const int idil[NDIM] = */ {in_dilation[0]},
|
||||
/* const size_t in_strides[NDIM + 2] = */
|
||||
{in.strides()[0], in.strides()[1], in.strides()[2]},
|
||||
/* const size_t wt_strides[NDIM + 2] = */
|
||||
{wt.strides()[0], wt.strides()[1], wt.strides()[2]},
|
||||
/* const size_t out_strides[NDIM + 2] = */
|
||||
{out.strides()[0], out.strides()[1], out.strides()[2]},
|
||||
};
|
||||
/* const int groups = */ 1,
|
||||
/* const bool flip = */ flip};
|
||||
|
||||
// Direct to explicit gemm conv
|
||||
if (wt_dilation[0] == 1) {
|
||||
explicit_gemm_conv_1D_gpu(s, d, in, wt, out, conv_params);
|
||||
}
|
||||
|
||||
// Direct to fallback conv
|
||||
else {
|
||||
throw std::invalid_argument("[conv_1D_gpu] Dilation needs to be 1.");
|
||||
}
|
||||
return explicit_gemm_conv_ND_gpu(s, d, in, wt, out, conv_params);
|
||||
}
|
||||
|
||||
void slow_conv_2D_gpu(
|
||||
@@ -169,114 +157,262 @@ void implicit_gemm_conv_2D_gpu(
|
||||
const array& wt,
|
||||
array out,
|
||||
const MLXConvParams<2>& conv_params) {
|
||||
int bm = 32, bn = 32, bk = 16;
|
||||
// Deduce implicit gemm size
|
||||
int implicit_M = conv_params.N * conv_params.oS[0] * conv_params.oS[1];
|
||||
int implicit_N = conv_params.O;
|
||||
int implicit_K = conv_params.wS[0] * conv_params.wS[1] * conv_params.C;
|
||||
|
||||
// Determine block and warp tiles
|
||||
int wm = 2, wn = 2;
|
||||
|
||||
int bm = implicit_M >= 8192 && conv_params.C >= 64 ? 64 : 32;
|
||||
int bn = (bm == 64 || implicit_N >= 64) ? 64 : 32;
|
||||
int bk = 16;
|
||||
|
||||
if (implicit_N <= 16) {
|
||||
bn = 8;
|
||||
wm = 4;
|
||||
wn = 1;
|
||||
}
|
||||
|
||||
int tn = (implicit_N + bn - 1) / bn;
|
||||
int tm = (implicit_M + bm - 1) / bm;
|
||||
int swizzle_log = 0;
|
||||
|
||||
// Fix small channel specialization
|
||||
int n_channel_specialization = 0;
|
||||
int channel_k_iters = ((conv_params.C + bk - 1) / bk);
|
||||
int gemm_k_iters = conv_params.wS[0] * conv_params.wS[1] * channel_k_iters;
|
||||
|
||||
if (conv_params.C <= 2) {
|
||||
gemm_k_iters = (implicit_K + bk - 1) / bk;
|
||||
n_channel_specialization = conv_params.C;
|
||||
} else if (conv_params.C <= 4) {
|
||||
gemm_k_iters = ((conv_params.wS[0] * conv_params.wS[1] * 4) + bk - 1) / bk;
|
||||
n_channel_specialization = conv_params.C;
|
||||
}
|
||||
|
||||
bool small_filter = (!n_channel_specialization) &&
|
||||
(conv_params.wS[0] <= 16 && conv_params.wS[1] <= 16);
|
||||
|
||||
// Fix host side helper params
|
||||
int sign = (conv_params.flip ? -1 : 1);
|
||||
int ijw = conv_params.in_strides[2] * conv_params.kdil[1];
|
||||
int ijh = conv_params.in_strides[1] * conv_params.kdil[0];
|
||||
|
||||
int inp_jump_w = sign * ijw;
|
||||
int inp_jump_h = sign * (ijh - (conv_params.wS[1] - 1) * ijw);
|
||||
int inp_jump_c = bk - sign * (conv_params.wS[0] - 1) * ijh -
|
||||
sign * (conv_params.wS[1] - 1) * ijw;
|
||||
|
||||
// Build implicit gemm params
|
||||
ImplicitGemmConv2DParams gemm_params{
|
||||
/* const int M = */ implicit_M,
|
||||
/* const int N = */ implicit_N,
|
||||
/* const int K = */ implicit_K,
|
||||
|
||||
/* const int gemm_k_iterations = */ gemm_k_iters,
|
||||
|
||||
/* const int inp_jump_w = */ inp_jump_w,
|
||||
/* const int inp_jump_h = */ inp_jump_h,
|
||||
/* const int inp_jump_c = */ inp_jump_c,
|
||||
|
||||
/* const int tiles_n = */ tn,
|
||||
/* const int tiles_m = */ tm,
|
||||
/* const int swizzle_log = */ swizzle_log};
|
||||
|
||||
// Determine kernel
|
||||
std::ostringstream kname;
|
||||
kname << "implicit_gemm_conv_2d_" << type_to_name(out) << "_bm" << bm << "_bn"
|
||||
<< bn << "_bk" << bk << "_wm" << wm << "_wn" << wn;
|
||||
<< bn << "_bk" << bk << "_wm" << wm << "_wn" << wn << "_channel_"
|
||||
<< (n_channel_specialization ? std::to_string(n_channel_specialization)
|
||||
: "l")
|
||||
<< "_filter_" << (small_filter ? 's' : 'l');
|
||||
|
||||
// Encode and dispatch kernel
|
||||
auto compute_encoder = d.get_command_encoder(s.index);
|
||||
auto kernel = d.get_kernel(kname.str());
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
|
||||
int implicit_M = conv_params.N * conv_params.oS[0] * conv_params.oS[1];
|
||||
int implicit_N = conv_params.O;
|
||||
|
||||
size_t grid_dim_x = (implicit_N + bn - 1) / bn;
|
||||
size_t grid_dim_y = (implicit_M + bm - 1) / bm;
|
||||
// Deduce grid launch dimensions
|
||||
int tile = 1 << swizzle_log;
|
||||
size_t grid_dim_y = (tm + tile - 1) / tile;
|
||||
size_t grid_dim_x = tn * tile;
|
||||
|
||||
MTL::Size group_dims = MTL::Size(32, wn, wm);
|
||||
MTL::Size grid_dims = MTL::Size(grid_dim_x, grid_dim_y, 1);
|
||||
|
||||
// Encode arrays
|
||||
set_array_buffer(compute_encoder, in, 0);
|
||||
set_array_buffer(compute_encoder, wt, 1);
|
||||
set_array_buffer(compute_encoder, out, 2);
|
||||
|
||||
// Encode params
|
||||
compute_encoder->setBytes(&conv_params, sizeof(MLXConvParams<2>), 3);
|
||||
compute_encoder->setBytes(&gemm_params, sizeof(ImplicitGemmConv2DParams), 4);
|
||||
|
||||
// Launch kernel
|
||||
compute_encoder->dispatchThreadgroups(grid_dims, group_dims);
|
||||
}
|
||||
|
||||
void explicit_gemm_conv_2D_gpu(
|
||||
void implicit_gemm_conv_2D_general_gpu(
|
||||
const Stream& s,
|
||||
metal::Device& d,
|
||||
const array& in,
|
||||
const array& wt,
|
||||
array out,
|
||||
const MLXConvParams<2>& conv_params) {
|
||||
// Pad input
|
||||
std::vector<int> padded_shape = {
|
||||
conv_params.N,
|
||||
conv_params.iS[0] + 2 * conv_params.pad[0],
|
||||
conv_params.iS[1] + 2 * conv_params.pad[1],
|
||||
conv_params.C};
|
||||
array in_padded(padded_shape, in.dtype(), nullptr, {});
|
||||
// Deduce implicit gemm size
|
||||
int implicit_M = conv_params.N * conv_params.oS[0] * conv_params.oS[1];
|
||||
int implicit_N = conv_params.O;
|
||||
int implicit_K = conv_params.wS[0] * conv_params.wS[1] * conv_params.C;
|
||||
|
||||
// Fill with zeros
|
||||
auto zero = array(0, in.dtype());
|
||||
copy_gpu(array(0, in.dtype()), in_padded, CopyType::Scalar, s);
|
||||
// Determine block and warp tiles
|
||||
int wm = 2, wn = 2;
|
||||
|
||||
// Pick input slice from padded
|
||||
size_t data_offset = conv_params.pad[0] * in_padded.strides()[1] +
|
||||
conv_params.pad[1] * in_padded.strides()[2];
|
||||
array in_padded_slice(in.shape(), in_padded.dtype(), nullptr, {});
|
||||
in_padded_slice.copy_shared_buffer(
|
||||
in_padded,
|
||||
in_padded.strides(),
|
||||
in_padded.flags(),
|
||||
in_padded_slice.size(),
|
||||
data_offset);
|
||||
// Make jump params
|
||||
int f_wgt_jump_h =
|
||||
std::lcm(conv_params.idil[0], conv_params.kdil[0]) / conv_params.kdil[0];
|
||||
int f_wgt_jump_w =
|
||||
std::lcm(conv_params.idil[1], conv_params.kdil[1]) / conv_params.kdil[1];
|
||||
|
||||
// Copy input values into the slice
|
||||
copy_gpu_inplace(in, in_padded_slice, CopyType::GeneralGeneral, s);
|
||||
int f_out_jump_h =
|
||||
std::lcm(conv_params.idil[0], conv_params.str[0]) / conv_params.str[0];
|
||||
int f_out_jump_w =
|
||||
std::lcm(conv_params.idil[1], conv_params.str[1]) / conv_params.str[1];
|
||||
|
||||
// Make strided view
|
||||
std::vector<int> strided_shape = {
|
||||
conv_params.N,
|
||||
conv_params.oS[0],
|
||||
conv_params.oS[1],
|
||||
conv_params.wS[0],
|
||||
conv_params.wS[1],
|
||||
conv_params.C};
|
||||
int adj_out_h = (conv_params.oS[0] + f_out_jump_h - 1) / f_out_jump_h;
|
||||
int adj_out_w = (conv_params.oS[1] + f_out_jump_w - 1) / f_out_jump_w;
|
||||
int adj_out_hw = adj_out_h * adj_out_w;
|
||||
int adj_implicit_m = conv_params.N * adj_out_hw;
|
||||
|
||||
std::vector<size_t> strided_strides = {
|
||||
in_padded.strides()[0],
|
||||
in_padded.strides()[1] * conv_params.str[0],
|
||||
in_padded.strides()[2] * conv_params.str[1],
|
||||
in_padded.strides()[1],
|
||||
in_padded.strides()[2],
|
||||
in_padded.strides()[3]};
|
||||
auto flags = in_padded.flags();
|
||||
Conv2DGeneralJumpParams jump_params{
|
||||
/* const int f_wgt_jump_h = */ f_wgt_jump_h,
|
||||
/* const int f_wgt_jump_w = */ f_wgt_jump_w,
|
||||
|
||||
array in_strided_view(strided_shape, in_padded.dtype(), nullptr, {});
|
||||
in_strided_view.copy_shared_buffer(
|
||||
in_padded, strided_strides, flags, in_strided_view.size(), 0);
|
||||
/* const int f_out_jump_h = */ f_out_jump_h,
|
||||
/* const int f_out_jump_w = */ f_out_jump_w,
|
||||
|
||||
// Materialize strided view
|
||||
std::vector<int> strided_reshape = {
|
||||
conv_params.N * conv_params.oS[0] * conv_params.oS[1],
|
||||
conv_params.wS[0] * conv_params.wS[1] * conv_params.C};
|
||||
array in_strided(strided_reshape, in_strided_view.dtype(), nullptr, {});
|
||||
copy_gpu(in_strided_view, in_strided, CopyType::General, s);
|
||||
/* const int adj_out_h = */ adj_out_h,
|
||||
/* const int adj_out_w = */ adj_out_w,
|
||||
/* const int adj_out_hw = */ adj_out_hw,
|
||||
/* const int adj_implicit_m = */ adj_implicit_m};
|
||||
|
||||
// Perform gemm
|
||||
std::vector<array> copies = {zero, in_padded, in_strided};
|
||||
return steel_matmul(
|
||||
s,
|
||||
d,
|
||||
/*a = */ in_strided,
|
||||
/*b = */ wt,
|
||||
/*c = */ out,
|
||||
/*M = */ strided_reshape[0],
|
||||
/*N = */ conv_params.O,
|
||||
/*K = */ strided_reshape[1],
|
||||
/*batch_size_out = */ 1,
|
||||
/*a_cols = */ strided_reshape[1],
|
||||
/*b_cols = */ strided_reshape[1],
|
||||
/*a_transposed = */ false,
|
||||
/*b_transposed = */ true,
|
||||
/*copies = */ copies);
|
||||
// Make base info
|
||||
std::vector<Conv2DGeneralBaseInfo> base_h(f_out_jump_h);
|
||||
std::vector<Conv2DGeneralBaseInfo> base_w(f_out_jump_w);
|
||||
|
||||
int jump_h = conv_params.flip ? -conv_params.kdil[0] : conv_params.kdil[0];
|
||||
int jump_w = conv_params.flip ? -conv_params.kdil[1] : conv_params.kdil[1];
|
||||
|
||||
int init_h =
|
||||
(conv_params.flip ? (conv_params.wS[0] - 1) * conv_params.kdil[0] : 0);
|
||||
int init_w =
|
||||
(conv_params.flip ? (conv_params.wS[1] - 1) * conv_params.kdil[1] : 0);
|
||||
|
||||
for (int i = 0; i < f_out_jump_h; ++i) {
|
||||
int ih_loop = i * conv_params.str[0] - conv_params.pad[0] + init_h;
|
||||
|
||||
int wh_base = 0;
|
||||
while (wh_base < conv_params.wS[0] && ih_loop % conv_params.idil[0] != 0) {
|
||||
wh_base++;
|
||||
ih_loop += jump_h;
|
||||
}
|
||||
|
||||
int wh_size =
|
||||
((conv_params.wS[0] - wh_base) + f_wgt_jump_h - 1) / f_wgt_jump_h;
|
||||
base_h[i] = {wh_base, wh_size};
|
||||
}
|
||||
|
||||
for (int j = 0; j < f_out_jump_w; ++j) {
|
||||
int iw_loop = j * conv_params.str[1] - conv_params.pad[1] + init_w;
|
||||
|
||||
int ww_base = 0;
|
||||
while (ww_base < conv_params.wS[1] && iw_loop % conv_params.idil[1] != 0) {
|
||||
ww_base++;
|
||||
iw_loop += jump_w;
|
||||
}
|
||||
|
||||
int ww_size =
|
||||
((conv_params.wS[1] - ww_base) + f_wgt_jump_w - 1) / f_wgt_jump_w;
|
||||
base_w[j] = {ww_base, ww_size};
|
||||
}
|
||||
|
||||
// Collect block sizes
|
||||
int bm = adj_implicit_m >= 8192 && conv_params.C >= 64 ? 64 : 32;
|
||||
int bn = (bm == 64 && implicit_N >= 64) ? 64 : 32;
|
||||
int bk = 16;
|
||||
|
||||
int tn = (implicit_N + bn - 1) / bn;
|
||||
int tm = (adj_implicit_m + bm - 1) / bm;
|
||||
int swizzle_log = 0;
|
||||
|
||||
// Get channel iteration info
|
||||
int channel_k_iters = ((conv_params.C + bk - 1) / bk);
|
||||
int gemm_k_iters = channel_k_iters;
|
||||
|
||||
// Fix host side helper params
|
||||
int sign = (conv_params.flip ? -1 : 1);
|
||||
int ijw = conv_params.in_strides[2] * conv_params.kdil[1];
|
||||
int ijh = conv_params.in_strides[1] * conv_params.kdil[0];
|
||||
|
||||
int inp_jump_w = sign * ijw;
|
||||
int inp_jump_h = sign * (ijh - (conv_params.wS[1] - 1) * ijw);
|
||||
int inp_jump_c = bk - sign * (conv_params.wS[0] - 1) * ijh -
|
||||
sign * (conv_params.wS[1] - 1) * ijw;
|
||||
|
||||
// Build implicit gemm params
|
||||
ImplicitGemmConv2DParams gemm_params{
|
||||
/* const int M = */ implicit_M,
|
||||
/* const int N = */ implicit_N,
|
||||
/* const int K = */ implicit_K,
|
||||
|
||||
/* const int gemm_k_iterations = */ gemm_k_iters,
|
||||
|
||||
/* const int inp_jump_w = */ inp_jump_w,
|
||||
/* const int inp_jump_h = */ inp_jump_h,
|
||||
/* const int inp_jump_c = */ inp_jump_c,
|
||||
|
||||
/* const int tiles_n = */ tn,
|
||||
/* const int tiles_m = */ tm,
|
||||
/* const int swizzle_log = */ swizzle_log};
|
||||
|
||||
// Determine kernel
|
||||
std::ostringstream kname;
|
||||
kname << "implicit_gemm_conv_2d_general_" << type_to_name(out) << "_bm" << bm
|
||||
<< "_bn" << bn << "_bk" << bk << "_wm" << wm << "_wn" << wn;
|
||||
|
||||
// Encode and dispatch kernel
|
||||
auto compute_encoder = d.get_command_encoder(s.index);
|
||||
auto kernel = d.get_kernel(kname.str());
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
|
||||
// Deduce grid launch dimensions
|
||||
int tile = 1 << swizzle_log;
|
||||
size_t grid_dim_y = (tm + tile - 1) / tile;
|
||||
size_t grid_dim_x = tn * tile;
|
||||
size_t grid_dim_z = f_out_jump_h * f_out_jump_w;
|
||||
|
||||
MTL::Size group_dims = MTL::Size(32, wn, wm);
|
||||
MTL::Size grid_dims = MTL::Size(grid_dim_x, grid_dim_y, grid_dim_z);
|
||||
|
||||
// Encode arrays
|
||||
set_array_buffer(compute_encoder, in, 0);
|
||||
set_array_buffer(compute_encoder, wt, 1);
|
||||
set_array_buffer(compute_encoder, out, 2);
|
||||
|
||||
// Encode params
|
||||
compute_encoder->setBytes(&conv_params, sizeof(MLXConvParams<2>), 3);
|
||||
compute_encoder->setBytes(&gemm_params, sizeof(ImplicitGemmConv2DParams), 4);
|
||||
compute_encoder->setBytes(&jump_params, sizeof(Conv2DGeneralJumpParams), 5);
|
||||
|
||||
compute_encoder->setBytes(
|
||||
base_h.data(), sizeof(Conv2DGeneralBaseInfo) * base_h.size(), 6);
|
||||
compute_encoder->setBytes(
|
||||
base_w.data(), sizeof(Conv2DGeneralBaseInfo) * base_w.size(), 7);
|
||||
|
||||
// Launch kernel
|
||||
compute_encoder->dispatchThreadgroups(grid_dims, group_dims);
|
||||
}
|
||||
|
||||
void winograd_conv_2D_gpu(
|
||||
@@ -301,6 +437,7 @@ void winograd_conv_2D_gpu(
|
||||
// Fill with zeros
|
||||
array zero_arr = array(0, in.dtype());
|
||||
copy_gpu(zero_arr, in_padded, CopyType::Scalar, s);
|
||||
copies_w.push_back(zero_arr);
|
||||
|
||||
// Pick input slice from padded
|
||||
size_t data_offset = conv_params.pad[0] * in_padded.strides()[1] +
|
||||
@@ -329,7 +466,8 @@ void winograd_conv_2D_gpu(
|
||||
/* const int oS[NDIM] = */ {out.shape(1), out.shape(2)},
|
||||
/* const int str[NDIM] = */ {1, 1},
|
||||
/* const int pad[NDIM] = */ {0, 0},
|
||||
/* const int dil[NDIM] = */ {1, 1},
|
||||
/* const int kdil[NDIM] = */ {1, 1},
|
||||
/* const int idil[NDIM] = */ {1, 1},
|
||||
/* const size_t in_strides[NDIM + 2] = */
|
||||
{in_padded.strides()[0],
|
||||
in_padded.strides()[1],
|
||||
@@ -339,6 +477,8 @@ void winograd_conv_2D_gpu(
|
||||
{wt.strides()[0], wt.strides()[1], wt.strides()[2], wt.strides()[3]},
|
||||
/* const size_t out_strides[NDIM + 2] = */
|
||||
{out.strides()[0], out.strides()[1], out.strides()[2], out.strides()[3]},
|
||||
/* const int groups = */ 1,
|
||||
/* const bool flip = */ false,
|
||||
};
|
||||
|
||||
int O_c = conv_params.O;
|
||||
@@ -462,6 +602,8 @@ void conv_2D_gpu(
|
||||
const std::vector<int>& padding,
|
||||
const std::vector<int>& wt_strides,
|
||||
const std::vector<int>& wt_dilation,
|
||||
const std::vector<int>& in_dilation,
|
||||
bool flip,
|
||||
std::vector<array>& copies) {
|
||||
// Make conv params
|
||||
MLXConvParams<2> conv_params{
|
||||
@@ -473,37 +615,47 @@ void conv_2D_gpu(
|
||||
/* const int oS[NDIM] = */ {out.shape(1), out.shape(2)},
|
||||
/* const int str[NDIM] = */ {wt_strides[0], wt_strides[1]},
|
||||
/* const int pad[NDIM] = */ {padding[0], padding[1]},
|
||||
/* const int dil[NDIM] = */ {wt_dilation[0], wt_dilation[1]},
|
||||
/* const int kdil[NDIM] = */ {wt_dilation[0], wt_dilation[1]},
|
||||
/* const int idil[NDIM] = */ {in_dilation[0], in_dilation[1]},
|
||||
/* const size_t in_strides[NDIM + 2] = */
|
||||
{in.strides()[0], in.strides()[1], in.strides()[2], in.strides()[3]},
|
||||
/* const size_t wt_strides[NDIM + 2] = */
|
||||
{wt.strides()[0], wt.strides()[1], wt.strides()[2], wt.strides()[3]},
|
||||
/* const size_t out_strides[NDIM + 2] = */
|
||||
{out.strides()[0], out.strides()[1], out.strides()[2], out.strides()[3]},
|
||||
/* const int groups = */ 1,
|
||||
/* const bool flip = */ flip,
|
||||
};
|
||||
|
||||
bool is_stride_one = conv_params.str[0] == 1 && conv_params.str[1] == 1;
|
||||
bool is_kdil_one = conv_params.kdil[0] == 1 && conv_params.kdil[1] == 1;
|
||||
bool is_idil_one = conv_params.idil[0] == 1 && conv_params.idil[1] == 1;
|
||||
|
||||
bool inp_large = (conv_params.in_strides[0] >= 1ul << 18);
|
||||
bool channels_large = (conv_params.C + conv_params.O) >= 512;
|
||||
bool channels_med = (conv_params.C + conv_params.O) >= 256;
|
||||
|
||||
// Direct to winograd conv
|
||||
if (conv_params.C % 32 == 0 && conv_params.O % 32 == 0 &&
|
||||
conv_params.C >= 64 && conv_params.O >= 64 && conv_params.wS[0] == 3 &&
|
||||
conv_params.wS[1] == 3 && conv_params.str[0] == 1 &&
|
||||
conv_params.str[1] == 1 && conv_params.dil[0] == 1 &&
|
||||
conv_params.dil[1] == 1) {
|
||||
winograd_conv_2D_gpu(s, d, in, wt, out, conv_params, copies);
|
||||
if (!flip && is_stride_one && is_kdil_one && is_idil_one &&
|
||||
conv_params.wS[0] == 3 && conv_params.wS[1] == 3 &&
|
||||
conv_params.C % 32 == 0 && conv_params.O % 32 == 0 &&
|
||||
(channels_large || (channels_med && inp_large))) {
|
||||
return winograd_conv_2D_gpu(s, d, in, wt, out, conv_params, copies);
|
||||
}
|
||||
|
||||
// Direct to implicit gemm conv
|
||||
else if (conv_params.C % 32 == 0 && conv_params.O % 32 == 0) {
|
||||
implicit_gemm_conv_2D_gpu(s, d, in, wt, out, conv_params);
|
||||
if (is_idil_one && (conv_params.C <= 4 || conv_params.C % 16 == 0) &&
|
||||
(conv_params.O <= 16 || conv_params.O % 16 == 0)) {
|
||||
return implicit_gemm_conv_2D_gpu(s, d, in, wt, out, conv_params);
|
||||
}
|
||||
|
||||
else if (conv_params.C % 16 == 0 && conv_params.O % 16 == 0) {
|
||||
return implicit_gemm_conv_2D_general_gpu(s, d, in, wt, out, conv_params);
|
||||
}
|
||||
|
||||
// Direct to explicit gemm conv
|
||||
else if (wt_dilation[0] == 1 && wt_dilation[1] == 1) {
|
||||
explicit_gemm_conv_2D_gpu(s, d, in, wt, out, conv_params);
|
||||
}
|
||||
|
||||
// Direct to fallback conv
|
||||
else {
|
||||
slow_conv_2D_gpu(s, d, in, wt, out, conv_params);
|
||||
return explicit_gemm_conv_ND_gpu(s, d, in, wt, out, conv_params);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -534,11 +686,31 @@ void Convolution::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
// 2D conv
|
||||
if (out.ndim() == 4) {
|
||||
conv_2D_gpu(
|
||||
s, d, in, wt, out, padding_, kernel_strides_, kernel_dilation_, copies);
|
||||
s,
|
||||
d,
|
||||
in,
|
||||
wt,
|
||||
out,
|
||||
padding_,
|
||||
kernel_strides_,
|
||||
kernel_dilation_,
|
||||
input_dilation_,
|
||||
flip_,
|
||||
copies);
|
||||
}
|
||||
// 1D conv
|
||||
else if (out.ndim() == 3) {
|
||||
conv_1D_gpu(s, d, in, wt, out, padding_, kernel_strides_, kernel_dilation_);
|
||||
conv_1D_gpu(
|
||||
s,
|
||||
d,
|
||||
in,
|
||||
wt,
|
||||
out,
|
||||
padding_,
|
||||
kernel_strides_,
|
||||
kernel_dilation_,
|
||||
input_dilation_,
|
||||
flip_);
|
||||
}
|
||||
// Throw error
|
||||
else {
|
||||
|
@@ -51,11 +51,7 @@ endfunction(build_kernel_base)
|
||||
|
||||
function(build_kernel KERNEL)
|
||||
set(SRCFILE ${CMAKE_CURRENT_SOURCE_DIR}/${KERNEL}.metal)
|
||||
set(HEADERS_PADDED ${HEADERS})
|
||||
if(${KERNEL} STREQUAL "conv")
|
||||
set(HEADERS_PADDED ${HEADERS_PADDED} ${CMAKE_CURRENT_SOURCE_DIR}/conv.h)
|
||||
endif()
|
||||
build_kernel_base(${KERNEL} ${SRCFILE} "${HEADERS_PADDED}")
|
||||
build_kernel_base(${KERNEL} ${SRCFILE} "${HEADERS}")
|
||||
endfunction(build_kernel)
|
||||
|
||||
foreach(KERNEL ${KERNELS})
|
||||
|
@@ -1,481 +0,0 @@
|
||||
// Copyright © 2023 Apple Inc.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <metal_simdgroup>
|
||||
#include <metal_simdgroup_matrix>
|
||||
#include <metal_stdlib>
|
||||
|
||||
#include "mlx/backend/metal/kernels/bf16.h"
|
||||
#include "mlx/backend/metal/kernels/conv_params.h"
|
||||
|
||||
#define MLX_MTL_CONST static constant constexpr const
|
||||
|
||||
using namespace metal;
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// Loading helper
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <
|
||||
typename T,
|
||||
int BM,
|
||||
int BN,
|
||||
int BK,
|
||||
int vec_size,
|
||||
int tgp_size,
|
||||
int tgp_padding = 0>
|
||||
struct Conv2DInputBlockLoader {
|
||||
// Destination dimensions
|
||||
MLX_MTL_CONST int dst_fd = BM;
|
||||
MLX_MTL_CONST int dst_ld = BK + tgp_padding;
|
||||
MLX_MTL_CONST int n_vecs = BK / vec_size;
|
||||
|
||||
// Stride along block row within the block
|
||||
MLX_MTL_CONST int bstride = tgp_size / n_vecs;
|
||||
MLX_MTL_CONST int n_rows = dst_fd / bstride;
|
||||
|
||||
// Thread location indices
|
||||
const short thread_idx;
|
||||
const short bi;
|
||||
const short bj;
|
||||
|
||||
// threadgroup and device memory
|
||||
threadgroup T* dst;
|
||||
const device T* src;
|
||||
|
||||
const constant MLXConvParams<2>& params;
|
||||
|
||||
int weight_h;
|
||||
int weight_w;
|
||||
|
||||
int offsets_n[n_rows];
|
||||
int offsets_oh[n_rows];
|
||||
int offsets_ow[n_rows];
|
||||
|
||||
/* Constructor */
|
||||
METAL_FUNC Conv2DInputBlockLoader(
|
||||
const device T* src_,
|
||||
threadgroup T* dst_,
|
||||
const constant MLXConvParams<2>& params_,
|
||||
uint3 tid [[threadgroup_position_in_grid]],
|
||||
uint3 lid [[thread_position_in_threadgroup]],
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]],
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]])
|
||||
: thread_idx(simd_group_id * 32 + simd_lane_id),
|
||||
bi(thread_idx / n_vecs),
|
||||
bj(vec_size * (thread_idx % n_vecs)),
|
||||
dst(dst_ + bi * dst_ld + bj),
|
||||
src(src_ + bj),
|
||||
params(params_),
|
||||
weight_h(0),
|
||||
weight_w(0) {
|
||||
int out_n_pixels = params.oS[0] * params.oS[1];
|
||||
|
||||
for (int i = 0; i < n_rows; ++i) {
|
||||
int offset_nhw = tid.y * BM + bi + i * bstride;
|
||||
offsets_n[i] = offset_nhw / out_n_pixels;
|
||||
int hw = offset_nhw % out_n_pixels;
|
||||
offsets_oh[i] = hw / params.oS[1];
|
||||
offsets_ow[i] = hw % params.oS[1];
|
||||
}
|
||||
|
||||
(void)lid;
|
||||
}
|
||||
|
||||
/* Load from device memory into threadgroup memory - without bound checking */
|
||||
METAL_FUNC void load_unsafe() const {
|
||||
#pragma clang loop unroll(full)
|
||||
for (short i = 0, is = 0; i < n_rows; ++i, is += bstride) {
|
||||
int n = offsets_n[i];
|
||||
int oh = offsets_oh[i];
|
||||
int ow = offsets_ow[i];
|
||||
|
||||
int ih = oh * params.str[0] - params.pad[0] + weight_h * params.dil[0];
|
||||
int iw = ow * params.str[1] - params.pad[1] + weight_w * params.dil[1];
|
||||
|
||||
// Read from input if in bounds
|
||||
if (ih >= 0 && ih < params.iS[0] && iw >= 0 && iw < params.iS[1]) {
|
||||
const device T* curr_src = src + n * params.in_strides[0] +
|
||||
ih * params.in_strides[1] + iw * params.in_strides[2];
|
||||
|
||||
#pragma clang loop unroll(full)
|
||||
for (short j = 0; j < vec_size; ++j) {
|
||||
dst[is * dst_ld + j] = curr_src[j];
|
||||
}
|
||||
}
|
||||
|
||||
// Zero pad otherwise
|
||||
else {
|
||||
#pragma clang loop unroll(full)
|
||||
for (short j = 0; j < vec_size; ++j) {
|
||||
dst[is * dst_ld + j] = T(0);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/* Iteration helper */
|
||||
METAL_FUNC void next() {
|
||||
if (++weight_w < params.wS[1]) {
|
||||
return;
|
||||
}
|
||||
|
||||
weight_w = 0;
|
||||
|
||||
if (++weight_h < params.wS[0]) {
|
||||
return;
|
||||
}
|
||||
|
||||
weight_h = 0;
|
||||
|
||||
src += BK;
|
||||
}
|
||||
};
|
||||
|
||||
template <
|
||||
typename T,
|
||||
int BM,
|
||||
int BN,
|
||||
int BK,
|
||||
int vec_size,
|
||||
int tgp_size,
|
||||
int tgp_padding = 0>
|
||||
struct Conv2DWeightBlockLoader {
|
||||
// Destination dimensions
|
||||
MLX_MTL_CONST int dst_fd = BN;
|
||||
MLX_MTL_CONST int dst_ld = BK + tgp_padding;
|
||||
MLX_MTL_CONST int n_vecs = BK / vec_size;
|
||||
|
||||
// Stride along block row within the block
|
||||
MLX_MTL_CONST int bstride = tgp_size / n_vecs;
|
||||
MLX_MTL_CONST int n_rows = dst_fd / bstride;
|
||||
|
||||
// Leading dimension for src
|
||||
const int src_ld;
|
||||
|
||||
// Thread location indices
|
||||
const short thread_idx;
|
||||
const short bi;
|
||||
const short bj;
|
||||
|
||||
// threadgroup and device memory
|
||||
threadgroup T* dst;
|
||||
const device T* src;
|
||||
|
||||
const constant MLXConvParams<2>& params;
|
||||
|
||||
int weight_h;
|
||||
int weight_w;
|
||||
|
||||
/* Constructor */
|
||||
METAL_FUNC Conv2DWeightBlockLoader(
|
||||
const device T* src_,
|
||||
threadgroup T* dst_,
|
||||
const constant MLXConvParams<2>& params_,
|
||||
uint3 tid [[threadgroup_position_in_grid]],
|
||||
uint3 lid [[thread_position_in_threadgroup]],
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]],
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]])
|
||||
: src_ld(params_.wt_strides[0]),
|
||||
thread_idx(simd_group_id * 32 + simd_lane_id),
|
||||
bi(thread_idx / n_vecs),
|
||||
bj(vec_size * (thread_idx % n_vecs)),
|
||||
dst(dst_ + bi * dst_ld + bj),
|
||||
src(src_ + bi * src_ld + bj),
|
||||
params(params_),
|
||||
weight_h(0),
|
||||
weight_w(0) {
|
||||
(void)lid;
|
||||
(void)tid;
|
||||
}
|
||||
|
||||
/* Load from device memory into threadgroup memory - without bound checking */
|
||||
METAL_FUNC void load_unsafe() const {
|
||||
const device T* curr_src =
|
||||
src + weight_h * params.wt_strides[1] + weight_w * params.wt_strides[2];
|
||||
#pragma clang loop unroll(full)
|
||||
for (short i = 0; i < dst_fd; i += bstride) {
|
||||
#pragma clang loop unroll(full)
|
||||
for (short j = 0; j < vec_size; j++) {
|
||||
dst[i * dst_ld + j] = curr_src[i * src_ld + j];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/* Iteration helper */
|
||||
METAL_FUNC void next() {
|
||||
if (++weight_w < params.wS[1]) {
|
||||
return;
|
||||
}
|
||||
|
||||
weight_w = 0;
|
||||
|
||||
if (++weight_h < params.wS[0]) {
|
||||
return;
|
||||
}
|
||||
|
||||
weight_h = 0;
|
||||
|
||||
src += BK;
|
||||
}
|
||||
};
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// Transforms
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename OutT, typename InT>
|
||||
struct TransformNone {
|
||||
static METAL_FUNC OutT apply(InT x) {
|
||||
return static_cast<OutT>(x);
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct AccumHelper {
|
||||
typedef float accum_type;
|
||||
};
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// MMA helper
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <
|
||||
typename T,
|
||||
int BM,
|
||||
int BN,
|
||||
int BK,
|
||||
int WM,
|
||||
int WN,
|
||||
bool transpose_a,
|
||||
bool transpose_b,
|
||||
int tgp_padding_a = 0,
|
||||
int tgp_padding_b = 0,
|
||||
typename AccumType = typename AccumHelper<T>::accum_type,
|
||||
typename Epilogue = TransformNone<T, AccumType>>
|
||||
struct Conv2DBlockMMA {
|
||||
// Warp tile size along M
|
||||
MLX_MTL_CONST int TM = BM / (WM * 8);
|
||||
// Warp tile size along N
|
||||
MLX_MTL_CONST int TN = BN / (WN * 8);
|
||||
|
||||
// Warp tile simdgroup matrix strides along M
|
||||
MLX_MTL_CONST int TM_stride = 8 * WM;
|
||||
// Warp tile simdgroup matrix strides along M
|
||||
MLX_MTL_CONST int TN_stride = 8 * WN;
|
||||
|
||||
// Leading dimensions of threadgroup A, B blocks
|
||||
MLX_MTL_CONST int lda_tgp = (transpose_a ? BM : BK) + tgp_padding_a;
|
||||
MLX_MTL_CONST int ldb_tgp = (transpose_b ? BK : BN) + tgp_padding_b;
|
||||
|
||||
// Strides of A, B along reduction axis
|
||||
MLX_MTL_CONST short simd_stride_a =
|
||||
transpose_a ? TM_stride : TM_stride * lda_tgp;
|
||||
MLX_MTL_CONST short simd_stride_b =
|
||||
transpose_b ? TN_stride * ldb_tgp : TN_stride;
|
||||
|
||||
// Jump between elements
|
||||
MLX_MTL_CONST short jump_a = transpose_a ? lda_tgp : 1;
|
||||
MLX_MTL_CONST short jump_b = transpose_b ? ldb_tgp : 1;
|
||||
|
||||
// Offsets within threadgroup
|
||||
const int tm;
|
||||
const int tn;
|
||||
|
||||
// Simdgroup matrices
|
||||
simdgroup_matrix<AccumType, 8, 8> Asimd[TM];
|
||||
simdgroup_matrix<AccumType, 8, 8> Bsimd[TN];
|
||||
simdgroup_matrix<AccumType, 8, 8> results[TM * TN] = {
|
||||
simdgroup_matrix<AccumType, 8, 8>(0)};
|
||||
|
||||
short sm;
|
||||
short sn;
|
||||
|
||||
/* Constructor */
|
||||
METAL_FUNC Conv2DBlockMMA(
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]],
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]])
|
||||
: tm(8 * (simd_group_id / WN)), tn(8 * (simd_group_id % WN)) {
|
||||
short qid = simd_lane_id / 4;
|
||||
sm = (qid & 4) + (simd_lane_id / 2) % 4;
|
||||
sn = (qid & 2) * 2 + (simd_lane_id % 2) * 2;
|
||||
}
|
||||
|
||||
/* (BM, BK) X (BK, BN) multiply accumulate function */
|
||||
METAL_FUNC void mma(const threadgroup T* As, const threadgroup T* Bs) {
|
||||
// Iterate over BK in blocks of 8
|
||||
#pragma clang loop unroll(full)
|
||||
for (short kk = 0; kk < BK; kk += 8) {
|
||||
short2 offset_a =
|
||||
transpose_a ? short2(tm + sm, kk + sn) : short2(kk + sn, tm + sm);
|
||||
short2 offset_b =
|
||||
transpose_b ? short2(kk + sm, tn + sn) : short2(tn + sn, kk + sm);
|
||||
|
||||
const threadgroup T* As__ = As + offset_a.y * lda_tgp + offset_a.x;
|
||||
const threadgroup T* Bs__ = Bs + offset_b.y * ldb_tgp + offset_b.x;
|
||||
|
||||
simdgroup_barrier(mem_flags::mem_none);
|
||||
// Load elements from threadgroup A as simdgroup matrices
|
||||
#pragma clang loop unroll(full)
|
||||
for (short i = 0; i < TM; i++) {
|
||||
Asimd[i].thread_elements()[0] = static_cast<AccumType>(As__[0]);
|
||||
Asimd[i].thread_elements()[1] = static_cast<AccumType>(As__[jump_a]);
|
||||
As__ += simd_stride_a;
|
||||
}
|
||||
|
||||
simdgroup_barrier(mem_flags::mem_none);
|
||||
// Load elements from threadgroup B as simdgroup matrices
|
||||
#pragma clang loop unroll(full)
|
||||
for (short j = 0; j < TN; j++) {
|
||||
Bsimd[j].thread_elements()[0] = static_cast<AccumType>(Bs__[0]);
|
||||
Bsimd[j].thread_elements()[1] = static_cast<AccumType>(Bs__[jump_b]);
|
||||
Bs__ += simd_stride_b;
|
||||
}
|
||||
|
||||
simdgroup_barrier(mem_flags::mem_none);
|
||||
// Multiply and accumulate into result simdgroup matrices
|
||||
#pragma clang loop unroll(full)
|
||||
for (short i = 0; i < TM; i++) {
|
||||
#pragma clang loop unroll(full)
|
||||
for (short j = 0; j < TN; j++) {
|
||||
simdgroup_multiply_accumulate(
|
||||
results[i * TN + j], Asimd[i], Bsimd[j], results[i * TN + j]);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/* Store results from simdgroup_matrix results into device memory */
|
||||
METAL_FUNC void store_result(device T* C, const int ldc) const {
|
||||
#pragma clang loop unroll(full)
|
||||
for (int i = 0; i < TM; i++) {
|
||||
#pragma clang loop unroll(full)
|
||||
for (int j = 0; j < TN; j++) {
|
||||
C[(i * TM_stride + sm + tm) * ldc + j * TN_stride + tn + sn] =
|
||||
Epilogue::apply(results[i * TN + j].thread_elements()[0]);
|
||||
C[(i * TM_stride + sm + tm) * ldc + j * TN_stride + tn + sn + 1] =
|
||||
Epilogue::apply(results[i * TN + j].thread_elements()[1]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
METAL_FUNC void
|
||||
store_result_safe(device T* C, const int ldc, short2 dst_tile_dims) const {
|
||||
#pragma clang loop unroll(full)
|
||||
for (int i = 0; i < TM; i++) {
|
||||
if (tm + i * TM_stride + sm < dst_tile_dims.y) {
|
||||
#pragma clang loop unroll(full)
|
||||
for (int j = 0; j < TN; j++) {
|
||||
if (tn + j * TN_stride + sn < dst_tile_dims.x) {
|
||||
C[(tm + i * TM_stride + sm) * ldc + tn + j * TN_stride + sn] =
|
||||
Epilogue::apply(results[i * TN + j].thread_elements()[0]);
|
||||
}
|
||||
|
||||
if (tn + j * TN_stride + sn + 1 < dst_tile_dims.x) {
|
||||
C[(tm + i * TM_stride + sm) * ldc + tn + j * TN_stride + sn + 1] =
|
||||
Epilogue::apply(results[i * TN + j].thread_elements()[1]);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// GEMM kernels
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <
|
||||
typename T,
|
||||
int BM,
|
||||
int BN,
|
||||
int BK,
|
||||
int WM,
|
||||
int WN,
|
||||
bool transpose_a,
|
||||
bool transpose_b,
|
||||
typename AccumType = typename AccumHelper<T>::accum_type,
|
||||
typename Epilogue = TransformNone<T, AccumType>>
|
||||
struct Conv2DImplicitGEMMKernel {
|
||||
MLX_MTL_CONST short tgp_padding_a = 16 / sizeof(T);
|
||||
MLX_MTL_CONST short tgp_padding_b = 16 / sizeof(T);
|
||||
MLX_MTL_CONST short tgp_mem_size_a =
|
||||
transpose_a ? BK * (BM + tgp_padding_a) : BM * (BK + tgp_padding_a);
|
||||
MLX_MTL_CONST short tgp_mem_size_b =
|
||||
transpose_b ? BN * (BK + tgp_padding_b) : BK * (BN + tgp_padding_b);
|
||||
MLX_MTL_CONST short tgp_mem_size = tgp_mem_size_a + tgp_mem_size_b;
|
||||
|
||||
MLX_MTL_CONST short tgp_size = WM * WN * 32;
|
||||
MLX_MTL_CONST short vec_size = (BM == 64 && BN == 64) ? 8 : 4;
|
||||
|
||||
using loader_a_t =
|
||||
Conv2DInputBlockLoader<T, BM, BN, BK, vec_size, tgp_size, tgp_padding_a>;
|
||||
using loader_b_t =
|
||||
Conv2DWeightBlockLoader<T, BM, BN, BK, vec_size, tgp_size, tgp_padding_b>;
|
||||
using mma_t = Conv2DBlockMMA<
|
||||
T,
|
||||
BM,
|
||||
BN,
|
||||
BK,
|
||||
WM,
|
||||
WN,
|
||||
transpose_a,
|
||||
transpose_b,
|
||||
tgp_padding_a,
|
||||
tgp_padding_b,
|
||||
AccumType,
|
||||
Epilogue>;
|
||||
|
||||
/* Main kernel function */
|
||||
static METAL_FUNC void run(
|
||||
const device T* A [[buffer(0)]],
|
||||
const device T* B [[buffer(1)]],
|
||||
device T* C [[buffer(2)]],
|
||||
const constant MLXConvParams<2>& params [[buffer(3)]],
|
||||
threadgroup T* tgp_memory [[threadgroup(0)]],
|
||||
uint3 tid [[threadgroup_position_in_grid]],
|
||||
uint3 lid [[thread_position_in_threadgroup]],
|
||||
uint simd_gid [[simdgroup_index_in_threadgroup]],
|
||||
uint simd_lid [[thread_index_in_simdgroup]]) {
|
||||
const int c_row = tid.y * BM;
|
||||
const int c_col = tid.x * BN;
|
||||
const int K = params.wt_strides[0];
|
||||
const int N = params.O;
|
||||
|
||||
B += c_col * K;
|
||||
C += c_row * N + c_col;
|
||||
|
||||
// Prepare threadgroup memory for loading
|
||||
threadgroup T* As = tgp_memory;
|
||||
threadgroup T* Bs = tgp_memory + tgp_mem_size_a;
|
||||
|
||||
// Prepare threadgroup loading operations
|
||||
loader_a_t loader_a(A, As, params, tid, lid, simd_gid, simd_lid);
|
||||
loader_b_t loader_b(B, Bs, params, tid, lid, simd_gid, simd_lid);
|
||||
|
||||
// Prepare threadgroup mma operation
|
||||
mma_t mma_op(simd_gid, simd_lid);
|
||||
|
||||
for (int k = 0; k < K; k += BK) {
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
// Load elements into threadgroup
|
||||
loader_a.load_unsafe();
|
||||
loader_b.load_unsafe();
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
// Multiply and accumulate threadgroup elements
|
||||
mma_op.mma(As, Bs);
|
||||
|
||||
// Prepare for next iteration
|
||||
loader_a.next();
|
||||
loader_b.next();
|
||||
}
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_none);
|
||||
|
||||
// Store results to device memory
|
||||
mma_op.store_result(C, N);
|
||||
}
|
||||
};
|
@@ -1,16 +1,102 @@
|
||||
// Copyright © 2023 Apple Inc.
|
||||
// Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
#include <metal_stdlib>
|
||||
#include <metal_simdgroup>
|
||||
#include <metal_simdgroup_matrix>
|
||||
#include <metal_stdlib>
|
||||
|
||||
#include "mlx/backend/metal/kernels/conv_params.h"
|
||||
|
||||
#include "mlx/backend/metal/kernels/steel/conv/params.h"
|
||||
#include "mlx/backend/metal/kernels/bf16.h"
|
||||
|
||||
#include "mlx/backend/metal/kernels/conv.h"
|
||||
#define MLX_MTL_CONST static constant constexpr const
|
||||
|
||||
using namespace metal;
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
/// Slow and naive kernels
|
||||
/// Naive unfold with dilation
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename T, int N>
|
||||
[[kernel]] void naive_unfold_Nd(
|
||||
const device T* in [[buffer(0)]],
|
||||
device T* out [[buffer(1)]],
|
||||
const constant MLXConvParams<N>* params [[buffer(2)]],
|
||||
uint3 gid [[thread_position_in_grid]]) {
|
||||
|
||||
int filter_size = params->C;
|
||||
for(short i = 0; i < N; i++) filter_size *= params->wS[i];
|
||||
|
||||
int out_pixels = 1;
|
||||
for(short i = 0; i < N; i++) out_pixels *= params->oS[i];
|
||||
|
||||
// Set out
|
||||
out += gid.z * filter_size + gid.y * (params->C);
|
||||
|
||||
// Corrdinates in input
|
||||
int is[N] = {0};
|
||||
|
||||
// gid.z: N oS (Batch and row in unfolded output)
|
||||
// gid.y: wS (Filter location to unfold input)
|
||||
// gid.x: C (channel)
|
||||
|
||||
int n = (gid.z) / out_pixels;
|
||||
int oS = (gid.z) % out_pixels;
|
||||
int wS = gid.y;
|
||||
|
||||
bool valid = n < params->N;
|
||||
|
||||
// Unroll dimensions
|
||||
for (int i = N - 1; i >= 0; --i) {
|
||||
int os_ = (oS % params->oS[i]);
|
||||
int ws_ = (wS % params->wS[i]);
|
||||
|
||||
ws_ = params->flip ? params->wS[i] - ws_ - 1 : ws_;
|
||||
|
||||
int is_ = os_ * params->str[i] - params->pad[i] + ws_ * params->kdil[i];
|
||||
int is_max = 1 + params->idil[i] * (params->iS[i] - 1);
|
||||
|
||||
valid &= is_ >= 0 && is_ < is_max && (is_ % params->idil[i] == 0);
|
||||
|
||||
is[i] = is_ / params->idil[i];
|
||||
|
||||
oS /= params->oS[i];
|
||||
wS /= params->wS[i];
|
||||
}
|
||||
|
||||
if(valid) {
|
||||
size_t in_offset = n * params->in_strides[0];
|
||||
|
||||
for(int i = 0; i < N; ++i) {
|
||||
in_offset += is[i] * params->in_strides[i + 1];
|
||||
}
|
||||
|
||||
out[gid.x] = in[in_offset + gid.x];
|
||||
} else {
|
||||
out[gid.x] = T(0);
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
#define instantiate_naive_unfold_nd(name, itype, n) \
|
||||
template [[host_name("naive_unfold_nd_" #name "_" #n)]] \
|
||||
[[kernel]] void naive_unfold_Nd( \
|
||||
const device itype* in [[buffer(0)]], \
|
||||
device itype* out [[buffer(1)]], \
|
||||
const constant MLXConvParams<n>* params [[buffer(2)]], \
|
||||
uint3 gid [[thread_position_in_grid]]);
|
||||
|
||||
#define instantiate_naive_unfold_nd_dims(name, itype) \
|
||||
instantiate_naive_unfold_nd(name, itype, 1) \
|
||||
instantiate_naive_unfold_nd(name, itype, 2) \
|
||||
instantiate_naive_unfold_nd(name, itype, 3)
|
||||
|
||||
instantiate_naive_unfold_nd_dims(float32, float);
|
||||
instantiate_naive_unfold_nd_dims(float16, half);
|
||||
instantiate_naive_unfold_nd_dims(bfloat16, bfloat16_t);
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
/// Slow and naive conv2d kernels
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename T,
|
||||
@@ -58,8 +144,8 @@ template <typename T,
|
||||
|
||||
// Local in
|
||||
for(int m = 0; m < TM; m++) {
|
||||
int i = out_h[m] * params.str[0] - params.pad[0] + h * params.dil[0];
|
||||
int j = out_w[m] * params.str[1] - params.pad[1] + w * params.dil[1];
|
||||
int i = out_h[m] * params.str[0] - params.pad[0] + h * params.kdil[0];
|
||||
int j = out_w[m] * params.str[1] - params.pad[1] + w * params.kdil[1];
|
||||
|
||||
bool valid = i >= 0 && i < params.iS[0] && j >= 0 && j < params.iS[1];
|
||||
in_local[m] = valid ? in[i * params.in_strides[1] + j * params.in_strides[2] + c] : T(0);
|
||||
@@ -116,59 +202,6 @@ instantiate_naive_conv_2d_blocks(float32, float);
|
||||
instantiate_naive_conv_2d_blocks(float16, half);
|
||||
instantiate_naive_conv_2d_blocks(bfloat16, bfloat16_t);
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
/// Implicit gemm kernels
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename T,
|
||||
int BM,
|
||||
int BN,
|
||||
int BK,
|
||||
int WM,
|
||||
int WN>
|
||||
[[kernel, max_total_threads_per_threadgroup(WM * WN * 32)]] void implicit_gemm_conv_2d(
|
||||
const device T* in [[buffer(0)]],
|
||||
const device T* wt [[buffer(1)]],
|
||||
device T* out [[buffer(2)]],
|
||||
const constant MLXConvParams<2>& params [[buffer(3)]],
|
||||
uint3 tid [[threadgroup_position_in_grid]],
|
||||
uint3 lid [[thread_position_in_threadgroup]],
|
||||
uint simd_gid [[simdgroup_index_in_threadgroup]],
|
||||
uint simd_lid [[thread_index_in_simdgroup]]) {
|
||||
|
||||
using gemm_kernel = Conv2DImplicitGEMMKernel<T, BM, BN, BK, WM, WN, /*transpose_a*/ false, /*transpose_b*/ true>;
|
||||
|
||||
threadgroup T tgp_memory[gemm_kernel::tgp_mem_size];
|
||||
|
||||
gemm_kernel::run(
|
||||
in, wt, out,
|
||||
params, tgp_memory,
|
||||
tid, lid, simd_gid, simd_lid
|
||||
);
|
||||
|
||||
}
|
||||
|
||||
#define instantiate_implicit_conv_2d(name, itype, bm, bn, bk, wm, wn) \
|
||||
template [[host_name("implicit_gemm_conv_2d_" #name "_bm" #bm "_bn" #bn "_bk" #bk "_wm" #wm "_wn" #wn)]] \
|
||||
[[kernel]] void implicit_gemm_conv_2d<itype, bm, bn, bk, wm, wn>( \
|
||||
const device itype* in [[buffer(0)]], \
|
||||
const device itype* wt [[buffer(1)]], \
|
||||
device itype* out [[buffer(2)]], \
|
||||
const constant MLXConvParams<2>& params [[buffer(3)]], \
|
||||
uint3 tid [[threadgroup_position_in_grid]], \
|
||||
uint3 lid [[thread_position_in_threadgroup]], \
|
||||
uint simd_gid [[simdgroup_index_in_threadgroup]], \
|
||||
uint simd_lid [[thread_index_in_simdgroup]]);
|
||||
|
||||
#define instantiate_implicit_2d_blocks(name, itype) \
|
||||
instantiate_implicit_conv_2d(name, itype, 32, 32, 32, 2, 2) \
|
||||
instantiate_implicit_conv_2d(name, itype, 32, 32, 16, 2, 2) \
|
||||
instantiate_implicit_conv_2d(name, itype, 64, 64, 16, 2, 2)
|
||||
|
||||
instantiate_implicit_2d_blocks(float32, float);
|
||||
instantiate_implicit_2d_blocks(float16, half);
|
||||
instantiate_implicit_2d_blocks(bfloat16, bfloat16_t);
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
/// Winograd kernels
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
@@ -1,19 +0,0 @@
|
||||
// Copyright © 2023 Apple Inc.
|
||||
|
||||
#pragma once
|
||||
|
||||
template <int NDIM>
|
||||
struct MLXConvParams {
|
||||
const int N; // Batch size
|
||||
const int C; // In channels
|
||||
const int O; // Out channels
|
||||
const int iS[NDIM]; // Input spatial dim
|
||||
const int wS[NDIM]; // Weight spatial dim
|
||||
const int oS[NDIM]; // Output spatial dim
|
||||
const int str[NDIM]; // Kernel strides
|
||||
const int pad[NDIM]; // Input padding
|
||||
const int dil[NDIM]; // Kernel dilation
|
||||
const size_t in_strides[NDIM + 2]; // In strides
|
||||
const size_t wt_strides[NDIM + 2]; // Wt strides
|
||||
const size_t out_strides[NDIM + 2]; // Out strides
|
||||
};
|
11
mlx/backend/metal/kernels/steel/conv/conv.h
Normal file
11
mlx/backend/metal/kernels/steel/conv/conv.h
Normal file
@@ -0,0 +1,11 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "mlx/backend/metal/kernels/steel/utils.h"
|
||||
|
||||
#include "mlx/backend/metal/kernels/steel/conv/loader.h"
|
||||
#include "mlx/backend/metal/kernels/steel/conv/params.h"
|
||||
|
||||
using namespace metal;
|
||||
using namespace mlx::steel;
|
189
mlx/backend/metal/kernels/steel/conv/kernels/steel_conv.metal
Normal file
189
mlx/backend/metal/kernels/steel/conv/kernels/steel_conv.metal
Normal file
@@ -0,0 +1,189 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
#include <metal_stdlib>
|
||||
|
||||
#include "mlx/backend/metal/kernels/steel/gemm/mma.h"
|
||||
|
||||
#include "mlx/backend/metal/kernels/steel/conv/conv.h"
|
||||
#include "mlx/backend/metal/kernels/steel/conv/params.h"
|
||||
#include "mlx/backend/metal/kernels/bf16.h"
|
||||
|
||||
using namespace metal;
|
||||
|
||||
template <typename T,
|
||||
int BM,
|
||||
int BN,
|
||||
int BK,
|
||||
int WM,
|
||||
int WN,
|
||||
int N_CHANNELS = 0,
|
||||
bool SMALL_FILTER = false>
|
||||
[[kernel, max_total_threads_per_threadgroup(WM * WN * 32)]] void implicit_gemm_conv_2d(
|
||||
const device T* A [[buffer(0)]],
|
||||
const device T* B [[buffer(1)]],
|
||||
device T* C [[buffer(2)]],
|
||||
const constant MLXConvParams<2>* params [[buffer(3)]],
|
||||
const constant ImplicitGemmConv2DParams* gemm_params [[buffer(4)]],
|
||||
uint3 tid [[threadgroup_position_in_grid]],
|
||||
uint3 lid [[thread_position_in_threadgroup]],
|
||||
uint simd_gid [[simdgroup_index_in_threadgroup]],
|
||||
uint simd_lid [[thread_index_in_simdgroup]]) {
|
||||
|
||||
|
||||
using namespace mlx::steel;
|
||||
|
||||
(void)lid;
|
||||
|
||||
constexpr bool transpose_a = false;
|
||||
constexpr bool transpose_b = true;
|
||||
constexpr short tgp_padding_a = 16 / sizeof(T);
|
||||
constexpr short tgp_padding_b = 16 / sizeof(T);
|
||||
|
||||
constexpr short shape_a_cols = (transpose_a ? BM : BK) + tgp_padding_a;
|
||||
constexpr short shape_b_cols = (transpose_b ? BK : BN) + tgp_padding_b;
|
||||
constexpr short shape_a_rows = (transpose_a ? BK : BM);
|
||||
constexpr short shape_b_rows = (transpose_b ? BN : BK);
|
||||
constexpr short tgp_mem_size_a = shape_a_cols * shape_a_rows;
|
||||
constexpr short tgp_mem_size_b = shape_b_cols * shape_b_rows;
|
||||
|
||||
constexpr short tgp_size = WM * WN * 32;
|
||||
|
||||
// Input loader
|
||||
|
||||
using loader_a_t = typename metal::conditional_t<
|
||||
// Check for small channel specialization
|
||||
N_CHANNELS != 0 && N_CHANNELS <= 4,
|
||||
|
||||
// Go to small channel specialization
|
||||
Conv2DInputBlockLoaderSmallChannels<
|
||||
T, BM, BN, BK, tgp_size, N_CHANNELS, tgp_padding_a>,
|
||||
|
||||
// Else go to general loader
|
||||
typename metal::conditional_t<
|
||||
// Check if filter size is small enough
|
||||
SMALL_FILTER,
|
||||
|
||||
// Go to small filter specialization
|
||||
Conv2DInputBlockLoaderSmallFilter<
|
||||
T, BM, BN, BK, tgp_size, tgp_padding_a>,
|
||||
|
||||
// Else go to large filter generalization
|
||||
Conv2DInputBlockLoaderLargeFilter<
|
||||
T, BM, BN, BK, tgp_size, tgp_padding_a>
|
||||
>
|
||||
>;
|
||||
|
||||
|
||||
// Weight loader
|
||||
using loader_b_t = typename metal::conditional_t<
|
||||
// Check for small channel specialization
|
||||
N_CHANNELS != 0 && N_CHANNELS <= 4,
|
||||
|
||||
// Go to small channel specialization
|
||||
Conv2DWeightBlockLoaderSmallChannels<
|
||||
T, BM, BN, BK, tgp_size, N_CHANNELS, tgp_padding_b>,
|
||||
|
||||
// Else go to general loader
|
||||
Conv2DWeightBlockLoader<T, BM, BN, BK, tgp_size, tgp_padding_b>
|
||||
>;
|
||||
|
||||
using mma_t = BlockMMA<
|
||||
T,
|
||||
T,
|
||||
BM,
|
||||
BN,
|
||||
BK,
|
||||
WM,
|
||||
WN,
|
||||
transpose_a,
|
||||
transpose_b,
|
||||
shape_a_cols,
|
||||
shape_b_cols>;
|
||||
|
||||
threadgroup T As[tgp_mem_size_a];
|
||||
threadgroup T Bs[tgp_mem_size_b];
|
||||
|
||||
const int tid_y = ((tid.y) << gemm_params->swizzle_log) +
|
||||
((tid.x) & ((1 << gemm_params->swizzle_log) - 1));
|
||||
const int tid_x = (tid.x) >> gemm_params->swizzle_log;
|
||||
|
||||
if (gemm_params->tiles_n <= tid_x || gemm_params->tiles_m <= tid_y) {
|
||||
return;
|
||||
}
|
||||
|
||||
const int c_row = tid_y * BM;
|
||||
const int c_col = tid_x * BN;
|
||||
const int K = gemm_params->K;
|
||||
const int N = gemm_params->N;
|
||||
|
||||
B += c_col * K;
|
||||
C += c_row * N + c_col;
|
||||
|
||||
const int2 offsets_a(0, c_row);
|
||||
const int2 offsets_b(0, c_col);
|
||||
|
||||
// Prepare threadgroup loading operations
|
||||
loader_a_t loader_a(A, As, offsets_a, params, gemm_params, simd_gid, simd_lid);
|
||||
loader_b_t loader_b(B, Bs, offsets_b, params, gemm_params, simd_gid, simd_lid);
|
||||
|
||||
// Prepare threadgroup mma operation
|
||||
mma_t mma_op(simd_gid, simd_lid);
|
||||
|
||||
int gemm_k_iterations = gemm_params->gemm_k_iterations;
|
||||
for (int k = 0; k < gemm_k_iterations; k++) {
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
// Load elements into threadgroup
|
||||
loader_a.load_unsafe();
|
||||
loader_b.load_unsafe();
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
// Multiply and accumulate threadgroup elements
|
||||
mma_op.mma(As, Bs);
|
||||
|
||||
// Prepare for next iteration
|
||||
loader_a.next();
|
||||
loader_b.next();
|
||||
}
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_none);
|
||||
|
||||
// Store results to device memory
|
||||
short tgp_bm = min(BM, gemm_params->M - c_row);
|
||||
short tgp_bn = min(BN, gemm_params->N - c_col);
|
||||
mma_op.store_result_safe(C, N, short2(tgp_bn, tgp_bm));
|
||||
|
||||
}
|
||||
|
||||
#define instantiate_implicit_conv_2d(name, itype, bm, bn, bk, wm, wn, channel_name, n_channels, filter_name, small_filter) \
|
||||
template [[host_name("implicit_gemm_conv_2d_" #name "_bm" #bm "_bn" #bn "_bk" #bk "_wm" #wm "_wn" #wn "_channel_" #channel_name "_filter_" #filter_name)]] \
|
||||
[[kernel]] void implicit_gemm_conv_2d<itype, bm, bn, bk, wm, wn, n_channels, small_filter>( \
|
||||
const device itype* A [[buffer(0)]], \
|
||||
const device itype* B [[buffer(1)]], \
|
||||
device itype* C [[buffer(2)]], \
|
||||
const constant MLXConvParams<2>* params [[buffer(3)]], \
|
||||
const constant ImplicitGemmConv2DParams* gemm_params [[buffer(4)]], \
|
||||
uint3 tid [[threadgroup_position_in_grid]], \
|
||||
uint3 lid [[thread_position_in_threadgroup]], \
|
||||
uint simd_gid [[simdgroup_index_in_threadgroup]], \
|
||||
uint simd_lid [[thread_index_in_simdgroup]]);
|
||||
|
||||
#define instantiate_implicit_2d_filter(name, itype, bm, bn, bk, wm, wn) \
|
||||
instantiate_implicit_conv_2d(name, itype, bm, bn, bk, wm, wn, l, 0, s, true) \
|
||||
instantiate_implicit_conv_2d(name, itype, bm, bn, bk, wm, wn, l, 0, l, false) \
|
||||
instantiate_implicit_conv_2d(name, itype, bm, bn, bk, wm, wn, 1, 1, l, false) \
|
||||
instantiate_implicit_conv_2d(name, itype, bm, bn, bk, wm, wn, 2, 2, l, false) \
|
||||
instantiate_implicit_conv_2d(name, itype, bm, bn, bk, wm, wn, 3, 3, l, false) \
|
||||
instantiate_implicit_conv_2d(name, itype, bm, bn, bk, wm, wn, 4, 4, l, false)
|
||||
|
||||
#define instantiate_implicit_2d_blocks(name, itype) \
|
||||
instantiate_implicit_2d_filter(name, itype, 32, 8, 16, 4, 1) \
|
||||
instantiate_implicit_2d_filter(name, itype, 64, 8, 16, 4, 1) \
|
||||
instantiate_implicit_2d_filter(name, itype, 32, 32, 16, 2, 2) \
|
||||
instantiate_implicit_2d_filter(name, itype, 32, 64, 16, 2, 2) \
|
||||
instantiate_implicit_2d_filter(name, itype, 64, 32, 16, 2, 2) \
|
||||
instantiate_implicit_2d_filter(name, itype, 64, 64, 16, 2, 2)
|
||||
|
||||
instantiate_implicit_2d_blocks(float32, float);
|
||||
instantiate_implicit_2d_blocks(float16, half);
|
||||
instantiate_implicit_2d_blocks(bfloat16, bfloat16_t);
|
@@ -0,0 +1,209 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
#include <metal_stdlib>
|
||||
|
||||
#include "mlx/backend/metal/kernels/steel/gemm/mma.h"
|
||||
|
||||
#include "mlx/backend/metal/kernels/steel/conv/conv.h"
|
||||
#include "mlx/backend/metal/kernels/steel/conv/params.h"
|
||||
#include "mlx/backend/metal/kernels/steel/conv/loaders/loader_general.h"
|
||||
#include "mlx/backend/metal/kernels/bf16.h"
|
||||
|
||||
using namespace metal;
|
||||
using namespace mlx::steel;
|
||||
|
||||
template <typename T,
|
||||
int BM,
|
||||
int BN,
|
||||
int BK,
|
||||
int WM,
|
||||
int WN,
|
||||
typename AccumType = float,
|
||||
typename Epilogue = TransformNone<T, AccumType>>
|
||||
[[kernel, max_total_threads_per_threadgroup(WM * WN * 32)]] void implicit_gemm_conv_2d_general(
|
||||
const device T* A [[buffer(0)]],
|
||||
const device T* B [[buffer(1)]],
|
||||
device T* C [[buffer(2)]],
|
||||
const constant MLXConvParams<2>* params [[buffer(3)]],
|
||||
const constant ImplicitGemmConv2DParams* gemm_params [[buffer(4)]],
|
||||
const constant Conv2DGeneralJumpParams* jump_params [[buffer(5)]],
|
||||
const constant Conv2DGeneralBaseInfo* base_h [[buffer(6)]],
|
||||
const constant Conv2DGeneralBaseInfo* base_w [[buffer(7)]],
|
||||
uint3 tid [[threadgroup_position_in_grid]],
|
||||
uint3 lid [[thread_position_in_threadgroup]],
|
||||
uint simd_gid [[simdgroup_index_in_threadgroup]],
|
||||
uint simd_lid [[thread_index_in_simdgroup]]) {
|
||||
|
||||
(void)lid;
|
||||
|
||||
constexpr bool transpose_a = false;
|
||||
constexpr bool transpose_b = true;
|
||||
constexpr short tgp_padding_a = 16 / sizeof(T);
|
||||
constexpr short tgp_padding_b = 16 / sizeof(T);
|
||||
|
||||
constexpr short shape_a_cols = (transpose_a ? BM : BK) + tgp_padding_a;
|
||||
constexpr short shape_b_cols = (transpose_b ? BK : BN) + tgp_padding_b;
|
||||
constexpr short shape_a_rows = (transpose_a ? BK : BM);
|
||||
constexpr short shape_b_rows = (transpose_b ? BN : BK);
|
||||
constexpr short tgp_mem_size_a = shape_a_cols * shape_a_rows;
|
||||
constexpr short tgp_mem_size_b = shape_b_cols * shape_b_rows;
|
||||
|
||||
constexpr short tgp_size = WM * WN * 32;
|
||||
|
||||
// Input loader
|
||||
using loader_a_t = Conv2DInputBlockLoaderGeneral<
|
||||
T, BM, BN, BK, tgp_size, tgp_padding_a>;
|
||||
|
||||
// Weight loader
|
||||
using loader_b_t = Conv2DWeightBlockLoaderGeneral<
|
||||
T, BM, BN, BK, tgp_size, tgp_padding_b>;
|
||||
|
||||
using mma_t = BlockMMA<
|
||||
T,
|
||||
T,
|
||||
BM,
|
||||
BN,
|
||||
BK,
|
||||
WM,
|
||||
WN,
|
||||
transpose_a,
|
||||
transpose_b,
|
||||
shape_a_cols,
|
||||
shape_b_cols>;
|
||||
|
||||
threadgroup T As[tgp_mem_size_a];
|
||||
threadgroup T Bs[tgp_mem_size_b];
|
||||
|
||||
const int tid_y = ((tid.y) << gemm_params->swizzle_log) +
|
||||
((tid.x) & ((1 << gemm_params->swizzle_log) - 1));
|
||||
const int tid_x = (tid.x) >> gemm_params->swizzle_log;
|
||||
|
||||
if (gemm_params->tiles_n <= tid_x || gemm_params->tiles_m <= tid_y) {
|
||||
return;
|
||||
}
|
||||
|
||||
const int tid_z = tid.z;
|
||||
|
||||
const int base_oh = tid_z / jump_params->f_out_jump_w;
|
||||
const int base_ow = tid_z % jump_params->f_out_jump_w;
|
||||
|
||||
const int base_wh = base_h[base_oh].weight_base;
|
||||
const int base_ww = base_w[base_ow].weight_base;
|
||||
|
||||
const int base_wh_size = base_h[base_oh].weight_size;
|
||||
const int base_ww_size = base_w[base_ow].weight_size;
|
||||
|
||||
const int c_row = tid_y * BM;
|
||||
const int c_col = tid_x * BN;
|
||||
const int K = gemm_params->K;
|
||||
|
||||
B += c_col * K;
|
||||
|
||||
const int4 offsets_a(0, c_row, base_oh, base_ow);
|
||||
const int2 offsets_b(0, c_col);
|
||||
|
||||
// Prepare threadgroup loading operations
|
||||
loader_a_t loader_a(A, As, offsets_a, params, jump_params, base_wh, base_ww, simd_gid, simd_lid);
|
||||
loader_b_t loader_b(B, Bs, offsets_b, params, jump_params, base_wh, base_ww, simd_gid, simd_lid);
|
||||
|
||||
// Prepare threadgroup mma operation
|
||||
mma_t mma_op(simd_gid, simd_lid);
|
||||
|
||||
int gemm_k_iterations = base_wh_size * base_ww_size * gemm_params->gemm_k_iterations;
|
||||
|
||||
for (int k = 0; k < gemm_k_iterations; k++) {
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
// Load elements into threadgroup
|
||||
loader_a.load_unsafe();
|
||||
loader_b.load_unsafe();
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
// Multiply and accumulate threadgroup elements
|
||||
mma_op.mma(As, Bs);
|
||||
|
||||
// Prepare for next iteration
|
||||
loader_a.next();
|
||||
loader_b.next();
|
||||
}
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_none);
|
||||
|
||||
// Store results to device memory
|
||||
{
|
||||
// Adjust for simdgroup and thread locatio
|
||||
int offset_m = c_row + mma_op.sm + mma_op.tm;
|
||||
int offset_n = c_col + mma_op.sn + mma_op.tn;
|
||||
C += offset_n;
|
||||
|
||||
if (offset_n >= gemm_params->N)
|
||||
return;
|
||||
|
||||
short diff = gemm_params->N - offset_n;
|
||||
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (int i = 0; i < mma_t::TM; i++) {
|
||||
|
||||
int cm = offset_m + i * mma_t::TM_stride;
|
||||
|
||||
int n = cm / jump_params->adj_out_hw;
|
||||
int hw = cm % jump_params->adj_out_hw;
|
||||
int oh = (hw / jump_params->adj_out_w) * jump_params->f_out_jump_h + base_oh;
|
||||
int ow = (hw % jump_params->adj_out_w) * jump_params->f_out_jump_w + base_ow;
|
||||
|
||||
if(n < params->N && oh < params->oS[0] && ow < params->oS[1]) {
|
||||
|
||||
int offset_cm = n * params->out_strides[0] + oh * params->out_strides[1] + ow * params->out_strides[2];
|
||||
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (int j = 0; j < mma_t::TN; j++) {
|
||||
// Get accumulated result and associated offset in C
|
||||
thread const auto& accum = mma_op.results[i * mma_t::TN + j].thread_elements();
|
||||
int offset = offset_cm + (j * mma_t::TN_stride);
|
||||
|
||||
// Apply epilogue and output C
|
||||
if (j * mma_t::TN_stride < diff) {
|
||||
C[offset] = Epilogue::apply(accum[0]);
|
||||
}
|
||||
|
||||
if (j * mma_t::TN_stride + 1 < diff) {
|
||||
C[offset + 1] = Epilogue::apply(accum[1]);
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
#define instantiate_implicit_conv_2d(name, itype, bm, bn, bk, wm, wn) \
|
||||
template [[host_name("implicit_gemm_conv_2d_general_" #name "_bm" #bm "_bn" #bn "_bk" #bk "_wm" #wm "_wn" #wn)]] \
|
||||
[[kernel]] void implicit_gemm_conv_2d_general<itype, bm, bn, bk, wm, wn>( \
|
||||
const device itype* A [[buffer(0)]], \
|
||||
const device itype* B [[buffer(1)]], \
|
||||
device itype* C [[buffer(2)]], \
|
||||
const constant MLXConvParams<2>* params [[buffer(3)]], \
|
||||
const constant ImplicitGemmConv2DParams* gemm_params [[buffer(4)]], \
|
||||
const constant Conv2DGeneralJumpParams* jump_params [[buffer(5)]], \
|
||||
const constant Conv2DGeneralBaseInfo* base_h [[buffer(6)]], \
|
||||
const constant Conv2DGeneralBaseInfo* base_w [[buffer(7)]], \
|
||||
uint3 tid [[threadgroup_position_in_grid]], \
|
||||
uint3 lid [[thread_position_in_threadgroup]], \
|
||||
uint simd_gid [[simdgroup_index_in_threadgroup]], \
|
||||
uint simd_lid [[thread_index_in_simdgroup]]);
|
||||
|
||||
#define instantiate_implicit_2d_filter(name, itype, bm, bn, bk, wm, wn) \
|
||||
instantiate_implicit_conv_2d(name, itype, bm, bn, bk, wm, wn)
|
||||
|
||||
#define instantiate_implicit_2d_blocks(name, itype) \
|
||||
instantiate_implicit_2d_filter(name, itype, 32, 8, 16, 4, 1) \
|
||||
instantiate_implicit_2d_filter(name, itype, 64, 8, 16, 4, 1) \
|
||||
instantiate_implicit_2d_filter(name, itype, 32, 32, 16, 2, 2) \
|
||||
instantiate_implicit_2d_filter(name, itype, 32, 64, 16, 2, 2) \
|
||||
instantiate_implicit_2d_filter(name, itype, 64, 32, 16, 2, 2) \
|
||||
instantiate_implicit_2d_filter(name, itype, 64, 64, 16, 2, 2)
|
||||
|
||||
instantiate_implicit_2d_blocks(float32, float);
|
||||
instantiate_implicit_2d_blocks(float16, half);
|
||||
instantiate_implicit_2d_blocks(bfloat16, bfloat16_t);
|
6
mlx/backend/metal/kernels/steel/conv/loader.h
Normal file
6
mlx/backend/metal/kernels/steel/conv/loader.h
Normal file
@@ -0,0 +1,6 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "mlx/backend/metal/kernels/steel/conv/loaders/loader_channel_l.h"
|
||||
#include "mlx/backend/metal/kernels/steel/conv/loaders/loader_channel_n.h"
|
449
mlx/backend/metal/kernels/steel/conv/loaders/loader_channel_l.h
Normal file
449
mlx/backend/metal/kernels/steel/conv/loaders/loader_channel_l.h
Normal file
@@ -0,0 +1,449 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "mlx/backend/metal/kernels/steel/utils.h"
|
||||
|
||||
#include "mlx/backend/metal/kernels/steel/conv/params.h"
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// Loading helper
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace mlx {
|
||||
namespace steel {
|
||||
|
||||
template <
|
||||
typename T,
|
||||
short BM,
|
||||
short BN,
|
||||
short BK,
|
||||
short tgp_size,
|
||||
short tgp_padding = 0>
|
||||
struct Conv2DInputBlockLoaderLargeFilter {
|
||||
// Destination dimensions
|
||||
STEEL_CONST short BROWS = BM;
|
||||
STEEL_CONST short BCOLS = BK;
|
||||
|
||||
// Read dimensions
|
||||
STEEL_CONST short dst_ld = BCOLS + tgp_padding;
|
||||
STEEL_CONST short vec_size = tgp_size / (BROWS * BCOLS) >= 8 ? 8 : 4;
|
||||
|
||||
// Thread read shape
|
||||
STEEL_CONST short TCOLS = BCOLS / vec_size;
|
||||
STEEL_CONST short TROWS = tgp_size / TCOLS;
|
||||
|
||||
// Rows / strided reads within the block
|
||||
STEEL_CONST short n_rows = BROWS / TROWS;
|
||||
|
||||
// Thread location indices
|
||||
const short thread_idx;
|
||||
const short bi;
|
||||
const short bj;
|
||||
|
||||
// threadgroup and device memory
|
||||
threadgroup T* dst;
|
||||
|
||||
const constant MLXConvParams<2>* params;
|
||||
const constant ImplicitGemmConv2DParams* gemm_params;
|
||||
|
||||
short weight_h;
|
||||
short weight_w;
|
||||
|
||||
const device T* src[n_rows];
|
||||
|
||||
int read_n[n_rows];
|
||||
int read_ih[n_rows];
|
||||
int read_iw[n_rows];
|
||||
|
||||
/* Constructor */
|
||||
METAL_FUNC Conv2DInputBlockLoaderLargeFilter(
|
||||
const device T* src_,
|
||||
threadgroup T* dst_,
|
||||
const int2 offsets,
|
||||
const constant MLXConvParams<2>* params_,
|
||||
const constant ImplicitGemmConv2DParams* gemm_params_,
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]],
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]])
|
||||
: thread_idx(simd_group_id * 32 + simd_lane_id),
|
||||
bi(thread_idx / TCOLS),
|
||||
bj(vec_size * (thread_idx % TCOLS)),
|
||||
dst(dst_ + bi * dst_ld + bj),
|
||||
params(params_),
|
||||
gemm_params(gemm_params_),
|
||||
weight_h(0),
|
||||
weight_w(0) {
|
||||
int out_n_pixels = params->oS[0] * params->oS[1];
|
||||
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short i = 0; i < n_rows; ++i) {
|
||||
int offset_nhw = offsets.y + bi + i * TROWS;
|
||||
int n = offset_nhw / out_n_pixels;
|
||||
int hw = offset_nhw % out_n_pixels;
|
||||
int oh = hw / params->oS[1];
|
||||
int ow = hw % params->oS[1];
|
||||
|
||||
int ih = oh * params->str[0] - params->pad[0];
|
||||
int iw = ow * params->str[1] - params->pad[1];
|
||||
|
||||
read_n[i] = n;
|
||||
read_ih[i] = ih;
|
||||
read_iw[i] = iw;
|
||||
|
||||
// Adjust for flip
|
||||
if (params->flip) {
|
||||
ih += (params->wS[0] - 1) * params->kdil[0];
|
||||
iw += (params->wS[1] - 1) * params->kdil[1];
|
||||
}
|
||||
|
||||
// Read from input if in bounds
|
||||
src[i] = src_ + n * params->in_strides[0] + ih * params->in_strides[1] +
|
||||
iw * params->in_strides[2] + bj;
|
||||
}
|
||||
}
|
||||
|
||||
/* Load from device memory into threadgroup memory - without bound checking */
|
||||
METAL_FUNC void load_unsafe() const {
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short i = 0, is = 0; i < n_rows; ++i, is += TROWS) {
|
||||
// Find bounds
|
||||
int n = read_n[i];
|
||||
int ih = read_ih[i] + weight_h * params->kdil[0];
|
||||
int iw = read_iw[i] + weight_w * params->kdil[1];
|
||||
|
||||
// Read from input if in bounds
|
||||
if ((n < params->N) && (ih >= 0 && ih < params->iS[0]) &&
|
||||
(iw >= 0 && iw < params->iS[1])) {
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short j = 0; j < vec_size; ++j) {
|
||||
dst[is * dst_ld + j] = src[i][j];
|
||||
}
|
||||
}
|
||||
|
||||
// Zero pad otherwise
|
||||
else {
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short j = 0; j < vec_size; ++j) {
|
||||
dst[is * dst_ld + j] = T(0);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/* Iteration helper */
|
||||
METAL_FUNC void next() {
|
||||
if (++weight_w < params->wS[1]) {
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short i = 0; i < n_rows; i++) {
|
||||
src[i] += gemm_params->inp_jump_w;
|
||||
}
|
||||
|
||||
return;
|
||||
}
|
||||
|
||||
weight_w = 0;
|
||||
|
||||
if (++weight_h < params->wS[0]) {
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short i = 0; i < n_rows; i++) {
|
||||
src[i] += gemm_params->inp_jump_h;
|
||||
}
|
||||
|
||||
return;
|
||||
}
|
||||
|
||||
weight_h = 0;
|
||||
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short i = 0; i < n_rows; i++) {
|
||||
src[i] += gemm_params->inp_jump_c;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template <
|
||||
typename T,
|
||||
short BM,
|
||||
short BN,
|
||||
short BK,
|
||||
short tgp_size,
|
||||
short tgp_padding = 0>
|
||||
struct Conv2DInputBlockLoaderSmallFilter {
|
||||
// Destination dimensions
|
||||
STEEL_CONST short BROWS = BM;
|
||||
STEEL_CONST short BCOLS = BK;
|
||||
|
||||
// Read dimensions
|
||||
STEEL_CONST short dst_ld = BCOLS + tgp_padding;
|
||||
STEEL_CONST short vec_size = tgp_size / (BROWS * BCOLS) >= 8 ? 8 : 4;
|
||||
|
||||
// Thread read shape
|
||||
STEEL_CONST short TCOLS = BCOLS / vec_size;
|
||||
STEEL_CONST short TROWS = tgp_size / TCOLS;
|
||||
|
||||
// Rows / strided reads within the block
|
||||
STEEL_CONST short n_rows = BROWS / TROWS;
|
||||
|
||||
using mask_t = short;
|
||||
|
||||
// Thread location indices
|
||||
const short thread_idx;
|
||||
const short bi;
|
||||
const short bj;
|
||||
|
||||
// threadgroup and device memory
|
||||
threadgroup T* dst;
|
||||
|
||||
const constant MLXConvParams<2>* params;
|
||||
const constant ImplicitGemmConv2DParams* gemm_params;
|
||||
|
||||
short weight_h;
|
||||
short weight_w;
|
||||
|
||||
const device T* src[n_rows];
|
||||
|
||||
mask_t mask_h[n_rows];
|
||||
mask_t mask_w[n_rows];
|
||||
|
||||
/* Constructor */
|
||||
METAL_FUNC Conv2DInputBlockLoaderSmallFilter(
|
||||
const device T* src_,
|
||||
threadgroup T* dst_,
|
||||
const int2 offsets,
|
||||
const constant MLXConvParams<2>* params_,
|
||||
const constant ImplicitGemmConv2DParams* gemm_params_,
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]],
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]])
|
||||
: thread_idx(simd_group_id * 32 + simd_lane_id),
|
||||
bi(thread_idx / TCOLS),
|
||||
bj(vec_size * (thread_idx % TCOLS)),
|
||||
dst(dst_ + bi * dst_ld + bj),
|
||||
params(params_),
|
||||
gemm_params(gemm_params_),
|
||||
weight_h(0),
|
||||
weight_w(0) {
|
||||
int out_n_pixels = params->oS[0] * params->oS[1];
|
||||
|
||||
int read_n[n_rows];
|
||||
int read_ih[n_rows];
|
||||
int read_iw[n_rows];
|
||||
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short i = 0; i < n_rows; ++i) {
|
||||
int offset_nhw = offsets.y + bi + i * TROWS;
|
||||
int n = offset_nhw / out_n_pixels;
|
||||
int hw = offset_nhw % out_n_pixels;
|
||||
int oh = hw / params->oS[1];
|
||||
int ow = hw % params->oS[1];
|
||||
|
||||
int ih = oh * params->str[0] - params->pad[0];
|
||||
int iw = ow * params->str[1] - params->pad[1];
|
||||
|
||||
read_n[i] = n;
|
||||
read_ih[i] = ih;
|
||||
read_iw[i] = iw;
|
||||
|
||||
// Adjust for flip
|
||||
if (params->flip) {
|
||||
ih += (params->wS[0] - 1) * params->kdil[0];
|
||||
iw += (params->wS[1] - 1) * params->kdil[1];
|
||||
}
|
||||
|
||||
// Read from input if in bounds
|
||||
src[i] = src_ + n * params->in_strides[0] + ih * params->in_strides[1] +
|
||||
iw * params->in_strides[2] + bj;
|
||||
}
|
||||
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short i = 0; i < n_rows; ++i) {
|
||||
mask_h[i] = 0;
|
||||
mask_w[i] = 0;
|
||||
}
|
||||
|
||||
for (short kh = 0; kh < params->wS[0]; kh++) {
|
||||
short flip_h = params->flip ? params->wS[0] - kh - 1 : kh;
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short i = 0; i < n_rows; ++i) {
|
||||
int n = read_n[i];
|
||||
int ih = read_ih[i] + flip_h * params->kdil[0];
|
||||
|
||||
bool in_bounds = n < params->N && ih >= 0 && ih < params->iS[0];
|
||||
|
||||
mask_h[i] |= (in_bounds << kh);
|
||||
}
|
||||
}
|
||||
|
||||
for (short kw = 0; kw < params->wS[1]; kw++) {
|
||||
short flip_w = params->flip ? params->wS[1] - kw - 1 : kw;
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short i = 0; i < n_rows; ++i) {
|
||||
int iw = read_iw[i] + flip_w * params->kdil[1];
|
||||
|
||||
bool in_bounds = iw >= 0 && iw < params->iS[1];
|
||||
|
||||
mask_w[i] |= (in_bounds << kw);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/* Load from device memory into threadgroup memory - without bound checking */
|
||||
METAL_FUNC void load_unsafe() const {
|
||||
mask_t h_mask = mask_t(1) << weight_h;
|
||||
mask_t w_mask = mask_t(1) << weight_w;
|
||||
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short i = 0, is = 0; i < n_rows; ++i, is += TROWS) {
|
||||
// Read from input if in bounds
|
||||
if ((mask_h[i] & h_mask) && (mask_w[i] & w_mask)) {
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short j = 0; j < vec_size; ++j) {
|
||||
dst[is * dst_ld + j] = src[i][j];
|
||||
}
|
||||
}
|
||||
|
||||
// Zero pad otherwise
|
||||
else {
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short j = 0; j < vec_size; ++j) {
|
||||
dst[is * dst_ld + j] = T(0);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/* Iteration helper */
|
||||
METAL_FUNC void next() {
|
||||
if (++weight_w < params->wS[1]) {
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short i = 0; i < n_rows; i++) {
|
||||
src[i] += gemm_params->inp_jump_w;
|
||||
}
|
||||
|
||||
return;
|
||||
}
|
||||
|
||||
weight_w = 0;
|
||||
|
||||
if (++weight_h < params->wS[0]) {
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short i = 0; i < n_rows; i++) {
|
||||
src[i] += gemm_params->inp_jump_h;
|
||||
}
|
||||
|
||||
return;
|
||||
}
|
||||
|
||||
weight_h = 0;
|
||||
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short i = 0; i < n_rows; i++) {
|
||||
src[i] += gemm_params->inp_jump_c;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template <
|
||||
typename T,
|
||||
short BM,
|
||||
short BN,
|
||||
short BK,
|
||||
short tgp_size,
|
||||
short tgp_padding = 0>
|
||||
struct Conv2DWeightBlockLoader {
|
||||
// Destination dimensions
|
||||
STEEL_CONST short BROWS = BN;
|
||||
STEEL_CONST short BCOLS = BK;
|
||||
|
||||
// Read dimensions
|
||||
STEEL_CONST short dst_ld = BCOLS + tgp_padding;
|
||||
STEEL_CONST short vec_size =
|
||||
(BN == 8) ? 1 : (tgp_size / (BROWS * BCOLS) >= 8 ? 8 : 4);
|
||||
|
||||
// Thread read shape
|
||||
STEEL_CONST short TCOLS = BCOLS / vec_size;
|
||||
STEEL_CONST short TROWS = tgp_size / TCOLS;
|
||||
|
||||
// Rows / strided reads within the block
|
||||
STEEL_CONST short n_rows = BROWS / TROWS;
|
||||
|
||||
// Leading dimension for src
|
||||
const int src_ld;
|
||||
|
||||
// Thread location indices
|
||||
const short thread_idx;
|
||||
const short bi;
|
||||
const short bj;
|
||||
|
||||
// threadgroup and device memory
|
||||
threadgroup T* dst;
|
||||
const device T* src;
|
||||
|
||||
const constant MLXConvParams<2>* params;
|
||||
|
||||
int weight_hw;
|
||||
|
||||
const int read_n;
|
||||
const bool do_read;
|
||||
|
||||
/* Constructor */
|
||||
METAL_FUNC Conv2DWeightBlockLoader(
|
||||
const device T* src_,
|
||||
threadgroup T* dst_,
|
||||
const int2 offsets,
|
||||
const constant MLXConvParams<2>* params_,
|
||||
const constant ImplicitGemmConv2DParams* gemm_params_,
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]],
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]])
|
||||
: src_ld(params_->wt_strides[0]),
|
||||
thread_idx(simd_group_id * 32 + simd_lane_id),
|
||||
bi(thread_idx / TCOLS),
|
||||
bj(vec_size * (thread_idx % TCOLS)),
|
||||
dst(dst_ + bi * dst_ld + bj),
|
||||
src(src_ + bi * src_ld + bj),
|
||||
params(params_),
|
||||
weight_hw(0),
|
||||
read_n(offsets.y + bi),
|
||||
do_read(read_n + n_rows * TROWS <= gemm_params_->N) {}
|
||||
|
||||
/* Load from device memory into threadgroup memory - without bound checking */
|
||||
METAL_FUNC void load_unsafe() const {
|
||||
if (BN != 8 || do_read) {
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short i = 0; i < BN; i += TROWS) {
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short j = 0; j < vec_size; j++) {
|
||||
dst[i * dst_ld + j] = src[i * src_ld + j];
|
||||
}
|
||||
}
|
||||
} else {
|
||||
for (short i = 0; i < BN; i += TROWS) {
|
||||
if ((read_n + i) < params->O) {
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short j = 0; j < vec_size; j++) {
|
||||
dst[i * dst_ld + j] = src[i * src_ld + j];
|
||||
}
|
||||
} else {
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short j = 0; j < vec_size; j++) {
|
||||
dst[i * dst_ld + j] = T(0);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/* Iteration helper */
|
||||
METAL_FUNC void next() {
|
||||
if (++weight_hw < (params->wS[1] * params->wS[0])) {
|
||||
src += params->wt_strides[2];
|
||||
return;
|
||||
}
|
||||
|
||||
weight_hw = 0;
|
||||
|
||||
src += BK - (params->wS[1] * params->wS[0] - 1) * params->wt_strides[2];
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace steel
|
||||
} // namespace mlx
|
319
mlx/backend/metal/kernels/steel/conv/loaders/loader_channel_n.h
Normal file
319
mlx/backend/metal/kernels/steel/conv/loaders/loader_channel_n.h
Normal file
@@ -0,0 +1,319 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "mlx/backend/metal/kernels/steel/utils.h"
|
||||
|
||||
#include "mlx/backend/metal/kernels/steel/conv/params.h"
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// Loading helper
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace mlx {
|
||||
namespace steel {
|
||||
|
||||
template <short n_channels_>
|
||||
struct ChannelHelper {
|
||||
STEEL_CONST short n_channels = n_channels_;
|
||||
STEEL_CONST short vec_size = n_channels_ <= 4 ? 4 : 8;
|
||||
STEEL_CONST short excess = vec_size - n_channels_;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct ChannelHelper<1> {
|
||||
STEEL_CONST short n_channels = 1;
|
||||
STEEL_CONST short vec_size = 1;
|
||||
STEEL_CONST short excess = 0;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct ChannelHelper<2> {
|
||||
STEEL_CONST short n_channels = 2;
|
||||
STEEL_CONST short vec_size = 2;
|
||||
STEEL_CONST short excess = 0;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct ChannelHelper<3> {
|
||||
STEEL_CONST short n_channels = 3;
|
||||
STEEL_CONST short vec_size = 4;
|
||||
STEEL_CONST short excess = 1;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct ChannelHelper<4> {
|
||||
STEEL_CONST short n_channels = 4;
|
||||
STEEL_CONST short vec_size = 4;
|
||||
STEEL_CONST short excess = 0;
|
||||
};
|
||||
|
||||
template <
|
||||
typename T,
|
||||
short BM,
|
||||
short BN,
|
||||
short BK,
|
||||
short tgp_size,
|
||||
short n_channels,
|
||||
short tgp_padding = 0>
|
||||
struct Conv2DInputBlockLoaderSmallChannels {
|
||||
// Destination dimensions
|
||||
STEEL_CONST short BROWS = BM;
|
||||
STEEL_CONST short BCOLS = BK;
|
||||
|
||||
// Read dimensions
|
||||
STEEL_CONST short dst_ld = BCOLS + tgp_padding;
|
||||
STEEL_CONST short vec_size = ChannelHelper<n_channels>::vec_size;
|
||||
|
||||
// Thread read shape
|
||||
STEEL_CONST short TCOLS = BCOLS / vec_size;
|
||||
STEEL_CONST short TROWS = tgp_size / TCOLS;
|
||||
|
||||
// Rows / strided reads within the block
|
||||
STEEL_CONST short n_rows = BROWS / TROWS;
|
||||
|
||||
// Thread location indices
|
||||
const short thread_idx;
|
||||
const short bi;
|
||||
const short bj;
|
||||
|
||||
// threadgroup and device memory
|
||||
threadgroup T* dst;
|
||||
|
||||
const constant MLXConvParams<2>* params;
|
||||
const constant ImplicitGemmConv2DParams* gemm_params;
|
||||
|
||||
short weight_hw;
|
||||
|
||||
const device T* src[n_rows];
|
||||
|
||||
int read_n[n_rows];
|
||||
int read_ih[n_rows];
|
||||
int read_iw[n_rows];
|
||||
|
||||
/* Constructor */
|
||||
METAL_FUNC Conv2DInputBlockLoaderSmallChannels(
|
||||
const device T* src_,
|
||||
threadgroup T* dst_,
|
||||
const int2 offsets,
|
||||
const constant MLXConvParams<2>* params_,
|
||||
const constant ImplicitGemmConv2DParams* gemm_params_,
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]],
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]])
|
||||
: thread_idx(simd_group_id * 32 + simd_lane_id),
|
||||
bi(thread_idx / TCOLS),
|
||||
bj(vec_size * (thread_idx % TCOLS)),
|
||||
dst(dst_ + bi * dst_ld + bj),
|
||||
params(params_),
|
||||
gemm_params(gemm_params_),
|
||||
weight_hw(thread_idx % TCOLS) {
|
||||
int out_n_pixels = params->oS[0] * params->oS[1];
|
||||
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short i = 0; i < n_rows; ++i) {
|
||||
int offset_nhw = offsets.y + bi + i * TROWS;
|
||||
int n = offset_nhw / out_n_pixels;
|
||||
int hw = offset_nhw % out_n_pixels;
|
||||
int oh = hw / params->oS[1];
|
||||
int ow = hw % params->oS[1];
|
||||
|
||||
int ih = oh * params->str[0] - params->pad[0];
|
||||
int iw = ow * params->str[1] - params->pad[1];
|
||||
|
||||
// Read from input if in bounds
|
||||
src[i] = src_ + n * params->in_strides[0] + ih * params->in_strides[1] +
|
||||
iw * params->in_strides[2];
|
||||
|
||||
read_n[i] = n;
|
||||
read_ih[i] = ih;
|
||||
read_iw[i] = iw;
|
||||
}
|
||||
}
|
||||
|
||||
/* Load from device memory into threadgroup memory - without bound checking */
|
||||
METAL_FUNC void load_unsafe() const {
|
||||
if (weight_hw >= params->wS[1] * params->wS[0]) {
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short i = 0; i < BROWS; i += TROWS) {
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short j = 0; j < vec_size; j++) {
|
||||
dst[i * dst_ld + j] = T(0);
|
||||
}
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
int wh = (weight_hw / params->wS[1]);
|
||||
int ww = (weight_hw % params->wS[1]);
|
||||
|
||||
int flip_h = params->flip ? params->wS[0] - wh - 1 : wh;
|
||||
int flip_w = params->flip ? params->wS[1] - ww - 1 : ww;
|
||||
|
||||
int weight_h = flip_h * params->kdil[0];
|
||||
int weight_w = flip_w * params->kdil[1];
|
||||
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short i = 0, is = 0; i < n_rows; ++i, is += TROWS) {
|
||||
// Find bounds
|
||||
int n = read_n[i];
|
||||
int ih = read_ih[i] + weight_h;
|
||||
int iw = read_iw[i] + weight_w;
|
||||
|
||||
// Read from input if in bounds
|
||||
if ((n < params->N) && (ih >= 0 && ih < params->iS[0]) &&
|
||||
(iw >= 0 && iw < params->iS[1])) {
|
||||
const device T* curr_src = src[i] + weight_h * params->in_strides[1] +
|
||||
weight_w * params->in_strides[2];
|
||||
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short j = 0; j < n_channels; ++j) {
|
||||
dst[is * dst_ld + j] = curr_src[j];
|
||||
}
|
||||
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short j = n_channels; j < vec_size; ++j) {
|
||||
dst[is * dst_ld + j] = T(0);
|
||||
}
|
||||
}
|
||||
|
||||
// Zero pad otherwise
|
||||
else {
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short j = 0; j < vec_size; ++j) {
|
||||
dst[is * dst_ld + j] = T(0);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/* Iteration helper */
|
||||
METAL_FUNC void next() {
|
||||
weight_hw += TCOLS;
|
||||
}
|
||||
};
|
||||
|
||||
template <
|
||||
typename T,
|
||||
short BM,
|
||||
short BN,
|
||||
short BK,
|
||||
short tgp_size,
|
||||
short n_channels,
|
||||
short tgp_padding = 0>
|
||||
struct Conv2DWeightBlockLoaderSmallChannels {
|
||||
// Destination dimensions
|
||||
STEEL_CONST short BROWS = BN;
|
||||
STEEL_CONST short BCOLS = BK;
|
||||
|
||||
// Read dimensions
|
||||
STEEL_CONST short dst_ld = BCOLS + tgp_padding;
|
||||
STEEL_CONST short vec_size = ChannelHelper<n_channels>::vec_size;
|
||||
|
||||
// Thread read shape
|
||||
STEEL_CONST short TCOLS = BCOLS / vec_size;
|
||||
STEEL_CONST short TROWS = tgp_size / TCOLS;
|
||||
|
||||
// Rows / strided reads within the block
|
||||
STEEL_CONST short n_rows = BROWS / TROWS;
|
||||
|
||||
// Leading dimension for src
|
||||
const int src_ld;
|
||||
|
||||
// Thread location indices
|
||||
const short thread_idx;
|
||||
const short bi;
|
||||
const short bj;
|
||||
|
||||
// threadgroup and device memory
|
||||
threadgroup T* dst;
|
||||
const device T* src;
|
||||
|
||||
const constant MLXConvParams<2>* params;
|
||||
|
||||
int weight_hw;
|
||||
|
||||
const int read_n;
|
||||
const bool do_read;
|
||||
|
||||
/* Constructor */
|
||||
METAL_FUNC Conv2DWeightBlockLoaderSmallChannels(
|
||||
const device T* src_,
|
||||
threadgroup T* dst_,
|
||||
const int2 offsets,
|
||||
const constant MLXConvParams<2>* params_,
|
||||
const constant ImplicitGemmConv2DParams* gemm_params_,
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]],
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]])
|
||||
: src_ld(params_->wt_strides[0]),
|
||||
thread_idx(simd_group_id * 32 + simd_lane_id),
|
||||
bi(thread_idx / TCOLS),
|
||||
bj(vec_size * (thread_idx % TCOLS)),
|
||||
dst(dst_ + bi * dst_ld + bj),
|
||||
src(src_ + bi * src_ld),
|
||||
params(params_),
|
||||
weight_hw(thread_idx % TCOLS),
|
||||
read_n(offsets.y + bi),
|
||||
do_read(read_n + BN <= gemm_params_->N) {}
|
||||
|
||||
/* Load from device memory into threadgroup memory - without bound checking */
|
||||
METAL_FUNC void load_unsafe() const {
|
||||
if (bi >= BROWS || bj >= BCOLS)
|
||||
return;
|
||||
|
||||
if (read_n >= params->O || weight_hw >= params->wS[1] * params->wS[0]) {
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short i = 0; i < BROWS; i += TROWS) {
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short j = 0; j < vec_size; j++) {
|
||||
dst[i * dst_ld + j] = T(0);
|
||||
}
|
||||
}
|
||||
|
||||
return;
|
||||
}
|
||||
|
||||
const device T* curr_src = src + weight_hw * params->wt_strides[2];
|
||||
|
||||
if (BN != 8 || do_read) {
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short i = 0; i < BROWS; i += TROWS) {
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short j = 0; j < n_channels; j++) {
|
||||
dst[i * dst_ld + j] = curr_src[i * src_ld + j];
|
||||
}
|
||||
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short j = n_channels; j < vec_size; j++) {
|
||||
dst[i * dst_ld + j] = T(0);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
for (short i = 0; i < BROWS; i += TROWS) {
|
||||
if (((read_n + i) < params->O)) {
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short j = 0; j < n_channels; j++) {
|
||||
dst[i * dst_ld + j] = curr_src[i * src_ld + j];
|
||||
}
|
||||
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short j = n_channels; j < vec_size; j++) {
|
||||
dst[i * dst_ld + j] = T(0);
|
||||
}
|
||||
} else {
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short j = 0; j < vec_size; j++) {
|
||||
dst[i * dst_ld + j] = T(0);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/* Iteration helper */
|
||||
METAL_FUNC void next() {
|
||||
weight_hw += TCOLS;
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace steel
|
||||
} // namespace mlx
|
288
mlx/backend/metal/kernels/steel/conv/loaders/loader_general.h
Normal file
288
mlx/backend/metal/kernels/steel/conv/loaders/loader_general.h
Normal file
@@ -0,0 +1,288 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "mlx/backend/metal/kernels/steel/utils.h"
|
||||
|
||||
#include "mlx/backend/metal/kernels/steel/conv/params.h"
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// Loading helper
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace mlx {
|
||||
namespace steel {
|
||||
|
||||
template <
|
||||
typename T,
|
||||
short BM,
|
||||
short BN,
|
||||
short BK,
|
||||
short tgp_size,
|
||||
short tgp_padding = 0>
|
||||
struct Conv2DInputBlockLoaderGeneral {
|
||||
// Destination dimensions
|
||||
STEEL_CONST short BROWS = BM;
|
||||
STEEL_CONST short BCOLS = BK;
|
||||
|
||||
// Read dimensions
|
||||
STEEL_CONST short dst_ld = BCOLS + tgp_padding;
|
||||
STEEL_CONST short vec_size = tgp_size / (BROWS * BCOLS) >= 8 ? 8 : 4;
|
||||
|
||||
// Thread read shape
|
||||
STEEL_CONST short TCOLS = BCOLS / vec_size;
|
||||
STEEL_CONST short TROWS = tgp_size / TCOLS;
|
||||
|
||||
// Rows / strided reads within the block
|
||||
STEEL_CONST short n_rows = BROWS / TROWS;
|
||||
|
||||
// Thread location indices
|
||||
const short thread_idx;
|
||||
const short bi;
|
||||
const short bj;
|
||||
|
||||
// threadgroup and device memory
|
||||
threadgroup T* dst;
|
||||
|
||||
const constant MLXConvParams<2>* params;
|
||||
const constant Conv2DGeneralJumpParams* jump_params;
|
||||
|
||||
const short base_wh;
|
||||
const short base_ww;
|
||||
|
||||
short weight_h;
|
||||
short weight_w;
|
||||
|
||||
const device T* src[n_rows];
|
||||
|
||||
int read_n[n_rows];
|
||||
int read_ih[n_rows];
|
||||
int read_iw[n_rows];
|
||||
|
||||
/* Constructor */
|
||||
METAL_FUNC Conv2DInputBlockLoaderGeneral(
|
||||
const device T* src_,
|
||||
threadgroup T* dst_,
|
||||
const int4 offsets,
|
||||
const constant MLXConvParams<2>* params_,
|
||||
const constant Conv2DGeneralJumpParams* jump_params_,
|
||||
const short base_wh_,
|
||||
const short base_ww_,
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]],
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]])
|
||||
: thread_idx(simd_group_id * 32 + simd_lane_id),
|
||||
bi(thread_idx / TCOLS),
|
||||
bj(vec_size * (thread_idx % TCOLS)),
|
||||
dst(dst_ + bi * dst_ld + bj),
|
||||
params(params_),
|
||||
jump_params(jump_params_),
|
||||
base_wh(base_wh_),
|
||||
base_ww(base_ww_),
|
||||
weight_h(base_wh_),
|
||||
weight_w(base_ww_) {
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short i = 0; i < n_rows; ++i) {
|
||||
int offset_nhw = offsets.y + bi + i * TROWS;
|
||||
int n = offset_nhw / jump_params->adj_out_hw;
|
||||
int hw = offset_nhw % jump_params->adj_out_hw;
|
||||
int oh =
|
||||
(hw / jump_params->adj_out_w) * jump_params->f_out_jump_h + offsets.z;
|
||||
int ow =
|
||||
(hw % jump_params->adj_out_w) * jump_params->f_out_jump_w + offsets.w;
|
||||
|
||||
int ih = oh * params->str[0] - params->pad[0];
|
||||
int iw = ow * params->str[1] - params->pad[1];
|
||||
|
||||
read_n[i] = n;
|
||||
read_ih[i] = ih;
|
||||
read_iw[i] = iw;
|
||||
|
||||
// Read from input if in bounds
|
||||
src[i] = src_ + n * params->in_strides[0] + bj;
|
||||
}
|
||||
}
|
||||
|
||||
/* Load from device memory into threadgroup memory - without bound checking */
|
||||
METAL_FUNC void load_unsafe() const {
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short i = 0, is = 0; i < n_rows; ++i, is += TROWS) {
|
||||
// Find bounds
|
||||
int n = read_n[i];
|
||||
|
||||
int h_flip = params->flip ? params->wS[0] - weight_h - 1 : weight_h;
|
||||
int w_flip = params->flip ? params->wS[1] - weight_w - 1 : weight_w;
|
||||
|
||||
int ih_dil = read_ih[i] + h_flip * params->kdil[0];
|
||||
int iw_dil = read_iw[i] + w_flip * params->kdil[1];
|
||||
|
||||
int ih = ih_dil / params->idil[0];
|
||||
int iw = iw_dil / params->idil[1];
|
||||
|
||||
size_t offset = ih * params->in_strides[1] + iw * params->in_strides[2];
|
||||
|
||||
// Read from input if in bounds
|
||||
if ((n < params->N) && (ih_dil >= 0 && ih < params->iS[0]) &&
|
||||
(iw_dil >= 0 && iw < params->iS[1])) {
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short j = 0; j < vec_size; ++j) {
|
||||
dst[is * dst_ld + j] = (src[i])[offset + j];
|
||||
}
|
||||
}
|
||||
|
||||
// Zero pad otherwise
|
||||
else {
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short j = 0; j < vec_size; ++j) {
|
||||
dst[is * dst_ld + j] = T(0);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/* Iteration helper */
|
||||
METAL_FUNC void next() {
|
||||
weight_w += jump_params->f_wgt_jump_w;
|
||||
if (weight_w < params->wS[1]) {
|
||||
return;
|
||||
}
|
||||
|
||||
weight_w = base_ww;
|
||||
|
||||
weight_h += jump_params->f_wgt_jump_h;
|
||||
if (weight_h < params->wS[0]) {
|
||||
return;
|
||||
}
|
||||
|
||||
weight_h = base_wh;
|
||||
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short i = 0; i < n_rows; i++) {
|
||||
src[i] += BK;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template <
|
||||
typename T,
|
||||
short BM,
|
||||
short BN,
|
||||
short BK,
|
||||
short tgp_size,
|
||||
short tgp_padding = 0>
|
||||
struct Conv2DWeightBlockLoaderGeneral {
|
||||
// Destination dimensions
|
||||
STEEL_CONST short BROWS = BN;
|
||||
STEEL_CONST short BCOLS = BK;
|
||||
|
||||
// Read dimensions
|
||||
STEEL_CONST short dst_ld = BCOLS + tgp_padding;
|
||||
STEEL_CONST short vec_size =
|
||||
(BN == 8) ? 1 : (tgp_size / (BROWS * BCOLS) >= 8 ? 8 : 4);
|
||||
|
||||
// Thread read shape
|
||||
STEEL_CONST short TCOLS = BCOLS / vec_size;
|
||||
STEEL_CONST short TROWS = tgp_size / TCOLS;
|
||||
|
||||
// Rows / strided reads within the block
|
||||
STEEL_CONST short n_rows = BROWS / TROWS;
|
||||
|
||||
// Leading dimension for src
|
||||
const int src_ld;
|
||||
|
||||
// Thread location indices
|
||||
const short thread_idx;
|
||||
const short bi;
|
||||
const short bj;
|
||||
|
||||
// threadgroup and device memory
|
||||
threadgroup T* dst;
|
||||
const device T* src;
|
||||
|
||||
const constant MLXConvParams<2>* params;
|
||||
const constant Conv2DGeneralJumpParams* jump_params;
|
||||
|
||||
const short base_wh;
|
||||
const short base_ww;
|
||||
|
||||
short weight_h;
|
||||
short weight_w;
|
||||
|
||||
const int start_row;
|
||||
|
||||
/* Constructor */
|
||||
METAL_FUNC Conv2DWeightBlockLoaderGeneral(
|
||||
const device T* src_,
|
||||
threadgroup T* dst_,
|
||||
const int2 offsets,
|
||||
const constant MLXConvParams<2>* params_,
|
||||
const constant Conv2DGeneralJumpParams* jump_params_,
|
||||
const short base_wh_,
|
||||
const short base_ww_,
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]],
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]])
|
||||
: src_ld(params_->wt_strides[0]),
|
||||
thread_idx(simd_group_id * 32 + simd_lane_id),
|
||||
bi(thread_idx / TCOLS),
|
||||
bj(vec_size * (thread_idx % TCOLS)),
|
||||
dst(dst_ + bi * dst_ld + bj),
|
||||
src(src_ + bi * src_ld + bj),
|
||||
params(params_),
|
||||
jump_params(jump_params_),
|
||||
base_wh(base_wh_),
|
||||
base_ww(base_ww_),
|
||||
weight_h(base_wh_),
|
||||
weight_w(base_ww_),
|
||||
start_row(offsets.y + bi) {}
|
||||
|
||||
/* Load from device memory into threadgroup memory - without bound checking */
|
||||
METAL_FUNC void load_unsafe() const {
|
||||
const device T* curr_src = src + weight_h * params->wt_strides[1] +
|
||||
weight_w * params->wt_strides[2];
|
||||
|
||||
if ((start_row + BN <= params->O)) {
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short i = 0; i < BN; i += TROWS) {
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short j = 0; j < vec_size; j++) {
|
||||
dst[i * dst_ld + j] = curr_src[i * src_ld + j];
|
||||
}
|
||||
}
|
||||
} else {
|
||||
for (short i = 0; i < BN; i += TROWS) {
|
||||
if ((start_row + i) < params->O) {
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short j = 0; j < vec_size; j++) {
|
||||
dst[i * dst_ld + j] = curr_src[i * src_ld + j];
|
||||
}
|
||||
} else {
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short j = 0; j < vec_size; j++) {
|
||||
dst[i * dst_ld + j] = T(0);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/* Iteration helper */
|
||||
METAL_FUNC void next() {
|
||||
weight_w += jump_params->f_wgt_jump_w;
|
||||
if (weight_w < params->wS[1]) {
|
||||
return;
|
||||
}
|
||||
|
||||
weight_w = base_ww;
|
||||
|
||||
weight_h += jump_params->f_wgt_jump_h;
|
||||
if (weight_h < params->wS[0]) {
|
||||
return;
|
||||
}
|
||||
|
||||
weight_h = base_wh;
|
||||
|
||||
src += BK;
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace steel
|
||||
} // namespace mlx
|
62
mlx/backend/metal/kernels/steel/conv/params.h
Normal file
62
mlx/backend/metal/kernels/steel/conv/params.h
Normal file
@@ -0,0 +1,62 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
#pragma once
|
||||
|
||||
template <int NDIM>
|
||||
struct MLXConvParams {
|
||||
const int N; // Batch size
|
||||
const int C; // In channels
|
||||
const int O; // Out channels
|
||||
const int iS[NDIM]; // Input spatial dim
|
||||
const int wS[NDIM]; // Weight spatial dim
|
||||
const int oS[NDIM]; // Output spatial dim
|
||||
const int str[NDIM]; // Kernel strides
|
||||
const int pad[NDIM]; // Input padding
|
||||
const int kdil[NDIM]; // Kernel dilation
|
||||
const int idil[NDIM]; // Input dilation
|
||||
const size_t in_strides[NDIM + 2]; // In strides
|
||||
const size_t wt_strides[NDIM + 2]; // Wt strides
|
||||
const size_t out_strides[NDIM + 2]; // Out strides
|
||||
const int groups; // Input channel groups
|
||||
const bool flip;
|
||||
};
|
||||
|
||||
namespace mlx {
|
||||
namespace steel {
|
||||
|
||||
struct ImplicitGemmConv2DParams {
|
||||
const int M;
|
||||
const int N;
|
||||
const int K;
|
||||
|
||||
const int gemm_k_iterations;
|
||||
|
||||
const int inp_jump_w;
|
||||
const int inp_jump_h;
|
||||
const int inp_jump_c;
|
||||
|
||||
const int tiles_n;
|
||||
const int tiles_m;
|
||||
const int swizzle_log;
|
||||
};
|
||||
|
||||
struct Conv2DGeneralJumpParams {
|
||||
const int f_wgt_jump_h;
|
||||
const int f_wgt_jump_w;
|
||||
|
||||
const int f_out_jump_h;
|
||||
const int f_out_jump_w;
|
||||
|
||||
const int adj_out_h;
|
||||
const int adj_out_w;
|
||||
const int adj_out_hw;
|
||||
const int adj_implicit_m;
|
||||
};
|
||||
|
||||
struct Conv2DGeneralBaseInfo {
|
||||
int weight_base;
|
||||
int weight_size;
|
||||
};
|
||||
|
||||
} // namespace steel
|
||||
} // namespace mlx
|
@@ -4,6 +4,7 @@
|
||||
|
||||
#include "mlx/backend/metal/kernels/steel/gemm/loader.h"
|
||||
#include "mlx/backend/metal/kernels/steel/gemm/mma.h"
|
||||
#include "mlx/backend/metal/kernels/steel/gemm/params.h"
|
||||
#include "mlx/backend/metal/kernels/steel/gemm/transforms.h"
|
||||
#include "mlx/backend/metal/kernels/steel/utils.h"
|
||||
|
||||
|
@@ -2,9 +2,15 @@
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <metal_simdgroup>
|
||||
#include <metal_simdgroup_matrix>
|
||||
#include <metal_stdlib>
|
||||
|
||||
#include "mlx/backend/metal/kernels/steel/gemm/transforms.h"
|
||||
#include "mlx/backend/metal/kernels/steel/utils.h"
|
||||
|
||||
using namespace metal;
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// MMA helper
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
@@ -167,6 +173,9 @@ struct BlockMMA {
|
||||
C += (sm + tm) * ldc + (tn + sn);
|
||||
dst_tile_dims -= short2(tn + sn, sm + tm);
|
||||
|
||||
if (dst_tile_dims.x <= 0 || dst_tile_dims.y <= 0)
|
||||
return;
|
||||
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (int i = 0; i < TM; i++) {
|
||||
if (i * TM_stride < dst_tile_dims.y) {
|
||||
@@ -236,6 +245,9 @@ struct BlockMMA {
|
||||
D += (sm + tm) * ldd + tn + sn;
|
||||
dst_tile_dims -= short2(tn + sn, sm + tm);
|
||||
|
||||
if (dst_tile_dims.x <= 0 || dst_tile_dims.y <= 0)
|
||||
return;
|
||||
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (int i = 0; i < TM; i++) {
|
||||
if (i * TM_stride < dst_tile_dims.y) {
|
||||
|
@@ -1,5 +0,0 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "mlx/backend/metal/kernels/steel/gemm/params.h"
|
@@ -3,7 +3,6 @@
|
||||
#pragma once
|
||||
|
||||
#include <metal_stdlib>
|
||||
#include "mlx/backend/metal/kernels/steel/host.h"
|
||||
|
||||
#define STEEL_CONST static constant constexpr const
|
||||
#define STEEL_PRAGMA_UNROLL _Pragma("clang loop unroll(full)")
|
@@ -8,7 +8,7 @@
|
||||
#include "mlx/backend/metal/copy.h"
|
||||
#include "mlx/backend/metal/device.h"
|
||||
#include "mlx/backend/metal/kernels/defines.h"
|
||||
#include "mlx/backend/metal/kernels/steel/host.h"
|
||||
#include "mlx/backend/metal/kernels/steel/gemm/params.h"
|
||||
#include "mlx/backend/metal/matmul.h"
|
||||
#include "mlx/backend/metal/mps/gemm.h"
|
||||
#include "mlx/backend/metal/utils.h"
|
||||
|
Reference in New Issue
Block a user