diff --git a/CMakeLists.txt b/CMakeLists.txt index 17a832364..54d974a26 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -25,7 +25,7 @@ option(MLX_METAL_JIT "Use JIT compilation for Metal kernels" OFF) option(BUILD_SHARED_LIBS "Build mlx as a shared library" OFF) if(NOT MLX_VERSION) - set(MLX_VERSION 0.23.0) + set(MLX_VERSION 0.23.1) endif() add_compile_definitions("MLX_VERSION=${MLX_VERSION}") diff --git a/mlx/backend/metal/conv.cpp b/mlx/backend/metal/conv.cpp index 6356ad9ba..3e42f7d2f 100644 --- a/mlx/backend/metal/conv.cpp +++ b/mlx/backend/metal/conv.cpp @@ -533,45 +533,6 @@ void implicit_gemm_conv_2D_general_gpu( compute_encoder.dispatch_threadgroups(grid_dims, group_dims); } -void winograd_conv_2D_fused_gpu( - const Stream& s, - metal::Device& d, - const array& in, - const array& wt, - array out, - const MLXConvParams<2>& conv_params, - std::vector& copies_w) { - int O_c = conv_params.O; - int C_c = conv_params.C; - - int N_tiles_n = conv_params.N; - int N_tiles_h = (conv_params.oS[0] + 1) / 2; - int N_tiles_w = (conv_params.oS[1] + 1) / 2; - int N_tiles = N_tiles_n * N_tiles_h * N_tiles_w; - - int bc = 32; - int wm = 4; - int wn = 1; - std::ostringstream kname; - kname << "winograd_conv_2d_fused_" << type_to_name(out) << "_flip" - << conv_params.flip; - auto& compute_encoder = d.get_command_encoder(s.index); - auto kernel = d.get_kernel(kname.str()); - compute_encoder.set_compute_pipeline_state(kernel); - - compute_encoder.set_input_array(in, 0); - compute_encoder.set_input_array(wt, 1); - compute_encoder.set_output_array(out, 2); - - compute_encoder.set_bytes(conv_params, 3); - - MTL::Size group_dims = MTL::Size(8, 8, 2); - MTL::Size grid_dims = - MTL::Size(O_c / 8, (N_tiles_h * N_tiles_w) / 8, N_tiles_n); - - compute_encoder.dispatch_threadgroups(grid_dims, group_dims); -} - void winograd_conv_2D_gpu( const Stream& s, metal::Device& d, @@ -580,6 +541,67 @@ void winograd_conv_2D_gpu( array out, const MLXConvParams<2>& conv_params, std::vector& copies_w) { + Shape padded_shape = { + conv_params.N, + conv_params.iS[0] + 2 * conv_params.pad[0], + conv_params.iS[1] + 2 * conv_params.pad[1], + conv_params.C}; + + padded_shape[1] = 6 * ((padded_shape[1] - 2 + 5) / 6) + 2; + padded_shape[2] = 6 * ((padded_shape[2] - 2 + 5) / 6) + 2; + + array in_padded(std::move(padded_shape), in.dtype(), nullptr, {}); + + // Fill with zeros + array zero_arr = array(0, in.dtype()); + fill_gpu(zero_arr, in_padded, s); + copies_w.push_back(zero_arr); + + // Pick input slice from padded + size_t data_offset = conv_params.pad[0] * in_padded.strides()[1] + + conv_params.pad[1] * in_padded.strides()[2]; + array in_padded_slice(in.shape(), in_padded.dtype(), nullptr, {}); + in_padded_slice.copy_shared_buffer( + in_padded, + in_padded.strides(), + in_padded.flags(), + in_padded_slice.size(), + data_offset); + + // Copy input values into the slice + copy_gpu_inplace(in, in_padded_slice, CopyType::GeneralGeneral, s); + + copies_w.push_back(in_padded_slice); + copies_w.push_back(in_padded); + + MLXConvParams<2> conv_params_updated{ + /* const int N = */ static_cast(in_padded.shape(0)), + /* const int C = */ static_cast(in_padded.shape(3)), + /* const int O = */ static_cast(wt.shape(0)), + /* const int iS[NDIM] = */ + {static_cast(in_padded.shape(1)), + static_cast(in_padded.shape(2))}, + /* const int wS[NDIM] = */ + {static_cast(wt.shape(1)), static_cast(wt.shape(2))}, + /* const int oS[NDIM] = */ + {static_cast(out.shape(1)), static_cast(out.shape(2))}, + /* const int str[NDIM] = */ {1, 1}, + /* const int pad[NDIM] = */ {0, 0}, + /* const int kdil[NDIM] = */ {1, 1}, + /* const int idil[NDIM] = */ {1, 1}, + /* const size_t in_strides[NDIM + 2] = */ + {in_padded.strides()[0], + in_padded.strides()[1], + in_padded.strides()[2], + in_padded.strides()[3]}, + /* const size_t wt_strides[NDIM + 2] = */ + {wt.strides()[0], wt.strides()[1], wt.strides()[2], wt.strides()[3]}, + /* const size_t out_strides[NDIM + 2] = */ + {out.strides()[0], out.strides()[1], out.strides()[2], out.strides()[3]}, + /* const int groups = */ 1, + /* const bool flip = */ false, + }; + int O_c = conv_params.O; int C_c = conv_params.C; @@ -598,7 +620,7 @@ void winograd_conv_2D_gpu( int bo = 4; std::ostringstream kname; kname << "winograd_conv_2d_weight_transform_" << type_to_name(out) << "_bc" - << bc << "_flip" << conv_params.flip; + << bc; auto& compute_encoder = d.get_command_encoder(s.index); auto kernel = d.get_kernel(kname.str()); compute_encoder.set_compute_pipeline_state(kernel); @@ -631,10 +653,10 @@ void winograd_conv_2D_gpu( auto kernel = d.get_kernel(kname.str()); compute_encoder.set_compute_pipeline_state(kernel); - compute_encoder.set_input_array(in, 0); + compute_encoder.set_input_array(in_padded, 0); compute_encoder.set_output_array(inp_wg, 1); - compute_encoder.set_bytes(conv_params, 2); + compute_encoder.set_bytes(conv_params_updated, 2); MTL::Size group_dims = MTL::Size(32, wn, wm); MTL::Size grid_dims = MTL::Size(N_tiles_w, N_tiles_h, N_tiles_n); @@ -681,7 +703,7 @@ void winograd_conv_2D_gpu( compute_encoder.set_input_array(out_wg, 0); compute_encoder.set_output_array(out, 1); - compute_encoder.set_bytes(conv_params, 2); + compute_encoder.set_bytes(conv_params_updated, 2); MTL::Size group_dims = MTL::Size(32, wn, wm); MTL::Size grid_dims = MTL::Size(N_tiles_w, N_tiles_h, N_tiles_n); @@ -745,18 +767,14 @@ void conv_2D_gpu( } // Direct to winograd conv - bool img_large = + bool inp_large = (conv_params.N * conv_params.iS[0] * conv_params.iS[1]) >= 1ul << 12; bool channels_large = (conv_params.C + conv_params.O) >= 256; - if (conv_params.wS[0] == 3 && conv_params.wS[1] == 3 && - conv_params.C % 32 == 0 && conv_params.O % 32 == 0 && is_stride_one && - is_kdil_one && is_idil_one) { - if (img_large && channels_large) { - return winograd_conv_2D_gpu(s, d, in, wt, out, conv_params, copies); - } - if (conv_params.N <= 1) { - return winograd_conv_2D_fused_gpu(s, d, in, wt, out, conv_params, copies); - } + if (!flip && is_stride_one && is_kdil_one && is_idil_one && + conv_params.wS[0] == 3 && conv_params.wS[1] == 3 && + conv_params.C % 32 == 0 && conv_params.O % 32 == 0 && inp_large && + channels_large) { + return winograd_conv_2D_gpu(s, d, in, wt, out, conv_params, copies); } // Direct to implicit gemm conv @@ -858,40 +876,8 @@ void Convolution::eval_gpu(const std::vector& inputs, array& out) { wt = arr_copy; } - // Check for 1x1 conv - auto is_one = [](int x) { return x == 1; }; - auto is_zero = [](int x) { return x == 0; }; - if (groups_ == 1 && (wt.shape(0) * wt.shape(-1) == wt.size()) && - std::all_of(wt.shape().begin() + 1, wt.shape().end() - 1, is_one) && - std::all_of(kernel_strides_.begin(), kernel_strides_.end(), is_one) && - std::all_of(input_dilation_.begin(), input_dilation_.end(), is_one) && - std::all_of(kernel_dilation_.begin(), kernel_dilation_.end(), is_one) && - std::all_of(padding_.begin(), padding_.end(), is_zero)) { - std::vector empty_copies; - steel_matmul_regular( - s, - d, - /*a = */ in, - /*b = */ wt, - /*c = */ out, - /*M = */ in.size() / in.shape(-1), - /*N = */ wt.shape(0), - /*K = */ in.shape(-1), - /*batch_size_out = */ 1, - /*lda = */ in.shape(-1), - /*ldb = */ wt.shape(-1), - /*ldd = */ wt.shape(0), - /*transpose_a = */ false, - /*transpose_b = */ true, - /*batch_shape = */ {1}, - /*batch_strides = */ {1}, - /*A_batch_stride = */ 0, - /*B_batch_stride = */ 0, - /*matrix_stride_out = */ 0, - /*copies = */ empty_copies); - } // 3D conv - else if (out.ndim() == 5) { + if (out.ndim() == 5) { conv_3D_gpu( s, d, diff --git a/mlx/backend/metal/kernels/conv.metal b/mlx/backend/metal/kernels/conv.metal index be6fd6d1e..13ee239dc 100644 --- a/mlx/backend/metal/kernels/conv.metal +++ b/mlx/backend/metal/kernels/conv.metal @@ -326,13 +326,7 @@ constant constexpr const float WinogradTransforms<6, 3, 8>::wt_transform[8][8]; constant constexpr const float WinogradTransforms<6, 3, 8>::in_transform[8][8]; constant constexpr const float WinogradTransforms<6, 3, 8>::out_transform[8][8]; -template < - typename T, - int BC = 32, - int BO = 4, - bool do_flip = false, - int M = 6, - int R = 3> +template [[kernel, max_total_threads_per_threadgroup(BO * 32)]] void winograd_conv_2d_weight_transform( const device T* wt_in [[buffer(0)]], @@ -379,12 +373,7 @@ winograd_conv_2d_weight_transform( for (int kh = 0; kh < R; ++kh) { for (int kw = 0; kw < R; ++kw) { for (int kc = simd_lane_id; kc < BC; kc += 32) { - if (do_flip) { - Ws[simd_group_id][R - 1 - kh][R - 1 - kw][kc] = - wt_in[kh * R * C + kw * C + kc]; - } else { - Ws[simd_group_id][kh][kw][kc] = wt_in[kh * R * C + kw * C + kc]; - } + Ws[simd_group_id][kh][kw][kc] = wt_in[kh * R * C + kw * C + kc]; } } } @@ -409,10 +398,10 @@ winograd_conv_2d_weight_transform( } } -#define instantiate_winograd_conv_2d_weight_tr_base_2(name, itype, bc, f) \ - template [[host_name("winograd_conv_2d_weight_transform_" #name "_bc" #bc \ - "_flip" #f)]] [[kernel]] void \ - winograd_conv_2d_weight_transform( \ +#define instantiate_winograd_conv_2d_weight_transform_base(name, itype, bc) \ + template [[host_name("winograd_conv_2d_weight_transform_" #name \ + "_bc" #bc)]] [[kernel]] void \ + winograd_conv_2d_weight_transform( \ const device itype* wt_in [[buffer(0)]], \ device itype* wt_out [[buffer(1)]], \ const constant int& C [[buffer(2)]], \ @@ -421,10 +410,6 @@ winograd_conv_2d_weight_transform( uint simd_group_id [[simdgroup_index_in_threadgroup]], \ uint simd_lane_id [[thread_index_in_simdgroup]]); -#define instantiate_winograd_conv_2d_weight_transform_base(name, itype, bc) \ - instantiate_winograd_conv_2d_weight_tr_base_2(name, itype, bc, 0) \ - instantiate_winograd_conv_2d_weight_tr_base_2(name, itype, bc, 1) - template [[kernel, max_total_threads_per_threadgroup(WM* WN * 32)]] void winograd_conv_2d_input_transform( @@ -460,17 +445,10 @@ winograd_conv_2d_input_transform( // Resolve input tile constexpr int TH = (A / WM); constexpr int TW = (A / WN); - const int kh = TH * (simd_group_id / WN); - const int kw = TW * (simd_group_id % WN); - const int bh = M * tid.y + kh - params.pad[1]; - const int bw = M * tid.x + kw - params.pad[0]; - - const bool is_edge_w_lo = bw < 0; - const bool is_edge_h_lo = bh < 0; - const bool is_edge_w_hi = bw + (TW - 1) >= params.iS[0]; - const bool is_edge_h_hi = bh + (TH - 1) >= params.iS[1]; - const bool is_edge = - is_edge_w_lo || is_edge_h_lo || is_edge_w_hi || is_edge_h_hi; + int kh = TH * (simd_group_id / WN); + int kw = TW * (simd_group_id % WN); + int bh = M * tid.y + kh; + int bw = M * tid.x + kw; // Move to the correct input tile inp_in += tid.z * params.in_strides[0] + bh * params.in_strides[1] + @@ -506,21 +484,8 @@ winograd_conv_2d_input_transform( for (int h = 0; h < TH; h++) { for (int w = 0; w < TW; w++) { const device T* in_ptr = inp_in + jump_in[h][w]; - if (is_edge) { - if (((bh + h) < 0 || (bh + h) >= params.iS[1]) || - ((bw + w) < 0 || (bw + w) >= params.iS[0])) { - for (int c = simd_lane_id; c < BC; c += 32) { - Is[kh + h][kw + w][c] = T(0); - } - } else { - for (int c = simd_lane_id; c < BC; c += 32) { - Is[kh + h][kw + w][c] = in_ptr[c]; - } - } - } else { - for (int c = simd_lane_id; c < BC; c += 32) { - Is[kh + h][kw + w][c] = in_ptr[c]; - } + for (int c = simd_lane_id; c < BC; c += 32) { + Is[kh + h][kw + w][c] = in_ptr[c]; } } } @@ -687,371 +652,3 @@ winograd_conv_2d_output_transform( instantiate_winograd_conv_2d(float32, float); instantiate_winograd_conv_2d(bfloat16, bfloat16_t); instantiate_winograd_conv_2d(float16, half); // clang-format on - -#include "mlx/backend/metal/kernels/steel/attn/mma.h" - -template < - typename T, - bool do_flip = false, - int WM = 4, - int WN = 1, - typename AccumType = float> -[[kernel]] void winograd_fused( - const device T* input [[buffer(0)]], - const device T* weight [[buffer(1)]], - device T* output [[buffer(2)]], - const constant MLXConvParams<2>& params [[buffer(3)]], - uint3 tid [[threadgroup_position_in_grid]], - uint3 lid [[thread_position_in_threadgroup]], - uint3 tgp_per_grid [[threadgroups_per_grid]], - ushort simd_group_id [[simdgroup_index_in_threadgroup]], - ushort simd_lane_id [[thread_index_in_simdgroup]]) { - using namespace mlx::steel; - - (void)tgp_per_grid; - - // Winograd F(n x n, r x r) - // n x n output window - constexpr short FN = 2; - // r x r filter size - constexpr short FR = 3; - // a x a input window, a = n + r - 1 - constexpr short FA = 4; - - constexpr short kFragSize = 8; // MMA frag size - - constexpr short BT = 8; // Tile block size - constexpr short BO = 8; // Output channel block size - constexpr short BC = 8; // Input channel block size - - // clang-format off - static_assert(BT % (1 * kFragSize) == 0 && - BO % (1 * kFragSize) == 0 && - BC % kFragSize == 0, - "Matmuls sizes must be compatible with fragments"); - // clang-format on - - // Prepare for matmul - - // Warp tile sizes for matmul - constexpr short TM = (FA * FA * BT) / (WM * kFragSize); - constexpr short TN = (BO) / (WN * kFragSize); - constexpr short TK = (BC) / (kFragSize); - - // Warp primitives - using MMAFrag_acc_t = BaseMMAFrag; - - // Warp tiles sizes for matmul - MMATile Itile; - MMATile Wtile; - MMATile Otile[TM]; - - for (int im = 0; im < 4; im++) { - Otile[im].clear(); - } - - // Threadgroup memory for Weights and Inputs - constexpr short BS = BT > BO ? BT : BO; - threadgroup T Wt[FA * FA * BC * BO]; - threadgroup T It[FA * FA * BS * BS]; - - // Get thread position in tile - short2 simd_coord = MMAFrag_acc_t::get_coord(simd_lane_id); - const short sm = simd_coord.y; - const short sn = simd_coord.x; - - static_assert(FA * FA * BT == 32 * WM * WN, "Each thread loads one pixel."); - const int thr_idx = simd_group_id * 32 + simd_lane_id; - const int thr_t = thr_idx / (FA * FA); - const int thr_hw = thr_idx % (FA * FA); - const int thr_h = thr_hw / FA; - const int thr_w = thr_hw % FA; - - // Get batch, tile, and output idx for warp - const int b_idx = tid.z; - const int t_idx = BT * tid.y + thr_t; - const int o_idx = BO * tid.x + thr_t; - - // Divide tile into h, w tile - uniform oWu = make_uniform(params.oS[1]); - uniform tWu = (oWu + make_uniform(FN - 1)) / make_uniform(FN); - - const int oH_idx = FN * (t_idx / tWu); - const int oW_idx = FN * (t_idx % tWu); - const int iH_idx = oH_idx + thr_h - params.pad[0]; - const int iW_idx = oW_idx + thr_w - params.pad[1]; - - // Move to correct location - - // clang-format off - input += b_idx * params.in_strides[0] + // N - iH_idx * params.in_strides[1] + // H - iW_idx * params.in_strides[2]; // W - - weight += o_idx * params.wt_strides[0] + // O - thr_h * params.wt_strides[1] + // H - thr_w * params.wt_strides[2]; // W - // clang-format on - - // Do edge check prep for input - const bool is_edge_w_lo = iH_idx < 0; - const bool is_edge_h_lo = iW_idx < 0; - const bool is_edge_w_hi = iH_idx >= params.iS[0]; - const bool is_edge_h_hi = iW_idx >= params.iS[1]; - const bool is_edge = - is_edge_w_lo || is_edge_h_lo || is_edge_w_hi || is_edge_h_hi; - - // Iterate over C - for (int c = 0; c < params.C; c += BC) { -#define tmp_load_wt_idx(o, h, w, c) h* FA* BC* BO + w* BC* BO + c* BO + o -#define tmp_load_in_idx(t, h, w, c) h* FA* BS* BC + w* BS* BC + t* BC + c - -#define tmp_trns_wt_idx(o, h, w, c) h* FA* BC* BO + w* BC* BO + c* BO + o -#define tmp_trns_in_idx(t, h, w, c) h* FA* BS* BC + w* BS* BC + t* BC + c - - threadgroup_barrier(mem_flags::mem_threadgroup); - - // Load weight - if (thr_h < FR && thr_w < FR && thr_t < BO) { - for (int ic = 0; ic < BC; ic++) { - if (do_flip) { - Wt[tmp_load_wt_idx(thr_t, FR - 1 - thr_h, FR - 1 - thr_w, ic)] = - weight[c + ic]; - } else { - Wt[tmp_load_wt_idx(thr_t, thr_h, thr_w, ic)] = weight[c + ic]; - } - } - } - - // Load input - if (is_edge) { - for (int ic = 0; ic < BC; ic++) { - It[tmp_load_in_idx(thr_t, thr_h, thr_w, ic)] = T(0); - } - } else { - for (int ic = 0; ic < BC; ic++) { - It[tmp_load_in_idx(thr_t, thr_h, thr_w, ic)] = input[c + ic]; - } - } - threadgroup_barrier(mem_flags::mem_threadgroup); - - // Transform weight - if (lid.z == 0) { - const short ic = lid.y; - const short io = lid.x; - - T tmp_0[4][4]; - T tmp_1[4][4]; - - for (int ii = 0; ii < 3; ++ii) { - for (int jj = 0; jj < 3; ++jj) { - tmp_0[ii][jj] = Wt[tmp_load_wt_idx(io, ii, jj, ic)]; - } - } - - ////////////////////////////////////////////// - - tmp_1[0][0] = tmp_0[0][0]; - tmp_1[0][1] = tmp_0[0][1]; - tmp_1[0][2] = tmp_0[0][2]; - - tmp_1[1][0] = T(0.5) * (tmp_0[0][0] + tmp_0[1][0] + tmp_0[2][0]); - tmp_1[1][1] = T(0.5) * (tmp_0[0][1] + tmp_0[1][1] + tmp_0[2][1]); - tmp_1[1][2] = T(0.5) * (tmp_0[0][2] + tmp_0[1][2] + tmp_0[2][2]); - - tmp_1[2][0] = tmp_1[1][0] - tmp_0[1][0]; - tmp_1[2][1] = tmp_1[1][1] - tmp_0[1][1]; - tmp_1[2][2] = tmp_1[1][2] - tmp_0[1][2]; - - tmp_1[3][0] = tmp_0[2][0]; - tmp_1[3][1] = tmp_0[2][1]; - tmp_1[3][2] = tmp_0[2][2]; - - ////////////////////////////////////////////// - tmp_0[0][0] = tmp_1[0][0]; - tmp_0[1][0] = tmp_1[1][0]; - tmp_0[2][0] = tmp_1[2][0]; - tmp_0[3][0] = tmp_1[3][0]; - - tmp_0[0][1] = T(0.5) * (tmp_1[0][0] + tmp_1[0][1] + tmp_1[0][2]); - tmp_0[1][1] = T(0.5) * (tmp_1[1][0] + tmp_1[1][1] + tmp_1[1][2]); - tmp_0[2][1] = T(0.5) * (tmp_1[2][0] + tmp_1[2][1] + tmp_1[2][2]); - tmp_0[3][1] = T(0.5) * (tmp_1[3][0] + tmp_1[3][1] + tmp_1[3][2]); - - tmp_0[0][2] = tmp_0[0][1] - tmp_1[0][1]; - tmp_0[1][2] = tmp_0[1][1] - tmp_1[1][1]; - tmp_0[2][2] = tmp_0[2][1] - tmp_1[2][1]; - tmp_0[3][2] = tmp_0[3][1] - tmp_1[3][1]; - - tmp_0[0][3] = tmp_1[0][2]; - tmp_0[1][3] = tmp_1[1][2]; - tmp_0[2][3] = tmp_1[2][2]; - tmp_0[3][3] = tmp_1[3][2]; - - for (int ii = 0; ii < 4; ++ii) { - for (int jj = 0; jj < 4; ++jj) { - Wt[tmp_trns_wt_idx(io, ii, jj, ic)] = tmp_0[ii][jj]; - } - } - } - - // Transform input - else { - const short it = lid.y; - const short ic = lid.x; - - T tmp_0[4][4]; - T tmp_1[4][4]; - - for (int ii = 0; ii < 4; ++ii) { - for (int jj = 0; jj < 4; ++jj) { - tmp_0[ii][jj] = It[tmp_load_in_idx(it, ii, jj, ic)]; - } - } - - ////////////////////////////////////////////// - - tmp_1[0][0] = tmp_0[0][0] - tmp_0[2][0]; - tmp_1[0][1] = tmp_0[0][1] - tmp_0[2][1]; - tmp_1[0][2] = tmp_0[0][2] - tmp_0[2][2]; - tmp_1[0][3] = tmp_0[0][3] - tmp_0[2][3]; - - tmp_1[1][0] = tmp_0[1][0] + tmp_0[2][0]; - tmp_1[1][1] = tmp_0[1][1] + tmp_0[2][1]; - tmp_1[1][2] = tmp_0[1][2] + tmp_0[2][2]; - tmp_1[1][3] = tmp_0[1][3] + tmp_0[2][3]; - - tmp_1[2][0] = tmp_0[2][0] - tmp_0[1][0]; - tmp_1[2][1] = tmp_0[2][1] - tmp_0[1][1]; - tmp_1[2][2] = tmp_0[2][2] - tmp_0[1][2]; - tmp_1[2][3] = tmp_0[2][3] - tmp_0[1][3]; - - tmp_1[3][0] = tmp_0[1][0] - tmp_0[3][0]; - tmp_1[3][1] = tmp_0[1][1] - tmp_0[3][1]; - tmp_1[3][2] = tmp_0[1][2] - tmp_0[3][2]; - tmp_1[3][3] = tmp_0[1][3] - tmp_0[3][3]; - - ////////////////////////////////////////////// - tmp_0[0][0] = tmp_1[0][0] - tmp_1[0][2]; - tmp_0[1][0] = tmp_1[1][0] - tmp_1[1][2]; - tmp_0[2][0] = tmp_1[2][0] - tmp_1[2][2]; - tmp_0[3][0] = tmp_1[3][0] - tmp_1[3][2]; - - tmp_0[0][1] = tmp_1[0][1] + tmp_1[0][2]; - tmp_0[1][1] = tmp_1[1][1] + tmp_1[1][2]; - tmp_0[2][1] = tmp_1[2][1] + tmp_1[2][2]; - tmp_0[3][1] = tmp_1[3][1] + tmp_1[3][2]; - - tmp_0[0][2] = tmp_1[0][2] - tmp_1[0][1]; - tmp_0[1][2] = tmp_1[1][2] - tmp_1[1][1]; - tmp_0[2][2] = tmp_1[2][2] - tmp_1[2][1]; - tmp_0[3][2] = tmp_1[3][2] - tmp_1[3][1]; - - tmp_0[0][3] = tmp_1[0][1] - tmp_1[0][3]; - tmp_0[1][3] = tmp_1[1][1] - tmp_1[1][3]; - tmp_0[2][3] = tmp_1[2][1] - tmp_1[2][3]; - tmp_0[3][3] = tmp_1[3][1] - tmp_1[3][3]; - - for (int ii = 0; ii < 4; ++ii) { - for (int jj = 0; jj < 4; ++jj) { - It[tmp_trns_in_idx(it, ii, jj, ic)] = tmp_0[ii][jj]; - } - } - } - threadgroup_barrier(mem_flags::mem_threadgroup); - - // Do matmul - for (int im = 0; im < 4; im++) { - simdgroup_barrier(mem_flags::mem_none); - Itile.template load( - &It[simd_group_id * FA * BS * BS + im * BS * BS + sm * BS + sn]); - simdgroup_barrier(mem_flags::mem_none); - Wtile.template load( - &Wt[simd_group_id * FA * BC * BO + im * BC * BO + sm * BO + sn]); - simdgroup_barrier(mem_flags::mem_none); - tile_matmad(Otile[im], Itile, Wtile, Otile[im]); - } - } - - // Transform and write output - threadgroup_barrier(mem_flags::mem_threadgroup); - for (int im = 0; im < 4; im++) { - Otile[im].template store( - &It[simd_group_id * FA * BS * BS + im * BS * BS + sm * BS + sn]); - } - threadgroup_barrier(mem_flags::mem_threadgroup); - - if (lid.z == 0) { - const short it = lid.y; - const short io = lid.x; - - T tmp_0[4][4]; - T tmp_1[2][4]; - T tmp_2[2][2]; - - for (int ii = 0; ii < 4; ++ii) { - for (int jj = 0; jj < 4; ++jj) { - tmp_0[ii][jj] = It[tmp_trns_in_idx(it, ii, jj, io)]; - } - } - - tmp_1[0][0] = tmp_0[0][0] + tmp_0[1][0] + tmp_0[2][0]; - tmp_1[0][1] = tmp_0[0][1] + tmp_0[1][1] + tmp_0[2][1]; - tmp_1[0][2] = tmp_0[0][2] + tmp_0[1][2] + tmp_0[2][2]; - tmp_1[0][3] = tmp_0[0][3] + tmp_0[1][3] + tmp_0[2][3]; - - tmp_1[1][0] = tmp_0[1][0] - tmp_0[2][0] - tmp_0[3][0]; - tmp_1[1][1] = tmp_0[1][1] - tmp_0[2][1] - tmp_0[3][1]; - tmp_1[1][2] = tmp_0[1][2] - tmp_0[2][2] - tmp_0[3][2]; - tmp_1[1][3] = tmp_0[1][3] - tmp_0[2][3] - tmp_0[3][3]; - - tmp_2[0][0] = tmp_1[0][0] + tmp_1[0][1] + tmp_1[0][2]; - tmp_2[1][0] = tmp_1[1][0] + tmp_1[1][1] + tmp_1[1][2]; - - tmp_2[0][1] = tmp_1[0][1] - tmp_1[0][2] - tmp_1[0][3]; - tmp_2[1][1] = tmp_1[1][1] - tmp_1[1][2] - tmp_1[1][3]; - - const int oH_i = FN * ((BT * tid.y + it) / tWu); - const int oW_i = FN * ((BT * tid.y + it) % tWu); - - // clang-format off - output += b_idx * params.out_strides[0] + // N - oH_i * params.out_strides[1] + // H - oW_i * params.out_strides[2] + // W - BO * tid.x; // C - - // clang-format on - - output[0 * params.out_strides[1] + 0 * params.out_strides[2] + io] = - tmp_2[0][0]; - output[0 * params.out_strides[1] + 1 * params.out_strides[2] + io] = - tmp_2[0][1]; - output[1 * params.out_strides[1] + 0 * params.out_strides[2] + io] = - tmp_2[1][0]; - output[1 * params.out_strides[1] + 1 * params.out_strides[2] + io] = - tmp_2[1][1]; - } -} - -// clang-format off -#define instantiate_winograd_conv_2d_fused(name, itype, f) \ - template [[host_name("winograd_conv_2d_fused_" #name "_flip" #f)]] \ - [[kernel]] void winograd_fused( \ - const device itype* input [[buffer(0)]], \ - const device itype* weight [[buffer(1)]], \ - device itype* output [[buffer(2)]], \ - const constant MLXConvParams<2>& params [[buffer(3)]], \ - uint3 tid [[threadgroup_position_in_grid]], \ - uint3 lid [[thread_position_in_threadgroup]], \ - uint3 tgp_per_grid [[threadgroups_per_grid]], \ - ushort simd_group_id [[simdgroup_index_in_threadgroup]], \ - ushort simd_lane_id [[thread_index_in_simdgroup]]); - -#define instantiate_winograd_conv_2d_fused_2(name, itype) \ - instantiate_winograd_conv_2d_fused(name, itype, 0) \ - instantiate_winograd_conv_2d_fused(name, itype, 1) - -instantiate_winograd_conv_2d_fused_2(float32, float); -instantiate_winograd_conv_2d_fused_2(float16, float16_t); -instantiate_winograd_conv_2d_fused_2(bfloat16, bfloat16_t); -// clang-format on diff --git a/mlx/ops.cpp b/mlx/ops.cpp index 22b940814..514a244b0 100644 --- a/mlx/ops.cpp +++ b/mlx/ops.cpp @@ -3882,7 +3882,7 @@ array conv_general( return array( std::move(out_shape), - out_type, + in.dtype(), std::make_shared( to_stream(s), stride, diff --git a/python/tests/test_conv.py b/python/tests/test_conv.py index 0064563e4..9dd8fd140 100644 --- a/python/tests/test_conv.py +++ b/python/tests/test_conv.py @@ -341,7 +341,7 @@ class TestConv(mlx_tests.MLXTestCase): atol, rtol = 1e-1, 1e-3 else: atol, rtol = 1e-5, 1e-6 - self.assertTrue(np.allclose(out_pt, out_mx, atol=atol, rtol=rtol)) + self.assertTrue(np.allclose(out_pt, out_mx, atol=atol)) for dtype in ("float32", "bfloat16"): for N, C, O in ( @@ -1042,6 +1042,14 @@ class TestConv(mlx_tests.MLXTestCase): self.assertTrue(mx.allclose(expected[0], grads[0])) self.assertTrue(mx.allclose(expected[1], grads[1])) + def test_repeated_conv(self): + x = mx.random.normal((1, 3, 3, 320)) + w = mx.random.normal((320, 3, 3, 320)) + for i in range(8): + y1 = mx.conv2d(x, w, (1, 1), (1, 1), (1, 1), 1) + y2 = mx.conv2d(x, w, (1, 1), (1, 1), (1, 1), 1) + self.assertTrue(mx.allclose(y1, y2)) + if __name__ == "__main__": unittest.main() diff --git a/setup.py b/setup.py index 85f3ae3c0..9296254ff 100644 --- a/setup.py +++ b/setup.py @@ -173,7 +173,7 @@ if __name__ == "__main__": setup( name="mlx", - version=get_version("0.23.0"), + version=get_version("0.23.1"), author="MLX Contributors", author_email="mlx@group.apple.com", description="A framework for machine learning on Apple silicon.",