mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-24 01:17:26 +08:00

* Dispatch bf16 at run time when using the JIT * fix extension * fix extension build * fix extension build * Update utils.h
655 lines
23 KiB
Metal
655 lines
23 KiB
Metal
// Copyright © 2023-2024 Apple Inc.
|
|
|
|
#include <metal_simdgroup>
|
|
#include <metal_simdgroup_matrix>
|
|
#include <metal_stdlib>
|
|
|
|
#include "mlx/backend/metal/kernels/steel/conv/params.h"
|
|
#include "mlx/backend/metal/kernels/utils.h"
|
|
|
|
#define MLX_MTL_CONST static constant constexpr const
|
|
|
|
using namespace metal;
|
|
|
|
///////////////////////////////////////////////////////////////////////////////
|
|
/// Naive unfold with dilation
|
|
///////////////////////////////////////////////////////////////////////////////
|
|
|
|
template <typename T, int N>
|
|
[[kernel]] void naive_unfold_Nd(
|
|
const device T* in [[buffer(0)]],
|
|
device T* out [[buffer(1)]],
|
|
const constant MLXConvParams<N>* params [[buffer(2)]],
|
|
uint3 gid [[thread_position_in_grid]]) {
|
|
int filter_size = params->C;
|
|
for (short i = 0; i < N; i++)
|
|
filter_size *= params->wS[i];
|
|
|
|
int out_pixels = 1;
|
|
for (short i = 0; i < N; i++)
|
|
out_pixels *= params->oS[i];
|
|
|
|
// Set out
|
|
out += gid.z * filter_size + gid.y * (params->C);
|
|
|
|
// Coordinates in input
|
|
int is[N] = {0};
|
|
|
|
// gid.z: N oS (Batch and row in unfolded output)
|
|
// gid.y: wS (Filter location to unfold input)
|
|
// gid.x: C (channel)
|
|
|
|
int n = (gid.z) / out_pixels;
|
|
int oS = (gid.z) % out_pixels;
|
|
int wS = gid.y;
|
|
|
|
bool valid = n < params->N;
|
|
|
|
// Unroll dimensions
|
|
for (int i = N - 1; i >= 0; --i) {
|
|
int os_ = (oS % params->oS[i]);
|
|
int ws_ = (wS % params->wS[i]);
|
|
|
|
ws_ = params->flip ? params->wS[i] - ws_ - 1 : ws_;
|
|
|
|
int is_ = os_ * params->str[i] - params->pad[i] + ws_ * params->kdil[i];
|
|
int is_max = 1 + params->idil[i] * (params->iS[i] - 1);
|
|
|
|
valid &= is_ >= 0 && is_ < is_max && (is_ % params->idil[i] == 0);
|
|
|
|
is[i] = is_ / params->idil[i];
|
|
|
|
oS /= params->oS[i];
|
|
wS /= params->wS[i];
|
|
}
|
|
|
|
if (valid) {
|
|
size_t in_offset = n * params->in_strides[0];
|
|
|
|
for (int i = 0; i < N; ++i) {
|
|
in_offset += is[i] * params->in_strides[i + 1];
|
|
}
|
|
|
|
out[gid.x] = in[in_offset + gid.x];
|
|
} else {
|
|
out[gid.x] = T(0);
|
|
}
|
|
}
|
|
|
|
// This kernel unfolds the input array of size (N, *spatial_dims, C)
|
|
// into an array of size (N x *spatial_dims, C x *kernel_dims).
|
|
template <typename T, int N>
|
|
[[kernel]] void naive_unfold_transpose_Nd(
|
|
const device T* in [[buffer(0)]],
|
|
device T* out [[buffer(1)]],
|
|
const constant MLXConvParams<N>* params [[buffer(2)]],
|
|
uint3 gid [[thread_position_in_grid]]) {
|
|
int filter_size = params->C;
|
|
for (short i = 0; i < N; i++)
|
|
filter_size *= params->wS[i];
|
|
|
|
int out_pixels = 1;
|
|
for (short i = 0; i < N; i++)
|
|
out_pixels *= params->oS[i];
|
|
|
|
// Set out
|
|
out += gid.z * filter_size + gid.x * (filter_size / params->C);
|
|
|
|
// Coordinates in input
|
|
int is[N] = {0};
|
|
|
|
// gid.z: N oS (Batch and row in unfolded output)
|
|
// gid.y: wS (Filter location to unfold input)
|
|
// gid.x: C (channel)
|
|
|
|
int n = (gid.z) / out_pixels;
|
|
int oS = (gid.z) % out_pixels;
|
|
int wS = gid.y;
|
|
|
|
bool valid = n < params->N;
|
|
|
|
// Unroll dimensions
|
|
int kernel_stride = 1;
|
|
for (int i = N - 1; i >= 0; --i) {
|
|
int os_ = (oS % params->oS[i]);
|
|
int ws_ = (wS % params->wS[i]);
|
|
out += ws_ * kernel_stride;
|
|
|
|
ws_ = params->flip ? params->wS[i] - ws_ - 1 : ws_;
|
|
|
|
int is_ = os_ * params->str[i] - params->pad[i] + ws_ * params->kdil[i];
|
|
int is_max = 1 + params->idil[i] * (params->iS[i] - 1);
|
|
|
|
valid &= is_ >= 0 && is_ < is_max && (is_ % params->idil[i] == 0);
|
|
|
|
is[i] = is_ / params->idil[i];
|
|
|
|
oS /= params->oS[i];
|
|
wS /= params->wS[i];
|
|
|
|
kernel_stride *= params->wS[i];
|
|
}
|
|
|
|
if (valid) {
|
|
size_t in_offset = n * params->in_strides[0];
|
|
|
|
for (int i = 0; i < N; ++i) {
|
|
in_offset += is[i] * params->in_strides[i + 1];
|
|
}
|
|
|
|
out[0] = in[in_offset + gid.x];
|
|
} else {
|
|
out[0] = T(0);
|
|
}
|
|
}
|
|
|
|
#define instantiate_naive_unfold_nd(name, itype, n) \
|
|
template [[host_name("naive_unfold_nd_" #name "_" #n)]] [[kernel]] void \
|
|
naive_unfold_Nd( \
|
|
const device itype* in [[buffer(0)]], \
|
|
device itype* out [[buffer(1)]], \
|
|
const constant MLXConvParams<n>* params [[buffer(2)]], \
|
|
uint3 gid [[thread_position_in_grid]]); \
|
|
template \
|
|
[[host_name("naive_unfold_transpose_nd_" #name "_" #n)]] [[kernel]] void \
|
|
naive_unfold_transpose_Nd( \
|
|
const device itype* in [[buffer(0)]], \
|
|
device itype* out [[buffer(1)]], \
|
|
const constant MLXConvParams<n>* params [[buffer(2)]], \
|
|
uint3 gid [[thread_position_in_grid]]);
|
|
|
|
#define instantiate_naive_unfold_nd_dims(name, itype) \
|
|
instantiate_naive_unfold_nd(name, itype, 1) instantiate_naive_unfold_nd( \
|
|
name, itype, 2) instantiate_naive_unfold_nd(name, itype, 3)
|
|
|
|
instantiate_naive_unfold_nd_dims(float32, float);
|
|
instantiate_naive_unfold_nd_dims(float16, half);
|
|
instantiate_naive_unfold_nd_dims(bfloat16, bfloat16_t);
|
|
|
|
///////////////////////////////////////////////////////////////////////////////
|
|
/// Slow and naive conv2d kernels
|
|
///////////////////////////////////////////////////////////////////////////////
|
|
|
|
template <
|
|
typename T,
|
|
const int BM, /* Threadgroup rows (in threads) */
|
|
const int BN, /* Threadgroup cols (in threads) */
|
|
const int TM, /* Thread rows (in elements) */
|
|
const int TN, /* Thread cols (in elements) */
|
|
const int BC = 16>
|
|
[[kernel]] void naive_conv_2d(
|
|
const device T* in [[buffer(0)]],
|
|
const device T* wt [[buffer(1)]],
|
|
device T* out [[buffer(2)]],
|
|
const constant MLXConvParams<2>& params [[buffer(3)]],
|
|
uint3 tid [[threadgroup_position_in_grid]],
|
|
uint3 lid [[thread_position_in_threadgroup]],
|
|
uint simd_gid [[simdgroup_index_in_threadgroup]],
|
|
uint simd_lid [[thread_index_in_simdgroup]]) {
|
|
(void)simd_gid;
|
|
(void)simd_lid;
|
|
|
|
out += tid.z * params.out_strides[0];
|
|
in += tid.z * params.in_strides[0];
|
|
|
|
int out_o = tid.y * BN * TN + lid.y * TN;
|
|
int out_hw = tid.x * BM * TM + lid.x * TM;
|
|
|
|
int out_h[TM];
|
|
int out_w[TN];
|
|
|
|
for (int m = 0; m < TM; ++m) {
|
|
int mm = (out_hw + m);
|
|
out_h[m] = mm / params.oS[1];
|
|
out_w[m] = mm % params.oS[1];
|
|
}
|
|
|
|
T in_local[TM];
|
|
T wt_local[TN];
|
|
T out_local[TM * TN] = {T(0)};
|
|
|
|
for (int h = 0; h < params.wS[0]; ++h) {
|
|
for (int w = 0; w < params.wS[1]; ++w) {
|
|
for (int c = 0; c < params.C; ++c) {
|
|
// Local in
|
|
for (int m = 0; m < TM; m++) {
|
|
int i = out_h[m] * params.str[0] - params.pad[0] + h * params.kdil[0];
|
|
int j = out_w[m] * params.str[1] - params.pad[1] + w * params.kdil[1];
|
|
|
|
bool valid = i >= 0 && i < params.iS[0] && j >= 0 && j < params.iS[1];
|
|
in_local[m] = valid
|
|
? in[i * params.in_strides[1] + j * params.in_strides[2] + c]
|
|
: T(0);
|
|
}
|
|
|
|
// Load weight
|
|
for (int n = 0; n < TN; ++n) {
|
|
int o = out_o + n;
|
|
wt_local[n] = o < params.O
|
|
? wt[o * params.wt_strides[0] + h * params.wt_strides[1] +
|
|
w * params.wt_strides[2] + c]
|
|
: T(0);
|
|
}
|
|
|
|
// Accumulate
|
|
for (int m = 0; m < TM; ++m) {
|
|
for (int n = 0; n < TN; ++n) {
|
|
out_local[m * TN + n] += in_local[m] * wt_local[n];
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
for (int m = 0; m < TM; ++m) {
|
|
for (int n = 0; n < TN; ++n) {
|
|
if (out_h[m] < params.oS[0] && out_w[m] < params.oS[1] &&
|
|
(out_o + n) < params.O)
|
|
out[out_h[m] * params.out_strides[1] +
|
|
out_w[m] * params.out_strides[2] + out_o + n] =
|
|
out_local[m * TN + n];
|
|
}
|
|
}
|
|
}
|
|
|
|
// Instantiations
|
|
|
|
#define instantiate_naive_conv_2d(name, itype, bm, bn, tm, tn) \
|
|
template [[host_name("naive_conv_2d_" #name "_bm" #bm "_bn" #bn "_tm" #tm \
|
|
"_tn" #tn)]] [[kernel]] void \
|
|
naive_conv_2d<itype, bm, bn, tm, tn>( \
|
|
const device itype* in [[buffer(0)]], \
|
|
const device itype* wt [[buffer(1)]], \
|
|
device itype* out [[buffer(2)]], \
|
|
const constant MLXConvParams<2>& params [[buffer(3)]], \
|
|
uint3 tid [[threadgroup_position_in_grid]], \
|
|
uint3 lid [[thread_position_in_threadgroup]], \
|
|
uint simd_gid [[simdgroup_index_in_threadgroup]], \
|
|
uint simd_lid [[thread_index_in_simdgroup]]);
|
|
|
|
#define instantiate_naive_conv_2d_blocks(name, itype) \
|
|
instantiate_naive_conv_2d(name, itype, 16, 8, 4, 4) \
|
|
instantiate_naive_conv_2d(name, itype, 16, 8, 2, 4)
|
|
|
|
instantiate_naive_conv_2d_blocks(float32, float);
|
|
instantiate_naive_conv_2d_blocks(float16, half);
|
|
instantiate_naive_conv_2d_blocks(bfloat16, bfloat16_t);
|
|
|
|
///////////////////////////////////////////////////////////////////////////////
|
|
/// Winograd kernels
|
|
///////////////////////////////////////////////////////////////////////////////
|
|
|
|
template <int M, int R, int S>
|
|
struct WinogradTransforms {};
|
|
|
|
template <>
|
|
struct WinogradTransforms<6, 3, 8> {
|
|
MLX_MTL_CONST int OUT_TILE_SIZE = 6;
|
|
MLX_MTL_CONST int FILTER_SIZE = 3;
|
|
MLX_MTL_CONST int IN_TILE_SIZE = OUT_TILE_SIZE + FILTER_SIZE - 1;
|
|
MLX_MTL_CONST int SIMD_MATRIX_SIZE = 8;
|
|
MLX_MTL_CONST float in_transform[SIMD_MATRIX_SIZE][SIMD_MATRIX_SIZE] = {
|
|
{1.00f, 0.00f, 0.00f, 0.00f, 0.00f, 0.00f, 0.00f, 0.00f},
|
|
{0.00f, 1.00f, -1.00f, 0.50f, -0.50f, 2.00f, -2.00f, -1.00f},
|
|
{-5.25f, 1.00f, 1.00f, 0.25f, 0.25f, 4.00f, 4.00f, 0.00f},
|
|
{0.00f, -4.25f, 4.25f, -2.50f, 2.50f, -2.50f, 2.50f, 5.25f},
|
|
{5.25f, -4.25f, -4.25f, -1.25f, -1.25f, -5.00f, -5.00f, 0.00f},
|
|
{0.00f, 1.00f, -1.00f, 2.00f, -2.00f, 0.50f, -0.50f, -5.25f},
|
|
{-1.00f, 1.00f, 1.00f, 1.00f, 1.00f, 1.00f, 1.00f, 0.00f},
|
|
{0.00f, 0.00f, 0.00f, 0.00f, 0.00f, 0.00f, 0.00f, 1.00f},
|
|
};
|
|
|
|
MLX_MTL_CONST float out_transform[SIMD_MATRIX_SIZE][SIMD_MATRIX_SIZE] = {
|
|
{1.00f, 0.00f, 0.00f, 0.00f, 0.00f, 0.00f},
|
|
{1.00f, 1.00f, 1.00f, 1.00f, 1.00f, 1.00f},
|
|
{1.00f, -1.00f, 1.00f, -1.00f, 1.00f, -1.00f},
|
|
{1.00f, 2.00f, 4.00f, 8.00f, 16.00f, 32.00f},
|
|
{1.00f, -2.00f, 4.00f, -8.00f, 16.00f, -32.00f},
|
|
{1.00f, 0.50f, 0.25f, 0.125f, 0.0625f, 0.03125f},
|
|
{1.00f, -0.50f, 0.25f, -0.125f, 0.0625f, -0.03125f},
|
|
{0.00f, 0.00f, 0.00f, 0.00f, 0.00f, 1.00f},
|
|
};
|
|
|
|
MLX_MTL_CONST float wt_transform[SIMD_MATRIX_SIZE][SIMD_MATRIX_SIZE] = {
|
|
{1.00, 0.00, 0.00},
|
|
{-2.0 / 9.00, -2.0 / 9.00, -2.0 / 9.00},
|
|
{-2.0 / 9.00, 2.0 / 9.00, -2.0 / 9.00},
|
|
{1.0 / 90.0, 1.0 / 45.0, 2.0 / 45.0},
|
|
{1.0 / 90.0, -1.0 / 45.0, 2.0 / 45.0},
|
|
{32.0 / 45.0, 16.0 / 45.0, 8.0 / 45.0},
|
|
{32.0 / 45.0, -16.0 / 45.0, 8.0 / 45.0},
|
|
{0.00, 0.00, 1.00},
|
|
};
|
|
};
|
|
|
|
constant constexpr const float WinogradTransforms<6, 3, 8>::wt_transform[8][8];
|
|
constant constexpr const float WinogradTransforms<6, 3, 8>::in_transform[8][8];
|
|
constant constexpr const float WinogradTransforms<6, 3, 8>::out_transform[8][8];
|
|
|
|
template <typename T, int BC = 32, int BO = 4, int M = 6, int R = 3>
|
|
[[kernel, max_total_threads_per_threadgroup(BO * 32)]] void
|
|
winograd_conv_2d_weight_transform(
|
|
const device T* wt_in [[buffer(0)]],
|
|
device T* wt_out [[buffer(1)]],
|
|
const constant int& C [[buffer(2)]],
|
|
const constant int& O [[buffer(3)]],
|
|
uint tid [[threadgroup_position_in_grid]],
|
|
uint simd_group_id [[simdgroup_index_in_threadgroup]],
|
|
uint simd_lane_id [[thread_index_in_simdgroup]]) {
|
|
using WGT = WinogradTransforms<M, R, 8>;
|
|
|
|
// Get lane position in simdgroup
|
|
const short qid = simd_lane_id / 4;
|
|
const short sm = (qid & 4) + (simd_lane_id / 2) % 4;
|
|
const short sn = (qid & 2) * 2 + (simd_lane_id % 2) * 2;
|
|
|
|
// Initialize G matrix
|
|
simdgroup_matrix<float, 8, 8> G;
|
|
G.thread_elements()[0] = WGT::wt_transform[sm][sn];
|
|
G.thread_elements()[1] = WGT::wt_transform[sm][sn + 1];
|
|
|
|
// Initialize Gt matrix
|
|
simdgroup_matrix<float, 8, 8> Gt;
|
|
Gt.thread_elements()[0] = WGT::wt_transform[sn][sm];
|
|
Gt.thread_elements()[1] = WGT::wt_transform[sn + 1][sm];
|
|
|
|
// Move to the correct output filter
|
|
size_t ko = BO * tid + simd_group_id;
|
|
wt_in += ko * R * R * C;
|
|
|
|
// wt_out is stored transposed (A x A x C x O)
|
|
short ohw_0 = sm * 8 + sn;
|
|
short ohw_1 = sm * 8 + sn + 1;
|
|
device T* wt_out_0 = wt_out + ohw_0 * C * O + ko;
|
|
device T* wt_out_1 = wt_out + ohw_1 * C * O + ko;
|
|
|
|
// Prepare shared memory
|
|
threadgroup T Ws[BO][R][R][BC];
|
|
|
|
// Loop over C
|
|
for (int bc = 0; bc < C; bc += BC) {
|
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
// Read into shared memory
|
|
for (int kh = 0; kh < R; ++kh) {
|
|
for (int kw = 0; kw < R; ++kw) {
|
|
for (int kc = simd_lane_id; kc < BC; kc += 32) {
|
|
Ws[simd_group_id][kh][kw][kc] = wt_in[kh * R * C + kw * C + kc];
|
|
}
|
|
}
|
|
}
|
|
|
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
// Do transform and store the result
|
|
for (int c = 0; c < BC; ++c) {
|
|
simdgroup_matrix<float, 8, 8> g;
|
|
g.thread_elements()[0] =
|
|
sm < R && sn < R ? Ws[simd_group_id][sm][sn][c] : T(0);
|
|
g.thread_elements()[1] =
|
|
sm < R && sn + 1 < R ? Ws[simd_group_id][sm][sn + 1][c] : T(0);
|
|
|
|
simdgroup_matrix<float, 8, 8> g_out = (G * g) * Gt;
|
|
wt_out_0[c * O] = static_cast<T>(g_out.thread_elements()[0]);
|
|
wt_out_1[c * O] = static_cast<T>(g_out.thread_elements()[1]);
|
|
}
|
|
|
|
wt_in += BC;
|
|
wt_out_0 += BC * O;
|
|
wt_out_1 += BC * O;
|
|
}
|
|
}
|
|
|
|
#define instantiate_winograd_conv_2d_weight_transform_base(name, itype, bc) \
|
|
template [[host_name("winograd_conv_2d_weight_transform_" #name \
|
|
"_bc" #bc)]] [[kernel]] void \
|
|
winograd_conv_2d_weight_transform<itype, bc>( \
|
|
const device itype* wt_in [[buffer(0)]], \
|
|
device itype* wt_out [[buffer(1)]], \
|
|
const constant int& C [[buffer(2)]], \
|
|
const constant int& O [[buffer(3)]], \
|
|
uint tid [[threadgroup_position_in_grid]], \
|
|
uint simd_group_id [[simdgroup_index_in_threadgroup]], \
|
|
uint simd_lane_id [[thread_index_in_simdgroup]]);
|
|
|
|
template <typename T, int BC, int WM, int WN, int M = 6, int R = 3>
|
|
[[kernel, max_total_threads_per_threadgroup(WM* WN * 32)]] void
|
|
winograd_conv_2d_input_transform(
|
|
const device T* inp_in [[buffer(0)]],
|
|
device T* inp_out [[buffer(1)]],
|
|
const constant MLXConvParams<2>& params [[buffer(2)]],
|
|
uint3 tid [[threadgroup_position_in_grid]],
|
|
uint3 lid [[thread_position_in_threadgroup]],
|
|
uint3 tgp_per_grid [[threadgroups_per_grid]],
|
|
uint simd_group_id [[simdgroup_index_in_threadgroup]],
|
|
uint simd_lane_id [[thread_index_in_simdgroup]]) {
|
|
(void)lid;
|
|
|
|
using WGT = WinogradTransforms<M, R, 8>;
|
|
constexpr int A = WGT::IN_TILE_SIZE;
|
|
constexpr int N_SIMD_GROUPS = WM * WN;
|
|
|
|
// Get lane position in simdgroup
|
|
const short qid = simd_lane_id / 4;
|
|
const short sm = (qid & 4) + (simd_lane_id / 2) % 4;
|
|
const short sn = (qid & 2) * 2 + (simd_lane_id % 2) * 2;
|
|
|
|
// Initialize B matrix
|
|
simdgroup_matrix<float, 8, 8> B;
|
|
B.thread_elements()[0] = WGT::in_transform[sm][sn];
|
|
B.thread_elements()[1] = WGT::in_transform[sm][sn + 1];
|
|
|
|
// Initialize Bt matrix
|
|
simdgroup_matrix<float, 8, 8> Bt;
|
|
Bt.thread_elements()[0] = WGT::in_transform[sn][sm];
|
|
Bt.thread_elements()[1] = WGT::in_transform[sn + 1][sm];
|
|
|
|
// Resolve input tile
|
|
constexpr int TH = (A / WM);
|
|
constexpr int TW = (A / WN);
|
|
int kh = TH * (simd_group_id / WN);
|
|
int kw = TW * (simd_group_id % WN);
|
|
int bh = M * tid.y + kh;
|
|
int bw = M * tid.x + kw;
|
|
|
|
// Move to the correct input tile
|
|
inp_in += tid.z * params.in_strides[0] + bh * params.in_strides[1] +
|
|
bw * params.in_strides[2];
|
|
|
|
// Pre compute strides
|
|
int jump_in[TH][TW];
|
|
|
|
for (int h = 0; h < TH; h++) {
|
|
for (int w = 0; w < TW; w++) {
|
|
jump_in[h][w] = h * params.in_strides[1] + w * params.in_strides[2];
|
|
}
|
|
}
|
|
|
|
// inp_out is stored interleaved (A x A x tiles x C)
|
|
size_t N_TILES = tgp_per_grid.x * tgp_per_grid.y * tgp_per_grid.z;
|
|
size_t tile_id =
|
|
tid.z * tgp_per_grid.x * tgp_per_grid.y + tid.y * tgp_per_grid.x + tid.x;
|
|
size_t ohw_0 = sm * 8 + sn;
|
|
size_t ohw_1 = sm * 8 + sn + 1;
|
|
device T* inp_out_0 =
|
|
inp_out + ohw_0 * N_TILES * params.C + tile_id * params.C;
|
|
device T* inp_out_1 =
|
|
inp_out + ohw_1 * N_TILES * params.C + tile_id * params.C;
|
|
|
|
// Prepare shared memory
|
|
threadgroup T Is[A][A][BC];
|
|
|
|
// Loop over C
|
|
for (int bc = 0; bc < params.C; bc += BC) {
|
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
// Read into shared memory
|
|
for (int h = 0; h < TH; h++) {
|
|
for (int w = 0; w < TW; w++) {
|
|
const device T* in_ptr = inp_in + jump_in[h][w];
|
|
for (int c = simd_lane_id; c < BC; c += 32) {
|
|
Is[kh + h][kw + w][c] = in_ptr[c];
|
|
}
|
|
}
|
|
}
|
|
|
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
// Do transform and store the result
|
|
for (int c = simd_group_id; c < BC; c += N_SIMD_GROUPS) {
|
|
simdgroup_matrix<float, 8, 8> I;
|
|
I.thread_elements()[0] = Is[sm][sn][c];
|
|
I.thread_elements()[1] = Is[sm][sn + 1][c];
|
|
|
|
simdgroup_matrix<float, 8, 8> I_out = (Bt * I) * B;
|
|
inp_out_0[c] = static_cast<T>(I_out.thread_elements()[0]);
|
|
inp_out_1[c] = static_cast<T>(I_out.thread_elements()[1]);
|
|
}
|
|
|
|
inp_in += BC;
|
|
inp_out_0 += BC;
|
|
inp_out_1 += BC;
|
|
}
|
|
}
|
|
|
|
#define instantiate_winograd_conv_2d_input_transform(name, itype, bc) \
|
|
template [[host_name("winograd_conv_2d_input_transform_" #name \
|
|
"_bc" #bc)]] [[kernel]] void \
|
|
winograd_conv_2d_input_transform<itype, bc, 2, 2>( \
|
|
const device itype* inp_in [[buffer(0)]], \
|
|
device itype* inp_out [[buffer(1)]], \
|
|
const constant MLXConvParams<2>& params [[buffer(2)]], \
|
|
uint3 tid [[threadgroup_position_in_grid]], \
|
|
uint3 lid [[thread_position_in_threadgroup]], \
|
|
uint3 tgp_per_grid [[threadgroups_per_grid]], \
|
|
uint simd_group_id [[simdgroup_index_in_threadgroup]], \
|
|
uint simd_lane_id [[thread_index_in_simdgroup]]);
|
|
|
|
template <typename T, int BO, int WM, int WN, int M = 6, int R = 3>
|
|
[[kernel, max_total_threads_per_threadgroup(WM* WN * 32)]] void
|
|
winograd_conv_2d_output_transform(
|
|
const device T* out_in [[buffer(0)]],
|
|
device T* out_out [[buffer(1)]],
|
|
const constant MLXConvParams<2>& params [[buffer(2)]],
|
|
uint3 tid [[threadgroup_position_in_grid]],
|
|
uint3 lid [[thread_position_in_threadgroup]],
|
|
uint3 tgp_per_grid [[threadgroups_per_grid]],
|
|
uint simd_group_id [[simdgroup_index_in_threadgroup]],
|
|
uint simd_lane_id [[thread_index_in_simdgroup]]) {
|
|
(void)lid;
|
|
|
|
using WGT = WinogradTransforms<M, R, 8>;
|
|
constexpr int N_SIMD_GROUPS = WM * WN;
|
|
|
|
// Get lane position in simdgroup
|
|
const short qid = simd_lane_id / 4;
|
|
const short sm = (qid & 4) + (simd_lane_id / 2) % 4;
|
|
const short sn = (qid & 2) * 2 + (simd_lane_id % 2) * 2;
|
|
|
|
// Initialize A matrix
|
|
simdgroup_matrix<float, 8, 8> B;
|
|
B.thread_elements()[0] = WGT::out_transform[sm][sn];
|
|
B.thread_elements()[1] = WGT::out_transform[sm][sn + 1];
|
|
|
|
// Initialize At matrix
|
|
simdgroup_matrix<float, 8, 8> Bt;
|
|
Bt.thread_elements()[0] = WGT::out_transform[sn][sm];
|
|
Bt.thread_elements()[1] = WGT::out_transform[sn + 1][sm];
|
|
|
|
// Out_in comes in shape (A x A x tiles x O)
|
|
// We do transform and then write out to out_out in shape (N, H, W, O)
|
|
|
|
// Resolve output tile
|
|
constexpr int TH = (M / WM);
|
|
constexpr int TW = (M / WN);
|
|
int kh = TH * (simd_group_id / WN);
|
|
int kw = TW * (simd_group_id % WN);
|
|
int bh = M * tid.y + kh;
|
|
int bw = M * tid.x + kw;
|
|
|
|
// Move to the correct input tile
|
|
out_out += tid.z * params.out_strides[0] + bh * params.out_strides[1] +
|
|
bw * params.out_strides[2];
|
|
|
|
// Pre compute strides
|
|
int jump_in[TH][TW];
|
|
|
|
for (int h = 0; h < TH; h++) {
|
|
for (int w = 0; w < TW; w++) {
|
|
bool valid = ((bh + h) < params.oS[0]) && ((bw + w) < params.oS[1]);
|
|
jump_in[h][w] =
|
|
valid ? h * params.out_strides[1] + w * params.out_strides[2] : -1;
|
|
}
|
|
}
|
|
|
|
// out_in is stored interleaved (A x A x tiles x O)
|
|
size_t N_TILES = tgp_per_grid.x * tgp_per_grid.y * tgp_per_grid.z;
|
|
size_t tile_id =
|
|
tid.z * tgp_per_grid.x * tgp_per_grid.y + tid.y * tgp_per_grid.x + tid.x;
|
|
size_t ohw_0 = sm * 8 + sn;
|
|
size_t ohw_1 = sm * 8 + sn + 1;
|
|
const device T* out_in_0 =
|
|
out_in + ohw_0 * N_TILES * params.O + tile_id * params.O;
|
|
const device T* out_in_1 =
|
|
out_in + ohw_1 * N_TILES * params.O + tile_id * params.O;
|
|
|
|
// Prepare shared memory
|
|
threadgroup T Os[M][M][BO];
|
|
|
|
// Loop over O
|
|
for (int bo = 0; bo < params.O; bo += BO) {
|
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
// Do transform and store the result
|
|
for (int c = simd_group_id; c < BO; c += N_SIMD_GROUPS) {
|
|
simdgroup_matrix<float, 8, 8> O_mat;
|
|
O_mat.thread_elements()[0] = out_in_0[c];
|
|
O_mat.thread_elements()[1] = out_in_1[c];
|
|
|
|
simdgroup_matrix<float, 8, 8> O_out = (Bt * (O_mat * B));
|
|
if ((sm < M) && (sn < M)) {
|
|
Os[sm][sn][c] = static_cast<T>(O_out.thread_elements()[0]);
|
|
}
|
|
if ((sm < M) && ((sn + 1) < M)) {
|
|
Os[sm][sn + 1][c] = static_cast<T>(O_out.thread_elements()[1]);
|
|
}
|
|
}
|
|
|
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
// Read out from shared memory
|
|
for (int h = 0; h < TH; h++) {
|
|
for (int w = 0; w < TW; w++) {
|
|
if (jump_in[h][w] >= 0) {
|
|
device T* out_ptr = out_out + jump_in[h][w];
|
|
for (int c = simd_lane_id; c < BO; c += 32) {
|
|
out_ptr[c] = Os[kh + h][kw + w][c];
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
out_out += BO;
|
|
out_in_0 += BO;
|
|
out_in_1 += BO;
|
|
}
|
|
}
|
|
|
|
#define instantiate_winograd_conv_2d_output_transform(name, itype, bo) \
|
|
template [[host_name("winograd_conv_2d_output_transform_" #name \
|
|
"_bo" #bo)]] [[kernel]] void \
|
|
winograd_conv_2d_output_transform<itype, bo, 2, 2>( \
|
|
const device itype* out_in [[buffer(0)]], \
|
|
device itype* out_out [[buffer(1)]], \
|
|
const constant MLXConvParams<2>& params [[buffer(2)]], \
|
|
uint3 tid [[threadgroup_position_in_grid]], \
|
|
uint3 lid [[thread_position_in_threadgroup]], \
|
|
uint3 tgp_per_grid [[threadgroups_per_grid]], \
|
|
uint simd_group_id [[simdgroup_index_in_threadgroup]], \
|
|
uint simd_lane_id [[thread_index_in_simdgroup]]);
|
|
|
|
// clang-format off
|
|
#define instantiate_winograd_conv_2d(name, itype) \
|
|
instantiate_winograd_conv_2d_weight_transform_base(name, itype, 32) \
|
|
instantiate_winograd_conv_2d_input_transform(name, itype, 32) \
|
|
instantiate_winograd_conv_2d_output_transform(name, itype, 32) // clang-format on
|
|
|
|
// clang-format off
|
|
instantiate_winograd_conv_2d(float32, float);
|
|
instantiate_winograd_conv_2d(bfloat16, bfloat16_t);
|
|
instantiate_winograd_conv_2d(float16, half); // clang-format on
|