From 2dc307f2e6ee9364f591878cdd5fccd1eef8fd58 Mon Sep 17 00:00:00 2001 From: Jagrit Digani Date: Fri, 14 Feb 2025 13:08:13 -0800 Subject: [PATCH] Winograd Update for Small batches (#1803) * Build in padding to Winograd kernels * Add new fused Winograd kernel * Enable weight flipping in Winograd kernels --- mlx/backend/metal/conv.cpp | 158 +++++----- mlx/backend/metal/kernels/conv.metal | 429 ++++++++++++++++++++++++++- mlx/ops.cpp | 2 +- python/tests/test_conv.py | 2 +- 4 files changed, 505 insertions(+), 86 deletions(-) diff --git a/mlx/backend/metal/conv.cpp b/mlx/backend/metal/conv.cpp index 3e42f7d2f..6356ad9ba 100644 --- a/mlx/backend/metal/conv.cpp +++ b/mlx/backend/metal/conv.cpp @@ -533,6 +533,45 @@ void implicit_gemm_conv_2D_general_gpu( compute_encoder.dispatch_threadgroups(grid_dims, group_dims); } +void winograd_conv_2D_fused_gpu( + const Stream& s, + metal::Device& d, + const array& in, + const array& wt, + array out, + const MLXConvParams<2>& conv_params, + std::vector& copies_w) { + int O_c = conv_params.O; + int C_c = conv_params.C; + + int N_tiles_n = conv_params.N; + int N_tiles_h = (conv_params.oS[0] + 1) / 2; + int N_tiles_w = (conv_params.oS[1] + 1) / 2; + int N_tiles = N_tiles_n * N_tiles_h * N_tiles_w; + + int bc = 32; + int wm = 4; + int wn = 1; + std::ostringstream kname; + kname << "winograd_conv_2d_fused_" << type_to_name(out) << "_flip" + << conv_params.flip; + auto& compute_encoder = d.get_command_encoder(s.index); + auto kernel = d.get_kernel(kname.str()); + compute_encoder.set_compute_pipeline_state(kernel); + + compute_encoder.set_input_array(in, 0); + compute_encoder.set_input_array(wt, 1); + compute_encoder.set_output_array(out, 2); + + compute_encoder.set_bytes(conv_params, 3); + + MTL::Size group_dims = MTL::Size(8, 8, 2); + MTL::Size grid_dims = + MTL::Size(O_c / 8, (N_tiles_h * N_tiles_w) / 8, N_tiles_n); + + compute_encoder.dispatch_threadgroups(grid_dims, group_dims); +} + void winograd_conv_2D_gpu( const Stream& s, metal::Device& d, @@ -541,67 +580,6 @@ void winograd_conv_2D_gpu( array out, const MLXConvParams<2>& conv_params, std::vector& copies_w) { - Shape padded_shape = { - conv_params.N, - conv_params.iS[0] + 2 * conv_params.pad[0], - conv_params.iS[1] + 2 * conv_params.pad[1], - conv_params.C}; - - padded_shape[1] = 6 * ((padded_shape[1] - 2 + 5) / 6) + 2; - padded_shape[2] = 6 * ((padded_shape[2] - 2 + 5) / 6) + 2; - - array in_padded(std::move(padded_shape), in.dtype(), nullptr, {}); - - // Fill with zeros - array zero_arr = array(0, in.dtype()); - fill_gpu(zero_arr, in_padded, s); - copies_w.push_back(zero_arr); - - // Pick input slice from padded - size_t data_offset = conv_params.pad[0] * in_padded.strides()[1] + - conv_params.pad[1] * in_padded.strides()[2]; - array in_padded_slice(in.shape(), in_padded.dtype(), nullptr, {}); - in_padded_slice.copy_shared_buffer( - in_padded, - in_padded.strides(), - in_padded.flags(), - in_padded_slice.size(), - data_offset); - - // Copy input values into the slice - copy_gpu_inplace(in, in_padded_slice, CopyType::GeneralGeneral, s); - - copies_w.push_back(in_padded_slice); - copies_w.push_back(in_padded); - - MLXConvParams<2> conv_params_updated{ - /* const int N = */ static_cast(in_padded.shape(0)), - /* const int C = */ static_cast(in_padded.shape(3)), - /* const int O = */ static_cast(wt.shape(0)), - /* const int iS[NDIM] = */ - {static_cast(in_padded.shape(1)), - static_cast(in_padded.shape(2))}, - /* const int wS[NDIM] = */ - {static_cast(wt.shape(1)), static_cast(wt.shape(2))}, - /* const int oS[NDIM] = */ - {static_cast(out.shape(1)), static_cast(out.shape(2))}, - /* const int str[NDIM] = */ {1, 1}, - /* const int pad[NDIM] = */ {0, 0}, - /* const int kdil[NDIM] = */ {1, 1}, - /* const int idil[NDIM] = */ {1, 1}, - /* const size_t in_strides[NDIM + 2] = */ - {in_padded.strides()[0], - in_padded.strides()[1], - in_padded.strides()[2], - in_padded.strides()[3]}, - /* const size_t wt_strides[NDIM + 2] = */ - {wt.strides()[0], wt.strides()[1], wt.strides()[2], wt.strides()[3]}, - /* const size_t out_strides[NDIM + 2] = */ - {out.strides()[0], out.strides()[1], out.strides()[2], out.strides()[3]}, - /* const int groups = */ 1, - /* const bool flip = */ false, - }; - int O_c = conv_params.O; int C_c = conv_params.C; @@ -620,7 +598,7 @@ void winograd_conv_2D_gpu( int bo = 4; std::ostringstream kname; kname << "winograd_conv_2d_weight_transform_" << type_to_name(out) << "_bc" - << bc; + << bc << "_flip" << conv_params.flip; auto& compute_encoder = d.get_command_encoder(s.index); auto kernel = d.get_kernel(kname.str()); compute_encoder.set_compute_pipeline_state(kernel); @@ -653,10 +631,10 @@ void winograd_conv_2D_gpu( auto kernel = d.get_kernel(kname.str()); compute_encoder.set_compute_pipeline_state(kernel); - compute_encoder.set_input_array(in_padded, 0); + compute_encoder.set_input_array(in, 0); compute_encoder.set_output_array(inp_wg, 1); - compute_encoder.set_bytes(conv_params_updated, 2); + compute_encoder.set_bytes(conv_params, 2); MTL::Size group_dims = MTL::Size(32, wn, wm); MTL::Size grid_dims = MTL::Size(N_tiles_w, N_tiles_h, N_tiles_n); @@ -703,7 +681,7 @@ void winograd_conv_2D_gpu( compute_encoder.set_input_array(out_wg, 0); compute_encoder.set_output_array(out, 1); - compute_encoder.set_bytes(conv_params_updated, 2); + compute_encoder.set_bytes(conv_params, 2); MTL::Size group_dims = MTL::Size(32, wn, wm); MTL::Size grid_dims = MTL::Size(N_tiles_w, N_tiles_h, N_tiles_n); @@ -767,14 +745,18 @@ void conv_2D_gpu( } // Direct to winograd conv - bool inp_large = + bool img_large = (conv_params.N * conv_params.iS[0] * conv_params.iS[1]) >= 1ul << 12; bool channels_large = (conv_params.C + conv_params.O) >= 256; - if (!flip && is_stride_one && is_kdil_one && is_idil_one && - conv_params.wS[0] == 3 && conv_params.wS[1] == 3 && - conv_params.C % 32 == 0 && conv_params.O % 32 == 0 && inp_large && - channels_large) { - return winograd_conv_2D_gpu(s, d, in, wt, out, conv_params, copies); + if (conv_params.wS[0] == 3 && conv_params.wS[1] == 3 && + conv_params.C % 32 == 0 && conv_params.O % 32 == 0 && is_stride_one && + is_kdil_one && is_idil_one) { + if (img_large && channels_large) { + return winograd_conv_2D_gpu(s, d, in, wt, out, conv_params, copies); + } + if (conv_params.N <= 1) { + return winograd_conv_2D_fused_gpu(s, d, in, wt, out, conv_params, copies); + } } // Direct to implicit gemm conv @@ -876,8 +858,40 @@ void Convolution::eval_gpu(const std::vector& inputs, array& out) { wt = arr_copy; } + // Check for 1x1 conv + auto is_one = [](int x) { return x == 1; }; + auto is_zero = [](int x) { return x == 0; }; + if (groups_ == 1 && (wt.shape(0) * wt.shape(-1) == wt.size()) && + std::all_of(wt.shape().begin() + 1, wt.shape().end() - 1, is_one) && + std::all_of(kernel_strides_.begin(), kernel_strides_.end(), is_one) && + std::all_of(input_dilation_.begin(), input_dilation_.end(), is_one) && + std::all_of(kernel_dilation_.begin(), kernel_dilation_.end(), is_one) && + std::all_of(padding_.begin(), padding_.end(), is_zero)) { + std::vector empty_copies; + steel_matmul_regular( + s, + d, + /*a = */ in, + /*b = */ wt, + /*c = */ out, + /*M = */ in.size() / in.shape(-1), + /*N = */ wt.shape(0), + /*K = */ in.shape(-1), + /*batch_size_out = */ 1, + /*lda = */ in.shape(-1), + /*ldb = */ wt.shape(-1), + /*ldd = */ wt.shape(0), + /*transpose_a = */ false, + /*transpose_b = */ true, + /*batch_shape = */ {1}, + /*batch_strides = */ {1}, + /*A_batch_stride = */ 0, + /*B_batch_stride = */ 0, + /*matrix_stride_out = */ 0, + /*copies = */ empty_copies); + } // 3D conv - if (out.ndim() == 5) { + else if (out.ndim() == 5) { conv_3D_gpu( s, d, diff --git a/mlx/backend/metal/kernels/conv.metal b/mlx/backend/metal/kernels/conv.metal index 13ee239dc..925fa7f69 100644 --- a/mlx/backend/metal/kernels/conv.metal +++ b/mlx/backend/metal/kernels/conv.metal @@ -326,7 +326,13 @@ constant constexpr const float WinogradTransforms<6, 3, 8>::wt_transform[8][8]; constant constexpr const float WinogradTransforms<6, 3, 8>::in_transform[8][8]; constant constexpr const float WinogradTransforms<6, 3, 8>::out_transform[8][8]; -template +template < + typename T, + int BC = 32, + int BO = 4, + bool do_flip = false, + int M = 6, + int R = 3> [[kernel, max_total_threads_per_threadgroup(BO * 32)]] void winograd_conv_2d_weight_transform( const device T* wt_in [[buffer(0)]], @@ -373,7 +379,12 @@ winograd_conv_2d_weight_transform( for (int kh = 0; kh < R; ++kh) { for (int kw = 0; kw < R; ++kw) { for (int kc = simd_lane_id; kc < BC; kc += 32) { - Ws[simd_group_id][kh][kw][kc] = wt_in[kh * R * C + kw * C + kc]; + if (do_flip) { + Ws[simd_group_id][R - 1 - kh][R - 1 - kw][kc] = + wt_in[kh * R * C + kw * C + kc]; + } else { + Ws[simd_group_id][kh][kw][kc] = wt_in[kh * R * C + kw * C + kc]; + } } } } @@ -398,10 +409,10 @@ winograd_conv_2d_weight_transform( } } -#define instantiate_winograd_conv_2d_weight_transform_base(name, itype, bc) \ - template [[host_name("winograd_conv_2d_weight_transform_" #name \ - "_bc" #bc)]] [[kernel]] void \ - winograd_conv_2d_weight_transform( \ +#define instantiate_winograd_conv_2d_weight_tr_base_2(name, itype, bc, f) \ + template [[host_name("winograd_conv_2d_weight_transform_" #name "_bc" #bc \ + "_flip" #f)]] [[kernel]] void \ + winograd_conv_2d_weight_transform( \ const device itype* wt_in [[buffer(0)]], \ device itype* wt_out [[buffer(1)]], \ const constant int& C [[buffer(2)]], \ @@ -410,6 +421,10 @@ winograd_conv_2d_weight_transform( uint simd_group_id [[simdgroup_index_in_threadgroup]], \ uint simd_lane_id [[thread_index_in_simdgroup]]); +#define instantiate_winograd_conv_2d_weight_transform_base(name, itype, bc) \ + instantiate_winograd_conv_2d_weight_tr_base_2(name, itype, bc, 0) \ + instantiate_winograd_conv_2d_weight_tr_base_2(name, itype, bc, 1) + template [[kernel, max_total_threads_per_threadgroup(WM* WN * 32)]] void winograd_conv_2d_input_transform( @@ -445,10 +460,17 @@ winograd_conv_2d_input_transform( // Resolve input tile constexpr int TH = (A / WM); constexpr int TW = (A / WN); - int kh = TH * (simd_group_id / WN); - int kw = TW * (simd_group_id % WN); - int bh = M * tid.y + kh; - int bw = M * tid.x + kw; + const int kh = TH * (simd_group_id / WN); + const int kw = TW * (simd_group_id % WN); + const int bh = M * tid.y + kh - params.pad[1]; + const int bw = M * tid.x + kw - params.pad[0]; + + const bool is_edge_w_lo = bw < 0; + const bool is_edge_h_lo = bh < 0; + const bool is_edge_w_hi = bw + (TW - 1) >= params.iS[0]; + const bool is_edge_h_hi = bh + (TH - 1) >= params.iS[1]; + const bool is_edge = + is_edge_w_lo || is_edge_h_lo || is_edge_w_hi || is_edge_h_hi; // Move to the correct input tile inp_in += tid.z * params.in_strides[0] + bh * params.in_strides[1] + @@ -484,8 +506,21 @@ winograd_conv_2d_input_transform( for (int h = 0; h < TH; h++) { for (int w = 0; w < TW; w++) { const device T* in_ptr = inp_in + jump_in[h][w]; - for (int c = simd_lane_id; c < BC; c += 32) { - Is[kh + h][kw + w][c] = in_ptr[c]; + if (is_edge) { + if (((bh + h) < 0 || (bh + h) >= params.iS[1]) || + ((bw + w) < 0 || (bw + w) >= params.iS[0])) { + for (int c = simd_lane_id; c < BC; c += 32) { + Is[kh + h][kw + w][c] = T(0); + } + } else { + for (int c = simd_lane_id; c < BC; c += 32) { + Is[kh + h][kw + w][c] = in_ptr[c]; + } + } + } else { + for (int c = simd_lane_id; c < BC; c += 32) { + Is[kh + h][kw + w][c] = in_ptr[c]; + } } } } @@ -652,3 +687,373 @@ winograd_conv_2d_output_transform( instantiate_winograd_conv_2d(float32, float); instantiate_winograd_conv_2d(bfloat16, bfloat16_t); instantiate_winograd_conv_2d(float16, half); // clang-format on + +#include "mlx/backend/metal/kernels/steel/attn/mma.h" + +template < + typename T, + bool do_flip = false, + int WM = 4, + int WN = 1, + typename AccumType = float> +[[kernel]] void winograd_fused( + const device T* input [[buffer(0)]], + const device T* weight [[buffer(1)]], + device T* output [[buffer(2)]], + const constant MLXConvParams<2>& params [[buffer(3)]], + uint3 tid [[threadgroup_position_in_grid]], + uint3 lid [[thread_position_in_threadgroup]], + uint3 tgp_per_grid [[threadgroups_per_grid]], + ushort simd_group_id [[simdgroup_index_in_threadgroup]], + ushort simd_lane_id [[thread_index_in_simdgroup]]) { + using namespace mlx::steel; + + (void)tgp_per_grid; + + // Winograd F(n x n, r x r) + // n x n output window + constexpr short FN = 2; + // r x r filter size + constexpr short FR = 3; + // a x a input window, a = n + r - 1 + constexpr short FA = 4; + + constexpr short kFragSize = 8; // MMA frag size + + constexpr short BT = 8; // Tile block size + constexpr short BO = 8; // Output channel block size + constexpr short BC = 8; // Input channel block size + + // clang-format off + static_assert(BT % (1 * kFragSize) == 0 && + BO % (1 * kFragSize) == 0 && + BC % kFragSize == 0, + "Matmuls sizes must be compatible with fragments"); + // clang-format on + + // Prepare for matmul + + // Warp tile sizes for matmul + constexpr short TM = (FA * FA * BT) / (WM * kFragSize); + constexpr short TN = (BO) / (WN * kFragSize); + constexpr short TK = (BC) / (kFragSize); + + // Warp primitives + using MMAFrag_acc_t = BaseMMAFrag; + + // Warp tiles sizes for matmul + MMATile Itile; + MMATile Wtile; + MMATile Otile[TM]; + + for (int im = 0; im < 4; im++) { + Otile[im].clear(); + } + + // Threadgroup memory for Weights and Inputs + constexpr short BS = BT > BO ? BT : BO; + threadgroup T Wt[FA * FA * BC * BO]; + threadgroup T It[FA * FA * BS * BS]; + + // Get thread position in tile + short2 simd_coord = MMAFrag_acc_t::get_coord(simd_lane_id); + const short sm = simd_coord.y; + const short sn = simd_coord.x; + + static_assert(FA * FA * BT == 32 * WM * WN, "Each thread loads one pixel."); + const int thr_idx = simd_group_id * 32 + simd_lane_id; + const int thr_t = thr_idx / (FA * FA); + const int thr_hw = thr_idx % (FA * FA); + const int thr_h = thr_hw / FA; + const int thr_w = thr_hw % FA; + + // Get batch, tile, and output idx for warp + const int b_idx = tid.z; + const int t_idx = BT * tid.y + thr_t; + const int o_idx = BO * tid.x + thr_t; + + // Divide tile into h, w tile + uniform oHu = make_uniform(params.oS[0]); + uniform oWu = make_uniform(params.oS[1]); + uniform tHu = (oHu + make_uniform(FN - 1)) / make_uniform(FN); + uniform tWu = (oWu + make_uniform(FN - 1)) / make_uniform(FN); + + const int oH_idx = FN * (t_idx / tWu); + const int oW_idx = FN * (t_idx % tWu); + const int iH_idx = oH_idx + thr_h - params.pad[0]; + const int iW_idx = oW_idx + thr_w - params.pad[1]; + + // Move to correct location + + // clang-format off + input += b_idx * params.in_strides[0] + // N + iH_idx * params.in_strides[1] + // H + iW_idx * params.in_strides[2]; // W + + weight += o_idx * params.wt_strides[0] + // O + thr_h * params.wt_strides[1] + // H + thr_w * params.wt_strides[2]; // W + // clang-format on + + // Do edge check prep for input + const bool is_edge_w_lo = iH_idx < 0; + const bool is_edge_h_lo = iW_idx < 0; + const bool is_edge_w_hi = iH_idx >= params.iS[0]; + const bool is_edge_h_hi = iW_idx >= params.iS[1]; + const bool is_edge = + is_edge_w_lo || is_edge_h_lo || is_edge_w_hi || is_edge_h_hi; + + // Iterate over C + for (int c = 0; c < params.C; c += BC) { +#define tmp_load_wt_idx(o, h, w, c) h* FA* BC* BO + w* BC* BO + c* BO + o +#define tmp_load_in_idx(t, h, w, c) h* FA* BS* BC + w* BS* BC + t* BC + c + +#define tmp_trns_wt_idx(o, h, w, c) h* FA* BC* BO + w* BC* BO + c* BO + o +#define tmp_trns_in_idx(t, h, w, c) h* FA* BS* BC + w* BS* BC + t* BC + c + + threadgroup_barrier(mem_flags::mem_threadgroup); + + // Load weight + if (thr_h < FR && thr_w < FR && thr_t < BO) { + for (int ic = 0; ic < BC; ic++) { + if (do_flip) { + Wt[tmp_load_wt_idx(thr_t, FR - 1 - thr_h, FR - 1 - thr_w, ic)] = + weight[c + ic]; + } else { + Wt[tmp_load_wt_idx(thr_t, thr_h, thr_w, ic)] = weight[c + ic]; + } + } + } + + // Load input + if (is_edge) { + for (int ic = 0; ic < BC; ic++) { + It[tmp_load_in_idx(thr_t, thr_h, thr_w, ic)] = T(0); + } + } else { + for (int ic = 0; ic < BC; ic++) { + It[tmp_load_in_idx(thr_t, thr_h, thr_w, ic)] = input[c + ic]; + } + } + threadgroup_barrier(mem_flags::mem_threadgroup); + + // Transform weight + if (lid.z == 0) { + const short ic = lid.y; + const short io = lid.x; + + T tmp_0[4][4]; + T tmp_1[4][4]; + + for (int ii = 0; ii < 3; ++ii) { + for (int jj = 0; jj < 3; ++jj) { + tmp_0[ii][jj] = Wt[tmp_load_wt_idx(io, ii, jj, ic)]; + } + } + + ////////////////////////////////////////////// + + tmp_1[0][0] = tmp_0[0][0]; + tmp_1[0][1] = tmp_0[0][1]; + tmp_1[0][2] = tmp_0[0][2]; + + tmp_1[1][0] = T(0.5) * (tmp_0[0][0] + tmp_0[1][0] + tmp_0[2][0]); + tmp_1[1][1] = T(0.5) * (tmp_0[0][1] + tmp_0[1][1] + tmp_0[2][1]); + tmp_1[1][2] = T(0.5) * (tmp_0[0][2] + tmp_0[1][2] + tmp_0[2][2]); + + tmp_1[2][0] = tmp_1[1][0] - tmp_0[1][0]; + tmp_1[2][1] = tmp_1[1][1] - tmp_0[1][1]; + tmp_1[2][2] = tmp_1[1][2] - tmp_0[1][2]; + + tmp_1[3][0] = tmp_0[2][0]; + tmp_1[3][1] = tmp_0[2][1]; + tmp_1[3][2] = tmp_0[2][2]; + + ////////////////////////////////////////////// + tmp_0[0][0] = tmp_1[0][0]; + tmp_0[1][0] = tmp_1[1][0]; + tmp_0[2][0] = tmp_1[2][0]; + tmp_0[3][0] = tmp_1[3][0]; + + tmp_0[0][1] = T(0.5) * (tmp_1[0][0] + tmp_1[0][1] + tmp_1[0][2]); + tmp_0[1][1] = T(0.5) * (tmp_1[1][0] + tmp_1[1][1] + tmp_1[1][2]); + tmp_0[2][1] = T(0.5) * (tmp_1[2][0] + tmp_1[2][1] + tmp_1[2][2]); + tmp_0[3][1] = T(0.5) * (tmp_1[3][0] + tmp_1[3][1] + tmp_1[3][2]); + + tmp_0[0][2] = tmp_0[0][1] - tmp_1[0][1]; + tmp_0[1][2] = tmp_0[1][1] - tmp_1[1][1]; + tmp_0[2][2] = tmp_0[2][1] - tmp_1[2][1]; + tmp_0[3][2] = tmp_0[3][1] - tmp_1[3][1]; + + tmp_0[0][3] = tmp_1[0][2]; + tmp_0[1][3] = tmp_1[1][2]; + tmp_0[2][3] = tmp_1[2][2]; + tmp_0[3][3] = tmp_1[3][2]; + + for (int ii = 0; ii < 4; ++ii) { + for (int jj = 0; jj < 4; ++jj) { + Wt[tmp_trns_wt_idx(io, ii, jj, ic)] = tmp_0[ii][jj]; + } + } + } + + // Transform input + else { + const short it = lid.y; + const short ic = lid.x; + + T tmp_0[4][4]; + T tmp_1[4][4]; + + for (int ii = 0; ii < 4; ++ii) { + for (int jj = 0; jj < 4; ++jj) { + tmp_0[ii][jj] = It[tmp_load_in_idx(it, ii, jj, ic)]; + } + } + + ////////////////////////////////////////////// + + tmp_1[0][0] = tmp_0[0][0] - tmp_0[2][0]; + tmp_1[0][1] = tmp_0[0][1] - tmp_0[2][1]; + tmp_1[0][2] = tmp_0[0][2] - tmp_0[2][2]; + tmp_1[0][3] = tmp_0[0][3] - tmp_0[2][3]; + + tmp_1[1][0] = tmp_0[1][0] + tmp_0[2][0]; + tmp_1[1][1] = tmp_0[1][1] + tmp_0[2][1]; + tmp_1[1][2] = tmp_0[1][2] + tmp_0[2][2]; + tmp_1[1][3] = tmp_0[1][3] + tmp_0[2][3]; + + tmp_1[2][0] = tmp_0[2][0] - tmp_0[1][0]; + tmp_1[2][1] = tmp_0[2][1] - tmp_0[1][1]; + tmp_1[2][2] = tmp_0[2][2] - tmp_0[1][2]; + tmp_1[2][3] = tmp_0[2][3] - tmp_0[1][3]; + + tmp_1[3][0] = tmp_0[1][0] - tmp_0[3][0]; + tmp_1[3][1] = tmp_0[1][1] - tmp_0[3][1]; + tmp_1[3][2] = tmp_0[1][2] - tmp_0[3][2]; + tmp_1[3][3] = tmp_0[1][3] - tmp_0[3][3]; + + ////////////////////////////////////////////// + tmp_0[0][0] = tmp_1[0][0] - tmp_1[0][2]; + tmp_0[1][0] = tmp_1[1][0] - tmp_1[1][2]; + tmp_0[2][0] = tmp_1[2][0] - tmp_1[2][2]; + tmp_0[3][0] = tmp_1[3][0] - tmp_1[3][2]; + + tmp_0[0][1] = tmp_1[0][1] + tmp_1[0][2]; + tmp_0[1][1] = tmp_1[1][1] + tmp_1[1][2]; + tmp_0[2][1] = tmp_1[2][1] + tmp_1[2][2]; + tmp_0[3][1] = tmp_1[3][1] + tmp_1[3][2]; + + tmp_0[0][2] = tmp_1[0][2] - tmp_1[0][1]; + tmp_0[1][2] = tmp_1[1][2] - tmp_1[1][1]; + tmp_0[2][2] = tmp_1[2][2] - tmp_1[2][1]; + tmp_0[3][2] = tmp_1[3][2] - tmp_1[3][1]; + + tmp_0[0][3] = tmp_1[0][1] - tmp_1[0][3]; + tmp_0[1][3] = tmp_1[1][1] - tmp_1[1][3]; + tmp_0[2][3] = tmp_1[2][1] - tmp_1[2][3]; + tmp_0[3][3] = tmp_1[3][1] - tmp_1[3][3]; + + for (int ii = 0; ii < 4; ++ii) { + for (int jj = 0; jj < 4; ++jj) { + It[tmp_trns_in_idx(it, ii, jj, ic)] = tmp_0[ii][jj]; + } + } + } + threadgroup_barrier(mem_flags::mem_threadgroup); + + // Do matmul + for (int im = 0; im < 4; im++) { + simdgroup_barrier(mem_flags::mem_none); + Itile.template load( + &It[simd_group_id * FA * BS * BS + im * BS * BS + sm * BS + sn]); + simdgroup_barrier(mem_flags::mem_none); + Wtile.template load( + &Wt[simd_group_id * FA * BC * BO + im * BC * BO + sm * BO + sn]); + simdgroup_barrier(mem_flags::mem_none); + tile_matmad(Otile[im], Itile, Wtile, Otile[im]); + } + } + + // Transform and write output + threadgroup_barrier(mem_flags::mem_threadgroup); + for (int im = 0; im < 4; im++) { + Otile[im].template store( + &It[simd_group_id * FA * BS * BS + im * BS * BS + sm * BS + sn]); + } + threadgroup_barrier(mem_flags::mem_threadgroup); + + if (lid.z == 0) { + const short it = lid.y; + const short io = lid.x; + + T tmp_0[4][4]; + T tmp_1[2][4]; + T tmp_2[2][2]; + + for (int ii = 0; ii < 4; ++ii) { + for (int jj = 0; jj < 4; ++jj) { + tmp_0[ii][jj] = It[tmp_trns_in_idx(it, ii, jj, io)]; + } + } + + tmp_1[0][0] = tmp_0[0][0] + tmp_0[1][0] + tmp_0[2][0]; + tmp_1[0][1] = tmp_0[0][1] + tmp_0[1][1] + tmp_0[2][1]; + tmp_1[0][2] = tmp_0[0][2] + tmp_0[1][2] + tmp_0[2][2]; + tmp_1[0][3] = tmp_0[0][3] + tmp_0[1][3] + tmp_0[2][3]; + + tmp_1[1][0] = tmp_0[1][0] - tmp_0[2][0] - tmp_0[3][0]; + tmp_1[1][1] = tmp_0[1][1] - tmp_0[2][1] - tmp_0[3][1]; + tmp_1[1][2] = tmp_0[1][2] - tmp_0[2][2] - tmp_0[3][2]; + tmp_1[1][3] = tmp_0[1][3] - tmp_0[2][3] - tmp_0[3][3]; + + tmp_2[0][0] = tmp_1[0][0] + tmp_1[0][1] + tmp_1[0][2]; + tmp_2[1][0] = tmp_1[1][0] + tmp_1[1][1] + tmp_1[1][2]; + + tmp_2[0][1] = tmp_1[0][1] - tmp_1[0][2] - tmp_1[0][3]; + tmp_2[1][1] = tmp_1[1][1] - tmp_1[1][2] - tmp_1[1][3]; + + const int oH_i = FN * ((BT * tid.y + it) / tWu); + const int oW_i = FN * ((BT * tid.y + it) % tWu); + + // clang-format off + output += b_idx * params.out_strides[0] + // N + oH_i * params.out_strides[1] + // H + oW_i * params.out_strides[2] + // W + BO * tid.x; // C + + // clang-format on + + output[0 * params.out_strides[1] + 0 * params.out_strides[2] + io] = + tmp_2[0][0]; + output[0 * params.out_strides[1] + 1 * params.out_strides[2] + io] = + tmp_2[0][1]; + output[1 * params.out_strides[1] + 0 * params.out_strides[2] + io] = + tmp_2[1][0]; + output[1 * params.out_strides[1] + 1 * params.out_strides[2] + io] = + tmp_2[1][1]; + } +} + +// clang-format off +#define instantiate_winograd_conv_2d_fused(name, itype, f) \ + template [[host_name("winograd_conv_2d_fused_" #name "_flip" #f)]] \ + [[kernel]] void winograd_fused( \ + const device itype* input [[buffer(0)]], \ + const device itype* weight [[buffer(1)]], \ + device itype* output [[buffer(2)]], \ + const constant MLXConvParams<2>& params [[buffer(3)]], \ + uint3 tid [[threadgroup_position_in_grid]], \ + uint3 lid [[thread_position_in_threadgroup]], \ + uint3 tgp_per_grid [[threadgroups_per_grid]], \ + ushort simd_group_id [[simdgroup_index_in_threadgroup]], \ + ushort simd_lane_id [[thread_index_in_simdgroup]]); + +#define instantiate_winograd_conv_2d_fused_2(name, itype) \ + instantiate_winograd_conv_2d_fused(name, itype, 0) \ + instantiate_winograd_conv_2d_fused(name, itype, 1) + +instantiate_winograd_conv_2d_fused_2(float32, float); +instantiate_winograd_conv_2d_fused_2(float16, float16_t); +instantiate_winograd_conv_2d_fused_2(bfloat16, bfloat16_t); +// clang-format on \ No newline at end of file diff --git a/mlx/ops.cpp b/mlx/ops.cpp index 514a244b0..22b940814 100644 --- a/mlx/ops.cpp +++ b/mlx/ops.cpp @@ -3882,7 +3882,7 @@ array conv_general( return array( std::move(out_shape), - in.dtype(), + out_type, std::make_shared( to_stream(s), stride, diff --git a/python/tests/test_conv.py b/python/tests/test_conv.py index 862e8ec7f..0064563e4 100644 --- a/python/tests/test_conv.py +++ b/python/tests/test_conv.py @@ -341,7 +341,7 @@ class TestConv(mlx_tests.MLXTestCase): atol, rtol = 1e-1, 1e-3 else: atol, rtol = 1e-5, 1e-6 - self.assertTrue(np.allclose(out_pt, out_mx, atol=atol)) + self.assertTrue(np.allclose(out_pt, out_mx, atol=atol, rtol=rtol)) for dtype in ("float32", "bfloat16"): for N, C, O in (