mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-24 01:17:26 +08:00
Remove unnecessary copy from winograd
This commit is contained in:
parent
058d6ce683
commit
f14b4d72de
@ -541,67 +541,6 @@ void winograd_conv_2D_gpu(
|
||||
array out,
|
||||
const MLXConvParams<2>& conv_params,
|
||||
std::vector<array>& copies_w) {
|
||||
Shape padded_shape = {
|
||||
conv_params.N,
|
||||
conv_params.iS[0] + 2 * conv_params.pad[0],
|
||||
conv_params.iS[1] + 2 * conv_params.pad[1],
|
||||
conv_params.C};
|
||||
|
||||
padded_shape[1] = 6 * ((padded_shape[1] - 2 + 5) / 6) + 2;
|
||||
padded_shape[2] = 6 * ((padded_shape[2] - 2 + 5) / 6) + 2;
|
||||
|
||||
array in_padded(std::move(padded_shape), in.dtype(), nullptr, {});
|
||||
|
||||
// Fill with zeros
|
||||
array zero_arr = array(0, in.dtype());
|
||||
fill_gpu(zero_arr, in_padded, s);
|
||||
copies_w.push_back(zero_arr);
|
||||
|
||||
// Pick input slice from padded
|
||||
size_t data_offset = conv_params.pad[0] * in_padded.strides()[1] +
|
||||
conv_params.pad[1] * in_padded.strides()[2];
|
||||
array in_padded_slice(in.shape(), in_padded.dtype(), nullptr, {});
|
||||
in_padded_slice.copy_shared_buffer(
|
||||
in_padded,
|
||||
in_padded.strides(),
|
||||
in_padded.flags(),
|
||||
in_padded_slice.size(),
|
||||
data_offset);
|
||||
|
||||
// Copy input values into the slice
|
||||
copy_gpu_inplace(in, in_padded_slice, CopyType::GeneralGeneral, s);
|
||||
|
||||
copies_w.push_back(in_padded_slice);
|
||||
copies_w.push_back(in_padded);
|
||||
|
||||
MLXConvParams<2> conv_params_updated{
|
||||
/* const int N = */ static_cast<int>(in_padded.shape(0)),
|
||||
/* const int C = */ static_cast<int>(in_padded.shape(3)),
|
||||
/* const int O = */ static_cast<int>(wt.shape(0)),
|
||||
/* const int iS[NDIM] = */
|
||||
{static_cast<int>(in_padded.shape(1)),
|
||||
static_cast<int>(in_padded.shape(2))},
|
||||
/* const int wS[NDIM] = */
|
||||
{static_cast<int>(wt.shape(1)), static_cast<int>(wt.shape(2))},
|
||||
/* const int oS[NDIM] = */
|
||||
{static_cast<int>(out.shape(1)), static_cast<int>(out.shape(2))},
|
||||
/* const int str[NDIM] = */ {1, 1},
|
||||
/* const int pad[NDIM] = */ {0, 0},
|
||||
/* const int kdil[NDIM] = */ {1, 1},
|
||||
/* const int idil[NDIM] = */ {1, 1},
|
||||
/* const size_t in_strides[NDIM + 2] = */
|
||||
{in_padded.strides()[0],
|
||||
in_padded.strides()[1],
|
||||
in_padded.strides()[2],
|
||||
in_padded.strides()[3]},
|
||||
/* const size_t wt_strides[NDIM + 2] = */
|
||||
{wt.strides()[0], wt.strides()[1], wt.strides()[2], wt.strides()[3]},
|
||||
/* const size_t out_strides[NDIM + 2] = */
|
||||
{out.strides()[0], out.strides()[1], out.strides()[2], out.strides()[3]},
|
||||
/* const int groups = */ 1,
|
||||
/* const bool flip = */ false,
|
||||
};
|
||||
|
||||
int O_c = conv_params.O;
|
||||
int C_c = conv_params.C;
|
||||
|
||||
@ -653,10 +592,10 @@ void winograd_conv_2D_gpu(
|
||||
auto kernel = d.get_kernel(kname.str());
|
||||
compute_encoder.set_compute_pipeline_state(kernel);
|
||||
|
||||
compute_encoder.set_input_array(in_padded, 0);
|
||||
compute_encoder.set_input_array(in, 0);
|
||||
compute_encoder.set_output_array(inp_wg, 1);
|
||||
|
||||
compute_encoder.set_bytes(conv_params_updated, 2);
|
||||
compute_encoder.set_bytes(conv_params, 2);
|
||||
|
||||
MTL::Size group_dims = MTL::Size(32, wn, wm);
|
||||
MTL::Size grid_dims = MTL::Size(N_tiles_w, N_tiles_h, N_tiles_n);
|
||||
@ -703,7 +642,7 @@ void winograd_conv_2D_gpu(
|
||||
compute_encoder.set_input_array(out_wg, 0);
|
||||
compute_encoder.set_output_array(out, 1);
|
||||
|
||||
compute_encoder.set_bytes(conv_params_updated, 2);
|
||||
compute_encoder.set_bytes(conv_params, 2);
|
||||
|
||||
MTL::Size group_dims = MTL::Size(32, wn, wm);
|
||||
MTL::Size grid_dims = MTL::Size(N_tiles_w, N_tiles_h, N_tiles_n);
|
||||
|
@ -445,14 +445,21 @@ winograd_conv_2d_input_transform(
|
||||
// Resolve input tile
|
||||
constexpr int TH = (A / WM);
|
||||
constexpr int TW = (A / WN);
|
||||
int kh = TH * (simd_group_id / WN);
|
||||
int kw = TW * (simd_group_id % WN);
|
||||
int bh = M * tid.y + kh;
|
||||
int bw = M * tid.x + kw;
|
||||
const int kh = TH * (simd_group_id / WN);
|
||||
const int kw = TW * (simd_group_id % WN);
|
||||
const int bh = M * tid.y + kh - params.pad[1];
|
||||
const int bw = M * tid.x + kw - params.pad[0];
|
||||
|
||||
const bool is_edge_w_lo = bw < 0;
|
||||
const bool is_edge_h_lo = bh < 0;
|
||||
const bool is_edge_w_hi = bw + (TW - 1) >= params.iS[0];
|
||||
const bool is_edge_h_hi = bh + (TH - 1) >= params.iS[1];
|
||||
const bool is_edge =
|
||||
is_edge_w_lo || is_edge_h_lo || is_edge_w_hi || is_edge_h_hi;
|
||||
|
||||
// Move to the correct input tile
|
||||
inp_in += tid.z * params.in_strides[0] + bh * params.in_strides[1] +
|
||||
bw * params.in_strides[2];
|
||||
inp_in += tid.z * params.in_strides[0] + bh * int64_t(params.in_strides[1]) +
|
||||
bw * int64_t(params.in_strides[2]);
|
||||
|
||||
// Pre compute strides
|
||||
int jump_in[TH][TW];
|
||||
@ -484,8 +491,21 @@ winograd_conv_2d_input_transform(
|
||||
for (int h = 0; h < TH; h++) {
|
||||
for (int w = 0; w < TW; w++) {
|
||||
const device T* in_ptr = inp_in + jump_in[h][w];
|
||||
for (int c = simd_lane_id; c < BC; c += 32) {
|
||||
Is[kh + h][kw + w][c] = in_ptr[c];
|
||||
if (is_edge) {
|
||||
if (((bh + h) < 0 || (bh + h) >= params.iS[1]) ||
|
||||
((bw + w) < 0 || (bw + w) >= params.iS[0])) {
|
||||
for (int c = simd_lane_id; c < BC; c += 32) {
|
||||
Is[kh + h][kw + w][c] = T(0);
|
||||
}
|
||||
} else {
|
||||
for (int c = simd_lane_id; c < BC; c += 32) {
|
||||
Is[kh + h][kw + w][c] = in_ptr[c];
|
||||
}
|
||||
}
|
||||
} else {
|
||||
for (int c = simd_lane_id; c < BC; c += 32) {
|
||||
Is[kh + h][kw + w][c] = in_ptr[c];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user