mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Winograd Update for Small batches (#1803)
* Build in padding to Winograd kernels * Add new fused Winograd kernel * Enable weight flipping in Winograd kernels
This commit is contained in:
@@ -326,7 +326,13 @@ constant constexpr const float WinogradTransforms<6, 3, 8>::wt_transform[8][8];
|
||||
constant constexpr const float WinogradTransforms<6, 3, 8>::in_transform[8][8];
|
||||
constant constexpr const float WinogradTransforms<6, 3, 8>::out_transform[8][8];
|
||||
|
||||
template <typename T, int BC = 32, int BO = 4, int M = 6, int R = 3>
|
||||
template <
|
||||
typename T,
|
||||
int BC = 32,
|
||||
int BO = 4,
|
||||
bool do_flip = false,
|
||||
int M = 6,
|
||||
int R = 3>
|
||||
[[kernel, max_total_threads_per_threadgroup(BO * 32)]] void
|
||||
winograd_conv_2d_weight_transform(
|
||||
const device T* wt_in [[buffer(0)]],
|
||||
@@ -373,7 +379,12 @@ winograd_conv_2d_weight_transform(
|
||||
for (int kh = 0; kh < R; ++kh) {
|
||||
for (int kw = 0; kw < R; ++kw) {
|
||||
for (int kc = simd_lane_id; kc < BC; kc += 32) {
|
||||
Ws[simd_group_id][kh][kw][kc] = wt_in[kh * R * C + kw * C + kc];
|
||||
if (do_flip) {
|
||||
Ws[simd_group_id][R - 1 - kh][R - 1 - kw][kc] =
|
||||
wt_in[kh * R * C + kw * C + kc];
|
||||
} else {
|
||||
Ws[simd_group_id][kh][kw][kc] = wt_in[kh * R * C + kw * C + kc];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -398,10 +409,10 @@ winograd_conv_2d_weight_transform(
|
||||
}
|
||||
}
|
||||
|
||||
#define instantiate_winograd_conv_2d_weight_transform_base(name, itype, bc) \
|
||||
template [[host_name("winograd_conv_2d_weight_transform_" #name \
|
||||
"_bc" #bc)]] [[kernel]] void \
|
||||
winograd_conv_2d_weight_transform<itype, bc>( \
|
||||
#define instantiate_winograd_conv_2d_weight_tr_base_2(name, itype, bc, f) \
|
||||
template [[host_name("winograd_conv_2d_weight_transform_" #name "_bc" #bc \
|
||||
"_flip" #f)]] [[kernel]] void \
|
||||
winograd_conv_2d_weight_transform<itype, bc, 4, f>( \
|
||||
const device itype* wt_in [[buffer(0)]], \
|
||||
device itype* wt_out [[buffer(1)]], \
|
||||
const constant int& C [[buffer(2)]], \
|
||||
@@ -410,6 +421,10 @@ winograd_conv_2d_weight_transform(
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]], \
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]]);
|
||||
|
||||
#define instantiate_winograd_conv_2d_weight_transform_base(name, itype, bc) \
|
||||
instantiate_winograd_conv_2d_weight_tr_base_2(name, itype, bc, 0) \
|
||||
instantiate_winograd_conv_2d_weight_tr_base_2(name, itype, bc, 1)
|
||||
|
||||
template <typename T, int BC, int WM, int WN, int M = 6, int R = 3>
|
||||
[[kernel, max_total_threads_per_threadgroup(WM* WN * 32)]] void
|
||||
winograd_conv_2d_input_transform(
|
||||
@@ -445,10 +460,17 @@ winograd_conv_2d_input_transform(
|
||||
// Resolve input tile
|
||||
constexpr int TH = (A / WM);
|
||||
constexpr int TW = (A / WN);
|
||||
int kh = TH * (simd_group_id / WN);
|
||||
int kw = TW * (simd_group_id % WN);
|
||||
int bh = M * tid.y + kh;
|
||||
int bw = M * tid.x + kw;
|
||||
const int kh = TH * (simd_group_id / WN);
|
||||
const int kw = TW * (simd_group_id % WN);
|
||||
const int bh = M * tid.y + kh - params.pad[1];
|
||||
const int bw = M * tid.x + kw - params.pad[0];
|
||||
|
||||
const bool is_edge_w_lo = bw < 0;
|
||||
const bool is_edge_h_lo = bh < 0;
|
||||
const bool is_edge_w_hi = bw + (TW - 1) >= params.iS[0];
|
||||
const bool is_edge_h_hi = bh + (TH - 1) >= params.iS[1];
|
||||
const bool is_edge =
|
||||
is_edge_w_lo || is_edge_h_lo || is_edge_w_hi || is_edge_h_hi;
|
||||
|
||||
// Move to the correct input tile
|
||||
inp_in += tid.z * params.in_strides[0] + bh * params.in_strides[1] +
|
||||
@@ -484,8 +506,21 @@ winograd_conv_2d_input_transform(
|
||||
for (int h = 0; h < TH; h++) {
|
||||
for (int w = 0; w < TW; w++) {
|
||||
const device T* in_ptr = inp_in + jump_in[h][w];
|
||||
for (int c = simd_lane_id; c < BC; c += 32) {
|
||||
Is[kh + h][kw + w][c] = in_ptr[c];
|
||||
if (is_edge) {
|
||||
if (((bh + h) < 0 || (bh + h) >= params.iS[1]) ||
|
||||
((bw + w) < 0 || (bw + w) >= params.iS[0])) {
|
||||
for (int c = simd_lane_id; c < BC; c += 32) {
|
||||
Is[kh + h][kw + w][c] = T(0);
|
||||
}
|
||||
} else {
|
||||
for (int c = simd_lane_id; c < BC; c += 32) {
|
||||
Is[kh + h][kw + w][c] = in_ptr[c];
|
||||
}
|
||||
}
|
||||
} else {
|
||||
for (int c = simd_lane_id; c < BC; c += 32) {
|
||||
Is[kh + h][kw + w][c] = in_ptr[c];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -652,3 +687,373 @@ winograd_conv_2d_output_transform(
|
||||
instantiate_winograd_conv_2d(float32, float);
|
||||
instantiate_winograd_conv_2d(bfloat16, bfloat16_t);
|
||||
instantiate_winograd_conv_2d(float16, half); // clang-format on
|
||||
|
||||
#include "mlx/backend/metal/kernels/steel/attn/mma.h"
|
||||
|
||||
template <
|
||||
typename T,
|
||||
bool do_flip = false,
|
||||
int WM = 4,
|
||||
int WN = 1,
|
||||
typename AccumType = float>
|
||||
[[kernel]] void winograd_fused(
|
||||
const device T* input [[buffer(0)]],
|
||||
const device T* weight [[buffer(1)]],
|
||||
device T* output [[buffer(2)]],
|
||||
const constant MLXConvParams<2>& params [[buffer(3)]],
|
||||
uint3 tid [[threadgroup_position_in_grid]],
|
||||
uint3 lid [[thread_position_in_threadgroup]],
|
||||
uint3 tgp_per_grid [[threadgroups_per_grid]],
|
||||
ushort simd_group_id [[simdgroup_index_in_threadgroup]],
|
||||
ushort simd_lane_id [[thread_index_in_simdgroup]]) {
|
||||
using namespace mlx::steel;
|
||||
|
||||
(void)tgp_per_grid;
|
||||
|
||||
// Winograd F(n x n, r x r)
|
||||
// n x n output window
|
||||
constexpr short FN = 2;
|
||||
// r x r filter size
|
||||
constexpr short FR = 3;
|
||||
// a x a input window, a = n + r - 1
|
||||
constexpr short FA = 4;
|
||||
|
||||
constexpr short kFragSize = 8; // MMA frag size
|
||||
|
||||
constexpr short BT = 8; // Tile block size
|
||||
constexpr short BO = 8; // Output channel block size
|
||||
constexpr short BC = 8; // Input channel block size
|
||||
|
||||
// clang-format off
|
||||
static_assert(BT % (1 * kFragSize) == 0 &&
|
||||
BO % (1 * kFragSize) == 0 &&
|
||||
BC % kFragSize == 0,
|
||||
"Matmuls sizes must be compatible with fragments");
|
||||
// clang-format on
|
||||
|
||||
// Prepare for matmul
|
||||
|
||||
// Warp tile sizes for matmul
|
||||
constexpr short TM = (FA * FA * BT) / (WM * kFragSize);
|
||||
constexpr short TN = (BO) / (WN * kFragSize);
|
||||
constexpr short TK = (BC) / (kFragSize);
|
||||
|
||||
// Warp primitives
|
||||
using MMAFrag_acc_t = BaseMMAFrag<AccumType, kFragSize, kFragSize>;
|
||||
|
||||
// Warp tiles sizes for matmul
|
||||
MMATile<AccumType, 1, TK, MMAFrag_acc_t> Itile;
|
||||
MMATile<AccumType, TK, TN, MMAFrag_acc_t> Wtile;
|
||||
MMATile<AccumType, 1, TN, MMAFrag_acc_t> Otile[TM];
|
||||
|
||||
for (int im = 0; im < 4; im++) {
|
||||
Otile[im].clear();
|
||||
}
|
||||
|
||||
// Threadgroup memory for Weights and Inputs
|
||||
constexpr short BS = BT > BO ? BT : BO;
|
||||
threadgroup T Wt[FA * FA * BC * BO];
|
||||
threadgroup T It[FA * FA * BS * BS];
|
||||
|
||||
// Get thread position in tile
|
||||
short2 simd_coord = MMAFrag_acc_t::get_coord(simd_lane_id);
|
||||
const short sm = simd_coord.y;
|
||||
const short sn = simd_coord.x;
|
||||
|
||||
static_assert(FA * FA * BT == 32 * WM * WN, "Each thread loads one pixel.");
|
||||
const int thr_idx = simd_group_id * 32 + simd_lane_id;
|
||||
const int thr_t = thr_idx / (FA * FA);
|
||||
const int thr_hw = thr_idx % (FA * FA);
|
||||
const int thr_h = thr_hw / FA;
|
||||
const int thr_w = thr_hw % FA;
|
||||
|
||||
// Get batch, tile, and output idx for warp
|
||||
const int b_idx = tid.z;
|
||||
const int t_idx = BT * tid.y + thr_t;
|
||||
const int o_idx = BO * tid.x + thr_t;
|
||||
|
||||
// Divide tile into h, w tile
|
||||
uniform<int> oHu = make_uniform(params.oS[0]);
|
||||
uniform<int> oWu = make_uniform(params.oS[1]);
|
||||
uniform<int> tHu = (oHu + make_uniform(FN - 1)) / make_uniform(FN);
|
||||
uniform<int> tWu = (oWu + make_uniform(FN - 1)) / make_uniform(FN);
|
||||
|
||||
const int oH_idx = FN * (t_idx / tWu);
|
||||
const int oW_idx = FN * (t_idx % tWu);
|
||||
const int iH_idx = oH_idx + thr_h - params.pad[0];
|
||||
const int iW_idx = oW_idx + thr_w - params.pad[1];
|
||||
|
||||
// Move to correct location
|
||||
|
||||
// clang-format off
|
||||
input += b_idx * params.in_strides[0] + // N
|
||||
iH_idx * params.in_strides[1] + // H
|
||||
iW_idx * params.in_strides[2]; // W
|
||||
|
||||
weight += o_idx * params.wt_strides[0] + // O
|
||||
thr_h * params.wt_strides[1] + // H
|
||||
thr_w * params.wt_strides[2]; // W
|
||||
// clang-format on
|
||||
|
||||
// Do edge check prep for input
|
||||
const bool is_edge_w_lo = iH_idx < 0;
|
||||
const bool is_edge_h_lo = iW_idx < 0;
|
||||
const bool is_edge_w_hi = iH_idx >= params.iS[0];
|
||||
const bool is_edge_h_hi = iW_idx >= params.iS[1];
|
||||
const bool is_edge =
|
||||
is_edge_w_lo || is_edge_h_lo || is_edge_w_hi || is_edge_h_hi;
|
||||
|
||||
// Iterate over C
|
||||
for (int c = 0; c < params.C; c += BC) {
|
||||
#define tmp_load_wt_idx(o, h, w, c) h* FA* BC* BO + w* BC* BO + c* BO + o
|
||||
#define tmp_load_in_idx(t, h, w, c) h* FA* BS* BC + w* BS* BC + t* BC + c
|
||||
|
||||
#define tmp_trns_wt_idx(o, h, w, c) h* FA* BC* BO + w* BC* BO + c* BO + o
|
||||
#define tmp_trns_in_idx(t, h, w, c) h* FA* BS* BC + w* BS* BC + t* BC + c
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
// Load weight
|
||||
if (thr_h < FR && thr_w < FR && thr_t < BO) {
|
||||
for (int ic = 0; ic < BC; ic++) {
|
||||
if (do_flip) {
|
||||
Wt[tmp_load_wt_idx(thr_t, FR - 1 - thr_h, FR - 1 - thr_w, ic)] =
|
||||
weight[c + ic];
|
||||
} else {
|
||||
Wt[tmp_load_wt_idx(thr_t, thr_h, thr_w, ic)] = weight[c + ic];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Load input
|
||||
if (is_edge) {
|
||||
for (int ic = 0; ic < BC; ic++) {
|
||||
It[tmp_load_in_idx(thr_t, thr_h, thr_w, ic)] = T(0);
|
||||
}
|
||||
} else {
|
||||
for (int ic = 0; ic < BC; ic++) {
|
||||
It[tmp_load_in_idx(thr_t, thr_h, thr_w, ic)] = input[c + ic];
|
||||
}
|
||||
}
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
// Transform weight
|
||||
if (lid.z == 0) {
|
||||
const short ic = lid.y;
|
||||
const short io = lid.x;
|
||||
|
||||
T tmp_0[4][4];
|
||||
T tmp_1[4][4];
|
||||
|
||||
for (int ii = 0; ii < 3; ++ii) {
|
||||
for (int jj = 0; jj < 3; ++jj) {
|
||||
tmp_0[ii][jj] = Wt[tmp_load_wt_idx(io, ii, jj, ic)];
|
||||
}
|
||||
}
|
||||
|
||||
//////////////////////////////////////////////
|
||||
|
||||
tmp_1[0][0] = tmp_0[0][0];
|
||||
tmp_1[0][1] = tmp_0[0][1];
|
||||
tmp_1[0][2] = tmp_0[0][2];
|
||||
|
||||
tmp_1[1][0] = T(0.5) * (tmp_0[0][0] + tmp_0[1][0] + tmp_0[2][0]);
|
||||
tmp_1[1][1] = T(0.5) * (tmp_0[0][1] + tmp_0[1][1] + tmp_0[2][1]);
|
||||
tmp_1[1][2] = T(0.5) * (tmp_0[0][2] + tmp_0[1][2] + tmp_0[2][2]);
|
||||
|
||||
tmp_1[2][0] = tmp_1[1][0] - tmp_0[1][0];
|
||||
tmp_1[2][1] = tmp_1[1][1] - tmp_0[1][1];
|
||||
tmp_1[2][2] = tmp_1[1][2] - tmp_0[1][2];
|
||||
|
||||
tmp_1[3][0] = tmp_0[2][0];
|
||||
tmp_1[3][1] = tmp_0[2][1];
|
||||
tmp_1[3][2] = tmp_0[2][2];
|
||||
|
||||
//////////////////////////////////////////////
|
||||
tmp_0[0][0] = tmp_1[0][0];
|
||||
tmp_0[1][0] = tmp_1[1][0];
|
||||
tmp_0[2][0] = tmp_1[2][0];
|
||||
tmp_0[3][0] = tmp_1[3][0];
|
||||
|
||||
tmp_0[0][1] = T(0.5) * (tmp_1[0][0] + tmp_1[0][1] + tmp_1[0][2]);
|
||||
tmp_0[1][1] = T(0.5) * (tmp_1[1][0] + tmp_1[1][1] + tmp_1[1][2]);
|
||||
tmp_0[2][1] = T(0.5) * (tmp_1[2][0] + tmp_1[2][1] + tmp_1[2][2]);
|
||||
tmp_0[3][1] = T(0.5) * (tmp_1[3][0] + tmp_1[3][1] + tmp_1[3][2]);
|
||||
|
||||
tmp_0[0][2] = tmp_0[0][1] - tmp_1[0][1];
|
||||
tmp_0[1][2] = tmp_0[1][1] - tmp_1[1][1];
|
||||
tmp_0[2][2] = tmp_0[2][1] - tmp_1[2][1];
|
||||
tmp_0[3][2] = tmp_0[3][1] - tmp_1[3][1];
|
||||
|
||||
tmp_0[0][3] = tmp_1[0][2];
|
||||
tmp_0[1][3] = tmp_1[1][2];
|
||||
tmp_0[2][3] = tmp_1[2][2];
|
||||
tmp_0[3][3] = tmp_1[3][2];
|
||||
|
||||
for (int ii = 0; ii < 4; ++ii) {
|
||||
for (int jj = 0; jj < 4; ++jj) {
|
||||
Wt[tmp_trns_wt_idx(io, ii, jj, ic)] = tmp_0[ii][jj];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Transform input
|
||||
else {
|
||||
const short it = lid.y;
|
||||
const short ic = lid.x;
|
||||
|
||||
T tmp_0[4][4];
|
||||
T tmp_1[4][4];
|
||||
|
||||
for (int ii = 0; ii < 4; ++ii) {
|
||||
for (int jj = 0; jj < 4; ++jj) {
|
||||
tmp_0[ii][jj] = It[tmp_load_in_idx(it, ii, jj, ic)];
|
||||
}
|
||||
}
|
||||
|
||||
//////////////////////////////////////////////
|
||||
|
||||
tmp_1[0][0] = tmp_0[0][0] - tmp_0[2][0];
|
||||
tmp_1[0][1] = tmp_0[0][1] - tmp_0[2][1];
|
||||
tmp_1[0][2] = tmp_0[0][2] - tmp_0[2][2];
|
||||
tmp_1[0][3] = tmp_0[0][3] - tmp_0[2][3];
|
||||
|
||||
tmp_1[1][0] = tmp_0[1][0] + tmp_0[2][0];
|
||||
tmp_1[1][1] = tmp_0[1][1] + tmp_0[2][1];
|
||||
tmp_1[1][2] = tmp_0[1][2] + tmp_0[2][2];
|
||||
tmp_1[1][3] = tmp_0[1][3] + tmp_0[2][3];
|
||||
|
||||
tmp_1[2][0] = tmp_0[2][0] - tmp_0[1][0];
|
||||
tmp_1[2][1] = tmp_0[2][1] - tmp_0[1][1];
|
||||
tmp_1[2][2] = tmp_0[2][2] - tmp_0[1][2];
|
||||
tmp_1[2][3] = tmp_0[2][3] - tmp_0[1][3];
|
||||
|
||||
tmp_1[3][0] = tmp_0[1][0] - tmp_0[3][0];
|
||||
tmp_1[3][1] = tmp_0[1][1] - tmp_0[3][1];
|
||||
tmp_1[3][2] = tmp_0[1][2] - tmp_0[3][2];
|
||||
tmp_1[3][3] = tmp_0[1][3] - tmp_0[3][3];
|
||||
|
||||
//////////////////////////////////////////////
|
||||
tmp_0[0][0] = tmp_1[0][0] - tmp_1[0][2];
|
||||
tmp_0[1][0] = tmp_1[1][0] - tmp_1[1][2];
|
||||
tmp_0[2][0] = tmp_1[2][0] - tmp_1[2][2];
|
||||
tmp_0[3][0] = tmp_1[3][0] - tmp_1[3][2];
|
||||
|
||||
tmp_0[0][1] = tmp_1[0][1] + tmp_1[0][2];
|
||||
tmp_0[1][1] = tmp_1[1][1] + tmp_1[1][2];
|
||||
tmp_0[2][1] = tmp_1[2][1] + tmp_1[2][2];
|
||||
tmp_0[3][1] = tmp_1[3][1] + tmp_1[3][2];
|
||||
|
||||
tmp_0[0][2] = tmp_1[0][2] - tmp_1[0][1];
|
||||
tmp_0[1][2] = tmp_1[1][2] - tmp_1[1][1];
|
||||
tmp_0[2][2] = tmp_1[2][2] - tmp_1[2][1];
|
||||
tmp_0[3][2] = tmp_1[3][2] - tmp_1[3][1];
|
||||
|
||||
tmp_0[0][3] = tmp_1[0][1] - tmp_1[0][3];
|
||||
tmp_0[1][3] = tmp_1[1][1] - tmp_1[1][3];
|
||||
tmp_0[2][3] = tmp_1[2][1] - tmp_1[2][3];
|
||||
tmp_0[3][3] = tmp_1[3][1] - tmp_1[3][3];
|
||||
|
||||
for (int ii = 0; ii < 4; ++ii) {
|
||||
for (int jj = 0; jj < 4; ++jj) {
|
||||
It[tmp_trns_in_idx(it, ii, jj, ic)] = tmp_0[ii][jj];
|
||||
}
|
||||
}
|
||||
}
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
// Do matmul
|
||||
for (int im = 0; im < 4; im++) {
|
||||
simdgroup_barrier(mem_flags::mem_none);
|
||||
Itile.template load<T, 1, 1, BS, 1>(
|
||||
&It[simd_group_id * FA * BS * BS + im * BS * BS + sm * BS + sn]);
|
||||
simdgroup_barrier(mem_flags::mem_none);
|
||||
Wtile.template load<T, 1, 1, BO, 1>(
|
||||
&Wt[simd_group_id * FA * BC * BO + im * BC * BO + sm * BO + sn]);
|
||||
simdgroup_barrier(mem_flags::mem_none);
|
||||
tile_matmad(Otile[im], Itile, Wtile, Otile[im]);
|
||||
}
|
||||
}
|
||||
|
||||
// Transform and write output
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
for (int im = 0; im < 4; im++) {
|
||||
Otile[im].template store<T, 1, 1, BS, 1>(
|
||||
&It[simd_group_id * FA * BS * BS + im * BS * BS + sm * BS + sn]);
|
||||
}
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
if (lid.z == 0) {
|
||||
const short it = lid.y;
|
||||
const short io = lid.x;
|
||||
|
||||
T tmp_0[4][4];
|
||||
T tmp_1[2][4];
|
||||
T tmp_2[2][2];
|
||||
|
||||
for (int ii = 0; ii < 4; ++ii) {
|
||||
for (int jj = 0; jj < 4; ++jj) {
|
||||
tmp_0[ii][jj] = It[tmp_trns_in_idx(it, ii, jj, io)];
|
||||
}
|
||||
}
|
||||
|
||||
tmp_1[0][0] = tmp_0[0][0] + tmp_0[1][0] + tmp_0[2][0];
|
||||
tmp_1[0][1] = tmp_0[0][1] + tmp_0[1][1] + tmp_0[2][1];
|
||||
tmp_1[0][2] = tmp_0[0][2] + tmp_0[1][2] + tmp_0[2][2];
|
||||
tmp_1[0][3] = tmp_0[0][3] + tmp_0[1][3] + tmp_0[2][3];
|
||||
|
||||
tmp_1[1][0] = tmp_0[1][0] - tmp_0[2][0] - tmp_0[3][0];
|
||||
tmp_1[1][1] = tmp_0[1][1] - tmp_0[2][1] - tmp_0[3][1];
|
||||
tmp_1[1][2] = tmp_0[1][2] - tmp_0[2][2] - tmp_0[3][2];
|
||||
tmp_1[1][3] = tmp_0[1][3] - tmp_0[2][3] - tmp_0[3][3];
|
||||
|
||||
tmp_2[0][0] = tmp_1[0][0] + tmp_1[0][1] + tmp_1[0][2];
|
||||
tmp_2[1][0] = tmp_1[1][0] + tmp_1[1][1] + tmp_1[1][2];
|
||||
|
||||
tmp_2[0][1] = tmp_1[0][1] - tmp_1[0][2] - tmp_1[0][3];
|
||||
tmp_2[1][1] = tmp_1[1][1] - tmp_1[1][2] - tmp_1[1][3];
|
||||
|
||||
const int oH_i = FN * ((BT * tid.y + it) / tWu);
|
||||
const int oW_i = FN * ((BT * tid.y + it) % tWu);
|
||||
|
||||
// clang-format off
|
||||
output += b_idx * params.out_strides[0] + // N
|
||||
oH_i * params.out_strides[1] + // H
|
||||
oW_i * params.out_strides[2] + // W
|
||||
BO * tid.x; // C
|
||||
|
||||
// clang-format on
|
||||
|
||||
output[0 * params.out_strides[1] + 0 * params.out_strides[2] + io] =
|
||||
tmp_2[0][0];
|
||||
output[0 * params.out_strides[1] + 1 * params.out_strides[2] + io] =
|
||||
tmp_2[0][1];
|
||||
output[1 * params.out_strides[1] + 0 * params.out_strides[2] + io] =
|
||||
tmp_2[1][0];
|
||||
output[1 * params.out_strides[1] + 1 * params.out_strides[2] + io] =
|
||||
tmp_2[1][1];
|
||||
}
|
||||
}
|
||||
|
||||
// clang-format off
|
||||
#define instantiate_winograd_conv_2d_fused(name, itype, f) \
|
||||
template [[host_name("winograd_conv_2d_fused_" #name "_flip" #f)]] \
|
||||
[[kernel]] void winograd_fused<itype, f>( \
|
||||
const device itype* input [[buffer(0)]], \
|
||||
const device itype* weight [[buffer(1)]], \
|
||||
device itype* output [[buffer(2)]], \
|
||||
const constant MLXConvParams<2>& params [[buffer(3)]], \
|
||||
uint3 tid [[threadgroup_position_in_grid]], \
|
||||
uint3 lid [[thread_position_in_threadgroup]], \
|
||||
uint3 tgp_per_grid [[threadgroups_per_grid]], \
|
||||
ushort simd_group_id [[simdgroup_index_in_threadgroup]], \
|
||||
ushort simd_lane_id [[thread_index_in_simdgroup]]);
|
||||
|
||||
#define instantiate_winograd_conv_2d_fused_2(name, itype) \
|
||||
instantiate_winograd_conv_2d_fused(name, itype, 0) \
|
||||
instantiate_winograd_conv_2d_fused(name, itype, 1)
|
||||
|
||||
instantiate_winograd_conv_2d_fused_2(float32, float);
|
||||
instantiate_winograd_conv_2d_fused_2(float16, float16_t);
|
||||
instantiate_winograd_conv_2d_fused_2(bfloat16, bfloat16_t);
|
||||
// clang-format on
|
||||
Reference in New Issue
Block a user