mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-24 09:21:16 +08:00
Remove unnecessary copy from winograd
This commit is contained in:
parent
058d6ce683
commit
f14b4d72de
@ -541,67 +541,6 @@ void winograd_conv_2D_gpu(
|
|||||||
array out,
|
array out,
|
||||||
const MLXConvParams<2>& conv_params,
|
const MLXConvParams<2>& conv_params,
|
||||||
std::vector<array>& copies_w) {
|
std::vector<array>& copies_w) {
|
||||||
Shape padded_shape = {
|
|
||||||
conv_params.N,
|
|
||||||
conv_params.iS[0] + 2 * conv_params.pad[0],
|
|
||||||
conv_params.iS[1] + 2 * conv_params.pad[1],
|
|
||||||
conv_params.C};
|
|
||||||
|
|
||||||
padded_shape[1] = 6 * ((padded_shape[1] - 2 + 5) / 6) + 2;
|
|
||||||
padded_shape[2] = 6 * ((padded_shape[2] - 2 + 5) / 6) + 2;
|
|
||||||
|
|
||||||
array in_padded(std::move(padded_shape), in.dtype(), nullptr, {});
|
|
||||||
|
|
||||||
// Fill with zeros
|
|
||||||
array zero_arr = array(0, in.dtype());
|
|
||||||
fill_gpu(zero_arr, in_padded, s);
|
|
||||||
copies_w.push_back(zero_arr);
|
|
||||||
|
|
||||||
// Pick input slice from padded
|
|
||||||
size_t data_offset = conv_params.pad[0] * in_padded.strides()[1] +
|
|
||||||
conv_params.pad[1] * in_padded.strides()[2];
|
|
||||||
array in_padded_slice(in.shape(), in_padded.dtype(), nullptr, {});
|
|
||||||
in_padded_slice.copy_shared_buffer(
|
|
||||||
in_padded,
|
|
||||||
in_padded.strides(),
|
|
||||||
in_padded.flags(),
|
|
||||||
in_padded_slice.size(),
|
|
||||||
data_offset);
|
|
||||||
|
|
||||||
// Copy input values into the slice
|
|
||||||
copy_gpu_inplace(in, in_padded_slice, CopyType::GeneralGeneral, s);
|
|
||||||
|
|
||||||
copies_w.push_back(in_padded_slice);
|
|
||||||
copies_w.push_back(in_padded);
|
|
||||||
|
|
||||||
MLXConvParams<2> conv_params_updated{
|
|
||||||
/* const int N = */ static_cast<int>(in_padded.shape(0)),
|
|
||||||
/* const int C = */ static_cast<int>(in_padded.shape(3)),
|
|
||||||
/* const int O = */ static_cast<int>(wt.shape(0)),
|
|
||||||
/* const int iS[NDIM] = */
|
|
||||||
{static_cast<int>(in_padded.shape(1)),
|
|
||||||
static_cast<int>(in_padded.shape(2))},
|
|
||||||
/* const int wS[NDIM] = */
|
|
||||||
{static_cast<int>(wt.shape(1)), static_cast<int>(wt.shape(2))},
|
|
||||||
/* const int oS[NDIM] = */
|
|
||||||
{static_cast<int>(out.shape(1)), static_cast<int>(out.shape(2))},
|
|
||||||
/* const int str[NDIM] = */ {1, 1},
|
|
||||||
/* const int pad[NDIM] = */ {0, 0},
|
|
||||||
/* const int kdil[NDIM] = */ {1, 1},
|
|
||||||
/* const int idil[NDIM] = */ {1, 1},
|
|
||||||
/* const size_t in_strides[NDIM + 2] = */
|
|
||||||
{in_padded.strides()[0],
|
|
||||||
in_padded.strides()[1],
|
|
||||||
in_padded.strides()[2],
|
|
||||||
in_padded.strides()[3]},
|
|
||||||
/* const size_t wt_strides[NDIM + 2] = */
|
|
||||||
{wt.strides()[0], wt.strides()[1], wt.strides()[2], wt.strides()[3]},
|
|
||||||
/* const size_t out_strides[NDIM + 2] = */
|
|
||||||
{out.strides()[0], out.strides()[1], out.strides()[2], out.strides()[3]},
|
|
||||||
/* const int groups = */ 1,
|
|
||||||
/* const bool flip = */ false,
|
|
||||||
};
|
|
||||||
|
|
||||||
int O_c = conv_params.O;
|
int O_c = conv_params.O;
|
||||||
int C_c = conv_params.C;
|
int C_c = conv_params.C;
|
||||||
|
|
||||||
@ -653,10 +592,10 @@ void winograd_conv_2D_gpu(
|
|||||||
auto kernel = d.get_kernel(kname.str());
|
auto kernel = d.get_kernel(kname.str());
|
||||||
compute_encoder.set_compute_pipeline_state(kernel);
|
compute_encoder.set_compute_pipeline_state(kernel);
|
||||||
|
|
||||||
compute_encoder.set_input_array(in_padded, 0);
|
compute_encoder.set_input_array(in, 0);
|
||||||
compute_encoder.set_output_array(inp_wg, 1);
|
compute_encoder.set_output_array(inp_wg, 1);
|
||||||
|
|
||||||
compute_encoder.set_bytes(conv_params_updated, 2);
|
compute_encoder.set_bytes(conv_params, 2);
|
||||||
|
|
||||||
MTL::Size group_dims = MTL::Size(32, wn, wm);
|
MTL::Size group_dims = MTL::Size(32, wn, wm);
|
||||||
MTL::Size grid_dims = MTL::Size(N_tiles_w, N_tiles_h, N_tiles_n);
|
MTL::Size grid_dims = MTL::Size(N_tiles_w, N_tiles_h, N_tiles_n);
|
||||||
@ -703,7 +642,7 @@ void winograd_conv_2D_gpu(
|
|||||||
compute_encoder.set_input_array(out_wg, 0);
|
compute_encoder.set_input_array(out_wg, 0);
|
||||||
compute_encoder.set_output_array(out, 1);
|
compute_encoder.set_output_array(out, 1);
|
||||||
|
|
||||||
compute_encoder.set_bytes(conv_params_updated, 2);
|
compute_encoder.set_bytes(conv_params, 2);
|
||||||
|
|
||||||
MTL::Size group_dims = MTL::Size(32, wn, wm);
|
MTL::Size group_dims = MTL::Size(32, wn, wm);
|
||||||
MTL::Size grid_dims = MTL::Size(N_tiles_w, N_tiles_h, N_tiles_n);
|
MTL::Size grid_dims = MTL::Size(N_tiles_w, N_tiles_h, N_tiles_n);
|
||||||
|
@ -445,14 +445,21 @@ winograd_conv_2d_input_transform(
|
|||||||
// Resolve input tile
|
// Resolve input tile
|
||||||
constexpr int TH = (A / WM);
|
constexpr int TH = (A / WM);
|
||||||
constexpr int TW = (A / WN);
|
constexpr int TW = (A / WN);
|
||||||
int kh = TH * (simd_group_id / WN);
|
const int kh = TH * (simd_group_id / WN);
|
||||||
int kw = TW * (simd_group_id % WN);
|
const int kw = TW * (simd_group_id % WN);
|
||||||
int bh = M * tid.y + kh;
|
const int bh = M * tid.y + kh - params.pad[1];
|
||||||
int bw = M * tid.x + kw;
|
const int bw = M * tid.x + kw - params.pad[0];
|
||||||
|
|
||||||
|
const bool is_edge_w_lo = bw < 0;
|
||||||
|
const bool is_edge_h_lo = bh < 0;
|
||||||
|
const bool is_edge_w_hi = bw + (TW - 1) >= params.iS[0];
|
||||||
|
const bool is_edge_h_hi = bh + (TH - 1) >= params.iS[1];
|
||||||
|
const bool is_edge =
|
||||||
|
is_edge_w_lo || is_edge_h_lo || is_edge_w_hi || is_edge_h_hi;
|
||||||
|
|
||||||
// Move to the correct input tile
|
// Move to the correct input tile
|
||||||
inp_in += tid.z * params.in_strides[0] + bh * params.in_strides[1] +
|
inp_in += tid.z * params.in_strides[0] + bh * int64_t(params.in_strides[1]) +
|
||||||
bw * params.in_strides[2];
|
bw * int64_t(params.in_strides[2]);
|
||||||
|
|
||||||
// Pre compute strides
|
// Pre compute strides
|
||||||
int jump_in[TH][TW];
|
int jump_in[TH][TW];
|
||||||
@ -484,10 +491,23 @@ winograd_conv_2d_input_transform(
|
|||||||
for (int h = 0; h < TH; h++) {
|
for (int h = 0; h < TH; h++) {
|
||||||
for (int w = 0; w < TW; w++) {
|
for (int w = 0; w < TW; w++) {
|
||||||
const device T* in_ptr = inp_in + jump_in[h][w];
|
const device T* in_ptr = inp_in + jump_in[h][w];
|
||||||
|
if (is_edge) {
|
||||||
|
if (((bh + h) < 0 || (bh + h) >= params.iS[1]) ||
|
||||||
|
((bw + w) < 0 || (bw + w) >= params.iS[0])) {
|
||||||
|
for (int c = simd_lane_id; c < BC; c += 32) {
|
||||||
|
Is[kh + h][kw + w][c] = T(0);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
for (int c = simd_lane_id; c < BC; c += 32) {
|
for (int c = simd_lane_id; c < BC; c += 32) {
|
||||||
Is[kh + h][kw + w][c] = in_ptr[c];
|
Is[kh + h][kw + w][c] = in_ptr[c];
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
} else {
|
||||||
|
for (int c = simd_lane_id; c < BC; c += 32) {
|
||||||
|
Is[kh + h][kw + w][c] = in_ptr[c];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
|
Loading…
Reference in New Issue
Block a user