Winograd Update for Small batches (#1803)

* Build in padding to Winograd kernels
* Add new fused Winograd kernel
* Enable weight flipping in Winograd kernels
This commit is contained in:
Jagrit Digani
2025-02-14 13:08:13 -08:00
committed by GitHub
parent 7aea5b1895
commit 2dc307f2e6
4 changed files with 505 additions and 86 deletions

View File

@@ -326,7 +326,13 @@ constant constexpr const float WinogradTransforms<6, 3, 8>::wt_transform[8][8];
constant constexpr const float WinogradTransforms<6, 3, 8>::in_transform[8][8];
constant constexpr const float WinogradTransforms<6, 3, 8>::out_transform[8][8];
template <typename T, int BC = 32, int BO = 4, int M = 6, int R = 3>
template <
typename T,
int BC = 32,
int BO = 4,
bool do_flip = false,
int M = 6,
int R = 3>
[[kernel, max_total_threads_per_threadgroup(BO * 32)]] void
winograd_conv_2d_weight_transform(
const device T* wt_in [[buffer(0)]],
@@ -373,7 +379,12 @@ winograd_conv_2d_weight_transform(
for (int kh = 0; kh < R; ++kh) {
for (int kw = 0; kw < R; ++kw) {
for (int kc = simd_lane_id; kc < BC; kc += 32) {
Ws[simd_group_id][kh][kw][kc] = wt_in[kh * R * C + kw * C + kc];
if (do_flip) {
Ws[simd_group_id][R - 1 - kh][R - 1 - kw][kc] =
wt_in[kh * R * C + kw * C + kc];
} else {
Ws[simd_group_id][kh][kw][kc] = wt_in[kh * R * C + kw * C + kc];
}
}
}
}
@@ -398,10 +409,10 @@ winograd_conv_2d_weight_transform(
}
}
#define instantiate_winograd_conv_2d_weight_transform_base(name, itype, bc) \
template [[host_name("winograd_conv_2d_weight_transform_" #name \
"_bc" #bc)]] [[kernel]] void \
winograd_conv_2d_weight_transform<itype, bc>( \
#define instantiate_winograd_conv_2d_weight_tr_base_2(name, itype, bc, f) \
template [[host_name("winograd_conv_2d_weight_transform_" #name "_bc" #bc \
"_flip" #f)]] [[kernel]] void \
winograd_conv_2d_weight_transform<itype, bc, 4, f>( \
const device itype* wt_in [[buffer(0)]], \
device itype* wt_out [[buffer(1)]], \
const constant int& C [[buffer(2)]], \
@@ -410,6 +421,10 @@ winograd_conv_2d_weight_transform(
uint simd_group_id [[simdgroup_index_in_threadgroup]], \
uint simd_lane_id [[thread_index_in_simdgroup]]);
#define instantiate_winograd_conv_2d_weight_transform_base(name, itype, bc) \
instantiate_winograd_conv_2d_weight_tr_base_2(name, itype, bc, 0) \
instantiate_winograd_conv_2d_weight_tr_base_2(name, itype, bc, 1)
template <typename T, int BC, int WM, int WN, int M = 6, int R = 3>
[[kernel, max_total_threads_per_threadgroup(WM* WN * 32)]] void
winograd_conv_2d_input_transform(
@@ -445,10 +460,17 @@ winograd_conv_2d_input_transform(
// Resolve input tile
constexpr int TH = (A / WM);
constexpr int TW = (A / WN);
int kh = TH * (simd_group_id / WN);
int kw = TW * (simd_group_id % WN);
int bh = M * tid.y + kh;
int bw = M * tid.x + kw;
const int kh = TH * (simd_group_id / WN);
const int kw = TW * (simd_group_id % WN);
const int bh = M * tid.y + kh - params.pad[1];
const int bw = M * tid.x + kw - params.pad[0];
const bool is_edge_w_lo = bw < 0;
const bool is_edge_h_lo = bh < 0;
const bool is_edge_w_hi = bw + (TW - 1) >= params.iS[0];
const bool is_edge_h_hi = bh + (TH - 1) >= params.iS[1];
const bool is_edge =
is_edge_w_lo || is_edge_h_lo || is_edge_w_hi || is_edge_h_hi;
// Move to the correct input tile
inp_in += tid.z * params.in_strides[0] + bh * params.in_strides[1] +
@@ -484,8 +506,21 @@ winograd_conv_2d_input_transform(
for (int h = 0; h < TH; h++) {
for (int w = 0; w < TW; w++) {
const device T* in_ptr = inp_in + jump_in[h][w];
for (int c = simd_lane_id; c < BC; c += 32) {
Is[kh + h][kw + w][c] = in_ptr[c];
if (is_edge) {
if (((bh + h) < 0 || (bh + h) >= params.iS[1]) ||
((bw + w) < 0 || (bw + w) >= params.iS[0])) {
for (int c = simd_lane_id; c < BC; c += 32) {
Is[kh + h][kw + w][c] = T(0);
}
} else {
for (int c = simd_lane_id; c < BC; c += 32) {
Is[kh + h][kw + w][c] = in_ptr[c];
}
}
} else {
for (int c = simd_lane_id; c < BC; c += 32) {
Is[kh + h][kw + w][c] = in_ptr[c];
}
}
}
}
@@ -652,3 +687,373 @@ winograd_conv_2d_output_transform(
instantiate_winograd_conv_2d(float32, float);
instantiate_winograd_conv_2d(bfloat16, bfloat16_t);
instantiate_winograd_conv_2d(float16, half); // clang-format on
#include "mlx/backend/metal/kernels/steel/attn/mma.h"
template <
typename T,
bool do_flip = false,
int WM = 4,
int WN = 1,
typename AccumType = float>
[[kernel]] void winograd_fused(
const device T* input [[buffer(0)]],
const device T* weight [[buffer(1)]],
device T* output [[buffer(2)]],
const constant MLXConvParams<2>& params [[buffer(3)]],
uint3 tid [[threadgroup_position_in_grid]],
uint3 lid [[thread_position_in_threadgroup]],
uint3 tgp_per_grid [[threadgroups_per_grid]],
ushort simd_group_id [[simdgroup_index_in_threadgroup]],
ushort simd_lane_id [[thread_index_in_simdgroup]]) {
using namespace mlx::steel;
(void)tgp_per_grid;
// Winograd F(n x n, r x r)
// n x n output window
constexpr short FN = 2;
// r x r filter size
constexpr short FR = 3;
// a x a input window, a = n + r - 1
constexpr short FA = 4;
constexpr short kFragSize = 8; // MMA frag size
constexpr short BT = 8; // Tile block size
constexpr short BO = 8; // Output channel block size
constexpr short BC = 8; // Input channel block size
// clang-format off
static_assert(BT % (1 * kFragSize) == 0 &&
BO % (1 * kFragSize) == 0 &&
BC % kFragSize == 0,
"Matmuls sizes must be compatible with fragments");
// clang-format on
// Prepare for matmul
// Warp tile sizes for matmul
constexpr short TM = (FA * FA * BT) / (WM * kFragSize);
constexpr short TN = (BO) / (WN * kFragSize);
constexpr short TK = (BC) / (kFragSize);
// Warp primitives
using MMAFrag_acc_t = BaseMMAFrag<AccumType, kFragSize, kFragSize>;
// Warp tiles sizes for matmul
MMATile<AccumType, 1, TK, MMAFrag_acc_t> Itile;
MMATile<AccumType, TK, TN, MMAFrag_acc_t> Wtile;
MMATile<AccumType, 1, TN, MMAFrag_acc_t> Otile[TM];
for (int im = 0; im < 4; im++) {
Otile[im].clear();
}
// Threadgroup memory for Weights and Inputs
constexpr short BS = BT > BO ? BT : BO;
threadgroup T Wt[FA * FA * BC * BO];
threadgroup T It[FA * FA * BS * BS];
// Get thread position in tile
short2 simd_coord = MMAFrag_acc_t::get_coord(simd_lane_id);
const short sm = simd_coord.y;
const short sn = simd_coord.x;
static_assert(FA * FA * BT == 32 * WM * WN, "Each thread loads one pixel.");
const int thr_idx = simd_group_id * 32 + simd_lane_id;
const int thr_t = thr_idx / (FA * FA);
const int thr_hw = thr_idx % (FA * FA);
const int thr_h = thr_hw / FA;
const int thr_w = thr_hw % FA;
// Get batch, tile, and output idx for warp
const int b_idx = tid.z;
const int t_idx = BT * tid.y + thr_t;
const int o_idx = BO * tid.x + thr_t;
// Divide tile into h, w tile
uniform<int> oHu = make_uniform(params.oS[0]);
uniform<int> oWu = make_uniform(params.oS[1]);
uniform<int> tHu = (oHu + make_uniform(FN - 1)) / make_uniform(FN);
uniform<int> tWu = (oWu + make_uniform(FN - 1)) / make_uniform(FN);
const int oH_idx = FN * (t_idx / tWu);
const int oW_idx = FN * (t_idx % tWu);
const int iH_idx = oH_idx + thr_h - params.pad[0];
const int iW_idx = oW_idx + thr_w - params.pad[1];
// Move to correct location
// clang-format off
input += b_idx * params.in_strides[0] + // N
iH_idx * params.in_strides[1] + // H
iW_idx * params.in_strides[2]; // W
weight += o_idx * params.wt_strides[0] + // O
thr_h * params.wt_strides[1] + // H
thr_w * params.wt_strides[2]; // W
// clang-format on
// Do edge check prep for input
const bool is_edge_w_lo = iH_idx < 0;
const bool is_edge_h_lo = iW_idx < 0;
const bool is_edge_w_hi = iH_idx >= params.iS[0];
const bool is_edge_h_hi = iW_idx >= params.iS[1];
const bool is_edge =
is_edge_w_lo || is_edge_h_lo || is_edge_w_hi || is_edge_h_hi;
// Iterate over C
for (int c = 0; c < params.C; c += BC) {
#define tmp_load_wt_idx(o, h, w, c) h* FA* BC* BO + w* BC* BO + c* BO + o
#define tmp_load_in_idx(t, h, w, c) h* FA* BS* BC + w* BS* BC + t* BC + c
#define tmp_trns_wt_idx(o, h, w, c) h* FA* BC* BO + w* BC* BO + c* BO + o
#define tmp_trns_in_idx(t, h, w, c) h* FA* BS* BC + w* BS* BC + t* BC + c
threadgroup_barrier(mem_flags::mem_threadgroup);
// Load weight
if (thr_h < FR && thr_w < FR && thr_t < BO) {
for (int ic = 0; ic < BC; ic++) {
if (do_flip) {
Wt[tmp_load_wt_idx(thr_t, FR - 1 - thr_h, FR - 1 - thr_w, ic)] =
weight[c + ic];
} else {
Wt[tmp_load_wt_idx(thr_t, thr_h, thr_w, ic)] = weight[c + ic];
}
}
}
// Load input
if (is_edge) {
for (int ic = 0; ic < BC; ic++) {
It[tmp_load_in_idx(thr_t, thr_h, thr_w, ic)] = T(0);
}
} else {
for (int ic = 0; ic < BC; ic++) {
It[tmp_load_in_idx(thr_t, thr_h, thr_w, ic)] = input[c + ic];
}
}
threadgroup_barrier(mem_flags::mem_threadgroup);
// Transform weight
if (lid.z == 0) {
const short ic = lid.y;
const short io = lid.x;
T tmp_0[4][4];
T tmp_1[4][4];
for (int ii = 0; ii < 3; ++ii) {
for (int jj = 0; jj < 3; ++jj) {
tmp_0[ii][jj] = Wt[tmp_load_wt_idx(io, ii, jj, ic)];
}
}
//////////////////////////////////////////////
tmp_1[0][0] = tmp_0[0][0];
tmp_1[0][1] = tmp_0[0][1];
tmp_1[0][2] = tmp_0[0][2];
tmp_1[1][0] = T(0.5) * (tmp_0[0][0] + tmp_0[1][0] + tmp_0[2][0]);
tmp_1[1][1] = T(0.5) * (tmp_0[0][1] + tmp_0[1][1] + tmp_0[2][1]);
tmp_1[1][2] = T(0.5) * (tmp_0[0][2] + tmp_0[1][2] + tmp_0[2][2]);
tmp_1[2][0] = tmp_1[1][0] - tmp_0[1][0];
tmp_1[2][1] = tmp_1[1][1] - tmp_0[1][1];
tmp_1[2][2] = tmp_1[1][2] - tmp_0[1][2];
tmp_1[3][0] = tmp_0[2][0];
tmp_1[3][1] = tmp_0[2][1];
tmp_1[3][2] = tmp_0[2][2];
//////////////////////////////////////////////
tmp_0[0][0] = tmp_1[0][0];
tmp_0[1][0] = tmp_1[1][0];
tmp_0[2][0] = tmp_1[2][0];
tmp_0[3][0] = tmp_1[3][0];
tmp_0[0][1] = T(0.5) * (tmp_1[0][0] + tmp_1[0][1] + tmp_1[0][2]);
tmp_0[1][1] = T(0.5) * (tmp_1[1][0] + tmp_1[1][1] + tmp_1[1][2]);
tmp_0[2][1] = T(0.5) * (tmp_1[2][0] + tmp_1[2][1] + tmp_1[2][2]);
tmp_0[3][1] = T(0.5) * (tmp_1[3][0] + tmp_1[3][1] + tmp_1[3][2]);
tmp_0[0][2] = tmp_0[0][1] - tmp_1[0][1];
tmp_0[1][2] = tmp_0[1][1] - tmp_1[1][1];
tmp_0[2][2] = tmp_0[2][1] - tmp_1[2][1];
tmp_0[3][2] = tmp_0[3][1] - tmp_1[3][1];
tmp_0[0][3] = tmp_1[0][2];
tmp_0[1][3] = tmp_1[1][2];
tmp_0[2][3] = tmp_1[2][2];
tmp_0[3][3] = tmp_1[3][2];
for (int ii = 0; ii < 4; ++ii) {
for (int jj = 0; jj < 4; ++jj) {
Wt[tmp_trns_wt_idx(io, ii, jj, ic)] = tmp_0[ii][jj];
}
}
}
// Transform input
else {
const short it = lid.y;
const short ic = lid.x;
T tmp_0[4][4];
T tmp_1[4][4];
for (int ii = 0; ii < 4; ++ii) {
for (int jj = 0; jj < 4; ++jj) {
tmp_0[ii][jj] = It[tmp_load_in_idx(it, ii, jj, ic)];
}
}
//////////////////////////////////////////////
tmp_1[0][0] = tmp_0[0][0] - tmp_0[2][0];
tmp_1[0][1] = tmp_0[0][1] - tmp_0[2][1];
tmp_1[0][2] = tmp_0[0][2] - tmp_0[2][2];
tmp_1[0][3] = tmp_0[0][3] - tmp_0[2][3];
tmp_1[1][0] = tmp_0[1][0] + tmp_0[2][0];
tmp_1[1][1] = tmp_0[1][1] + tmp_0[2][1];
tmp_1[1][2] = tmp_0[1][2] + tmp_0[2][2];
tmp_1[1][3] = tmp_0[1][3] + tmp_0[2][3];
tmp_1[2][0] = tmp_0[2][0] - tmp_0[1][0];
tmp_1[2][1] = tmp_0[2][1] - tmp_0[1][1];
tmp_1[2][2] = tmp_0[2][2] - tmp_0[1][2];
tmp_1[2][3] = tmp_0[2][3] - tmp_0[1][3];
tmp_1[3][0] = tmp_0[1][0] - tmp_0[3][0];
tmp_1[3][1] = tmp_0[1][1] - tmp_0[3][1];
tmp_1[3][2] = tmp_0[1][2] - tmp_0[3][2];
tmp_1[3][3] = tmp_0[1][3] - tmp_0[3][3];
//////////////////////////////////////////////
tmp_0[0][0] = tmp_1[0][0] - tmp_1[0][2];
tmp_0[1][0] = tmp_1[1][0] - tmp_1[1][2];
tmp_0[2][0] = tmp_1[2][0] - tmp_1[2][2];
tmp_0[3][0] = tmp_1[3][0] - tmp_1[3][2];
tmp_0[0][1] = tmp_1[0][1] + tmp_1[0][2];
tmp_0[1][1] = tmp_1[1][1] + tmp_1[1][2];
tmp_0[2][1] = tmp_1[2][1] + tmp_1[2][2];
tmp_0[3][1] = tmp_1[3][1] + tmp_1[3][2];
tmp_0[0][2] = tmp_1[0][2] - tmp_1[0][1];
tmp_0[1][2] = tmp_1[1][2] - tmp_1[1][1];
tmp_0[2][2] = tmp_1[2][2] - tmp_1[2][1];
tmp_0[3][2] = tmp_1[3][2] - tmp_1[3][1];
tmp_0[0][3] = tmp_1[0][1] - tmp_1[0][3];
tmp_0[1][3] = tmp_1[1][1] - tmp_1[1][3];
tmp_0[2][3] = tmp_1[2][1] - tmp_1[2][3];
tmp_0[3][3] = tmp_1[3][1] - tmp_1[3][3];
for (int ii = 0; ii < 4; ++ii) {
for (int jj = 0; jj < 4; ++jj) {
It[tmp_trns_in_idx(it, ii, jj, ic)] = tmp_0[ii][jj];
}
}
}
threadgroup_barrier(mem_flags::mem_threadgroup);
// Do matmul
for (int im = 0; im < 4; im++) {
simdgroup_barrier(mem_flags::mem_none);
Itile.template load<T, 1, 1, BS, 1>(
&It[simd_group_id * FA * BS * BS + im * BS * BS + sm * BS + sn]);
simdgroup_barrier(mem_flags::mem_none);
Wtile.template load<T, 1, 1, BO, 1>(
&Wt[simd_group_id * FA * BC * BO + im * BC * BO + sm * BO + sn]);
simdgroup_barrier(mem_flags::mem_none);
tile_matmad(Otile[im], Itile, Wtile, Otile[im]);
}
}
// Transform and write output
threadgroup_barrier(mem_flags::mem_threadgroup);
for (int im = 0; im < 4; im++) {
Otile[im].template store<T, 1, 1, BS, 1>(
&It[simd_group_id * FA * BS * BS + im * BS * BS + sm * BS + sn]);
}
threadgroup_barrier(mem_flags::mem_threadgroup);
if (lid.z == 0) {
const short it = lid.y;
const short io = lid.x;
T tmp_0[4][4];
T tmp_1[2][4];
T tmp_2[2][2];
for (int ii = 0; ii < 4; ++ii) {
for (int jj = 0; jj < 4; ++jj) {
tmp_0[ii][jj] = It[tmp_trns_in_idx(it, ii, jj, io)];
}
}
tmp_1[0][0] = tmp_0[0][0] + tmp_0[1][0] + tmp_0[2][0];
tmp_1[0][1] = tmp_0[0][1] + tmp_0[1][1] + tmp_0[2][1];
tmp_1[0][2] = tmp_0[0][2] + tmp_0[1][2] + tmp_0[2][2];
tmp_1[0][3] = tmp_0[0][3] + tmp_0[1][3] + tmp_0[2][3];
tmp_1[1][0] = tmp_0[1][0] - tmp_0[2][0] - tmp_0[3][0];
tmp_1[1][1] = tmp_0[1][1] - tmp_0[2][1] - tmp_0[3][1];
tmp_1[1][2] = tmp_0[1][2] - tmp_0[2][2] - tmp_0[3][2];
tmp_1[1][3] = tmp_0[1][3] - tmp_0[2][3] - tmp_0[3][3];
tmp_2[0][0] = tmp_1[0][0] + tmp_1[0][1] + tmp_1[0][2];
tmp_2[1][0] = tmp_1[1][0] + tmp_1[1][1] + tmp_1[1][2];
tmp_2[0][1] = tmp_1[0][1] - tmp_1[0][2] - tmp_1[0][3];
tmp_2[1][1] = tmp_1[1][1] - tmp_1[1][2] - tmp_1[1][3];
const int oH_i = FN * ((BT * tid.y + it) / tWu);
const int oW_i = FN * ((BT * tid.y + it) % tWu);
// clang-format off
output += b_idx * params.out_strides[0] + // N
oH_i * params.out_strides[1] + // H
oW_i * params.out_strides[2] + // W
BO * tid.x; // C
// clang-format on
output[0 * params.out_strides[1] + 0 * params.out_strides[2] + io] =
tmp_2[0][0];
output[0 * params.out_strides[1] + 1 * params.out_strides[2] + io] =
tmp_2[0][1];
output[1 * params.out_strides[1] + 0 * params.out_strides[2] + io] =
tmp_2[1][0];
output[1 * params.out_strides[1] + 1 * params.out_strides[2] + io] =
tmp_2[1][1];
}
}
// clang-format off
#define instantiate_winograd_conv_2d_fused(name, itype, f) \
template [[host_name("winograd_conv_2d_fused_" #name "_flip" #f)]] \
[[kernel]] void winograd_fused<itype, f>( \
const device itype* input [[buffer(0)]], \
const device itype* weight [[buffer(1)]], \
device itype* output [[buffer(2)]], \
const constant MLXConvParams<2>& params [[buffer(3)]], \
uint3 tid [[threadgroup_position_in_grid]], \
uint3 lid [[thread_position_in_threadgroup]], \
uint3 tgp_per_grid [[threadgroups_per_grid]], \
ushort simd_group_id [[simdgroup_index_in_threadgroup]], \
ushort simd_lane_id [[thread_index_in_simdgroup]]);
#define instantiate_winograd_conv_2d_fused_2(name, itype) \
instantiate_winograd_conv_2d_fused(name, itype, 0) \
instantiate_winograd_conv_2d_fused(name, itype, 1)
instantiate_winograd_conv_2d_fused_2(float32, float);
instantiate_winograd_conv_2d_fused_2(float16, float16_t);
instantiate_winograd_conv_2d_fused_2(bfloat16, bfloat16_t);
// clang-format on