diff --git a/mlx/backend/metal/conv.cpp b/mlx/backend/metal/conv.cpp index 3e42f7d2f..1ece746f1 100644 --- a/mlx/backend/metal/conv.cpp +++ b/mlx/backend/metal/conv.cpp @@ -541,67 +541,6 @@ void winograd_conv_2D_gpu( array out, const MLXConvParams<2>& conv_params, std::vector& copies_w) { - Shape padded_shape = { - conv_params.N, - conv_params.iS[0] + 2 * conv_params.pad[0], - conv_params.iS[1] + 2 * conv_params.pad[1], - conv_params.C}; - - padded_shape[1] = 6 * ((padded_shape[1] - 2 + 5) / 6) + 2; - padded_shape[2] = 6 * ((padded_shape[2] - 2 + 5) / 6) + 2; - - array in_padded(std::move(padded_shape), in.dtype(), nullptr, {}); - - // Fill with zeros - array zero_arr = array(0, in.dtype()); - fill_gpu(zero_arr, in_padded, s); - copies_w.push_back(zero_arr); - - // Pick input slice from padded - size_t data_offset = conv_params.pad[0] * in_padded.strides()[1] + - conv_params.pad[1] * in_padded.strides()[2]; - array in_padded_slice(in.shape(), in_padded.dtype(), nullptr, {}); - in_padded_slice.copy_shared_buffer( - in_padded, - in_padded.strides(), - in_padded.flags(), - in_padded_slice.size(), - data_offset); - - // Copy input values into the slice - copy_gpu_inplace(in, in_padded_slice, CopyType::GeneralGeneral, s); - - copies_w.push_back(in_padded_slice); - copies_w.push_back(in_padded); - - MLXConvParams<2> conv_params_updated{ - /* const int N = */ static_cast(in_padded.shape(0)), - /* const int C = */ static_cast(in_padded.shape(3)), - /* const int O = */ static_cast(wt.shape(0)), - /* const int iS[NDIM] = */ - {static_cast(in_padded.shape(1)), - static_cast(in_padded.shape(2))}, - /* const int wS[NDIM] = */ - {static_cast(wt.shape(1)), static_cast(wt.shape(2))}, - /* const int oS[NDIM] = */ - {static_cast(out.shape(1)), static_cast(out.shape(2))}, - /* const int str[NDIM] = */ {1, 1}, - /* const int pad[NDIM] = */ {0, 0}, - /* const int kdil[NDIM] = */ {1, 1}, - /* const int idil[NDIM] = */ {1, 1}, - /* const size_t in_strides[NDIM + 2] = */ - {in_padded.strides()[0], - in_padded.strides()[1], - in_padded.strides()[2], - in_padded.strides()[3]}, - /* const size_t wt_strides[NDIM + 2] = */ - {wt.strides()[0], wt.strides()[1], wt.strides()[2], wt.strides()[3]}, - /* const size_t out_strides[NDIM + 2] = */ - {out.strides()[0], out.strides()[1], out.strides()[2], out.strides()[3]}, - /* const int groups = */ 1, - /* const bool flip = */ false, - }; - int O_c = conv_params.O; int C_c = conv_params.C; @@ -653,10 +592,10 @@ void winograd_conv_2D_gpu( auto kernel = d.get_kernel(kname.str()); compute_encoder.set_compute_pipeline_state(kernel); - compute_encoder.set_input_array(in_padded, 0); + compute_encoder.set_input_array(in, 0); compute_encoder.set_output_array(inp_wg, 1); - compute_encoder.set_bytes(conv_params_updated, 2); + compute_encoder.set_bytes(conv_params, 2); MTL::Size group_dims = MTL::Size(32, wn, wm); MTL::Size grid_dims = MTL::Size(N_tiles_w, N_tiles_h, N_tiles_n); @@ -703,7 +642,7 @@ void winograd_conv_2D_gpu( compute_encoder.set_input_array(out_wg, 0); compute_encoder.set_output_array(out, 1); - compute_encoder.set_bytes(conv_params_updated, 2); + compute_encoder.set_bytes(conv_params, 2); MTL::Size group_dims = MTL::Size(32, wn, wm); MTL::Size grid_dims = MTL::Size(N_tiles_w, N_tiles_h, N_tiles_n); diff --git a/mlx/backend/metal/kernels/conv.metal b/mlx/backend/metal/kernels/conv.metal index 13ee239dc..06b62fcef 100644 --- a/mlx/backend/metal/kernels/conv.metal +++ b/mlx/backend/metal/kernels/conv.metal @@ -445,14 +445,21 @@ winograd_conv_2d_input_transform( // Resolve input tile constexpr int TH = (A / WM); constexpr int TW = (A / WN); - int kh = TH * (simd_group_id / WN); - int kw = TW * (simd_group_id % WN); - int bh = M * tid.y + kh; - int bw = M * tid.x + kw; + const int kh = TH * (simd_group_id / WN); + const int kw = TW * (simd_group_id % WN); + const int bh = M * tid.y + kh - params.pad[1]; + const int bw = M * tid.x + kw - params.pad[0]; + + const bool is_edge_w_lo = bw < 0; + const bool is_edge_h_lo = bh < 0; + const bool is_edge_w_hi = bw + (TW - 1) >= params.iS[0]; + const bool is_edge_h_hi = bh + (TH - 1) >= params.iS[1]; + const bool is_edge = + is_edge_w_lo || is_edge_h_lo || is_edge_w_hi || is_edge_h_hi; // Move to the correct input tile - inp_in += tid.z * params.in_strides[0] + bh * params.in_strides[1] + - bw * params.in_strides[2]; + inp_in += tid.z * params.in_strides[0] + bh * int64_t(params.in_strides[1]) + + bw * int64_t(params.in_strides[2]); // Pre compute strides int jump_in[TH][TW]; @@ -484,8 +491,21 @@ winograd_conv_2d_input_transform( for (int h = 0; h < TH; h++) { for (int w = 0; w < TW; w++) { const device T* in_ptr = inp_in + jump_in[h][w]; - for (int c = simd_lane_id; c < BC; c += 32) { - Is[kh + h][kw + w][c] = in_ptr[c]; + if (is_edge) { + if (((bh + h) < 0 || (bh + h) >= params.iS[1]) || + ((bw + w) < 0 || (bw + w) >= params.iS[0])) { + for (int c = simd_lane_id; c < BC; c += 32) { + Is[kh + h][kw + w][c] = T(0); + } + } else { + for (int c = simd_lane_id; c < BC; c += 32) { + Is[kh + h][kw + w][c] = in_ptr[c]; + } + } + } else { + for (int c = simd_lane_id; c < BC; c += 32) { + Is[kh + h][kw + w][c] = in_ptr[c]; + } } } }