Remove unnecessary copy from winograd

This commit is contained in:
Jagrit Digani 2025-01-06 14:06:03 -08:00
parent 058d6ce683
commit f14b4d72de
2 changed files with 31 additions and 72 deletions

View File

@ -541,67 +541,6 @@ void winograd_conv_2D_gpu(
array out,
const MLXConvParams<2>& conv_params,
std::vector<array>& copies_w) {
Shape padded_shape = {
conv_params.N,
conv_params.iS[0] + 2 * conv_params.pad[0],
conv_params.iS[1] + 2 * conv_params.pad[1],
conv_params.C};
padded_shape[1] = 6 * ((padded_shape[1] - 2 + 5) / 6) + 2;
padded_shape[2] = 6 * ((padded_shape[2] - 2 + 5) / 6) + 2;
array in_padded(std::move(padded_shape), in.dtype(), nullptr, {});
// Fill with zeros
array zero_arr = array(0, in.dtype());
fill_gpu(zero_arr, in_padded, s);
copies_w.push_back(zero_arr);
// Pick input slice from padded
size_t data_offset = conv_params.pad[0] * in_padded.strides()[1] +
conv_params.pad[1] * in_padded.strides()[2];
array in_padded_slice(in.shape(), in_padded.dtype(), nullptr, {});
in_padded_slice.copy_shared_buffer(
in_padded,
in_padded.strides(),
in_padded.flags(),
in_padded_slice.size(),
data_offset);
// Copy input values into the slice
copy_gpu_inplace(in, in_padded_slice, CopyType::GeneralGeneral, s);
copies_w.push_back(in_padded_slice);
copies_w.push_back(in_padded);
MLXConvParams<2> conv_params_updated{
/* const int N = */ static_cast<int>(in_padded.shape(0)),
/* const int C = */ static_cast<int>(in_padded.shape(3)),
/* const int O = */ static_cast<int>(wt.shape(0)),
/* const int iS[NDIM] = */
{static_cast<int>(in_padded.shape(1)),
static_cast<int>(in_padded.shape(2))},
/* const int wS[NDIM] = */
{static_cast<int>(wt.shape(1)), static_cast<int>(wt.shape(2))},
/* const int oS[NDIM] = */
{static_cast<int>(out.shape(1)), static_cast<int>(out.shape(2))},
/* const int str[NDIM] = */ {1, 1},
/* const int pad[NDIM] = */ {0, 0},
/* const int kdil[NDIM] = */ {1, 1},
/* const int idil[NDIM] = */ {1, 1},
/* const size_t in_strides[NDIM + 2] = */
{in_padded.strides()[0],
in_padded.strides()[1],
in_padded.strides()[2],
in_padded.strides()[3]},
/* const size_t wt_strides[NDIM + 2] = */
{wt.strides()[0], wt.strides()[1], wt.strides()[2], wt.strides()[3]},
/* const size_t out_strides[NDIM + 2] = */
{out.strides()[0], out.strides()[1], out.strides()[2], out.strides()[3]},
/* const int groups = */ 1,
/* const bool flip = */ false,
};
int O_c = conv_params.O;
int C_c = conv_params.C;
@ -653,10 +592,10 @@ void winograd_conv_2D_gpu(
auto kernel = d.get_kernel(kname.str());
compute_encoder.set_compute_pipeline_state(kernel);
compute_encoder.set_input_array(in_padded, 0);
compute_encoder.set_input_array(in, 0);
compute_encoder.set_output_array(inp_wg, 1);
compute_encoder.set_bytes(conv_params_updated, 2);
compute_encoder.set_bytes(conv_params, 2);
MTL::Size group_dims = MTL::Size(32, wn, wm);
MTL::Size grid_dims = MTL::Size(N_tiles_w, N_tiles_h, N_tiles_n);
@ -703,7 +642,7 @@ void winograd_conv_2D_gpu(
compute_encoder.set_input_array(out_wg, 0);
compute_encoder.set_output_array(out, 1);
compute_encoder.set_bytes(conv_params_updated, 2);
compute_encoder.set_bytes(conv_params, 2);
MTL::Size group_dims = MTL::Size(32, wn, wm);
MTL::Size grid_dims = MTL::Size(N_tiles_w, N_tiles_h, N_tiles_n);

View File

@ -445,14 +445,21 @@ winograd_conv_2d_input_transform(
// Resolve input tile
constexpr int TH = (A / WM);
constexpr int TW = (A / WN);
int kh = TH * (simd_group_id / WN);
int kw = TW * (simd_group_id % WN);
int bh = M * tid.y + kh;
int bw = M * tid.x + kw;
const int kh = TH * (simd_group_id / WN);
const int kw = TW * (simd_group_id % WN);
const int bh = M * tid.y + kh - params.pad[1];
const int bw = M * tid.x + kw - params.pad[0];
const bool is_edge_w_lo = bw < 0;
const bool is_edge_h_lo = bh < 0;
const bool is_edge_w_hi = bw + (TW - 1) >= params.iS[0];
const bool is_edge_h_hi = bh + (TH - 1) >= params.iS[1];
const bool is_edge =
is_edge_w_lo || is_edge_h_lo || is_edge_w_hi || is_edge_h_hi;
// Move to the correct input tile
inp_in += tid.z * params.in_strides[0] + bh * params.in_strides[1] +
bw * params.in_strides[2];
inp_in += tid.z * params.in_strides[0] + bh * int64_t(params.in_strides[1]) +
bw * int64_t(params.in_strides[2]);
// Pre compute strides
int jump_in[TH][TW];
@ -484,8 +491,21 @@ winograd_conv_2d_input_transform(
for (int h = 0; h < TH; h++) {
for (int w = 0; w < TW; w++) {
const device T* in_ptr = inp_in + jump_in[h][w];
for (int c = simd_lane_id; c < BC; c += 32) {
Is[kh + h][kw + w][c] = in_ptr[c];
if (is_edge) {
if (((bh + h) < 0 || (bh + h) >= params.iS[1]) ||
((bw + w) < 0 || (bw + w) >= params.iS[0])) {
for (int c = simd_lane_id; c < BC; c += 32) {
Is[kh + h][kw + w][c] = T(0);
}
} else {
for (int c = simd_lane_id; c < BC; c += 32) {
Is[kh + h][kw + w][c] = in_ptr[c];
}
}
} else {
for (int c = simd_lane_id; c < BC; c += 32) {
Is[kh + h][kw + w][c] = in_ptr[c];
}
}
}
}