fix conv2d bug + faster conv 1d (#2195)

* fix conv2d bug + faster conv 1d

* revert sort + flaky test
This commit is contained in:
Awni Hannun 2025-05-18 06:05:11 -07:00 committed by GitHub
parent 0654543dcc
commit 8576e6fe36
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 170 additions and 138 deletions

View File

@ -1,5 +1,4 @@
// Copyright © 2023-2024 Apple Inc.
#include <algorithm>
#include <cassert>
#include <numeric>
@ -178,83 +177,6 @@ void explicit_gemm_conv_group_ND_gpu(
/*copies = */ copies);
}
void conv_1D_gpu(
const Stream& s,
metal::Device& d,
const array& in,
const array& wt,
array out,
const std::vector<int>& padding,
const std::vector<int>& wt_strides,
const std::vector<int>& wt_dilation,
const std::vector<int>& in_dilation,
int groups,
bool flip) {
// Make conv params
MLXConvParams<1> conv_params{
/* const int N = */ static_cast<int>(in.shape(0)),
/* const int C = */ static_cast<int>(in.shape(2)),
/* const int O = */ static_cast<int>(wt.shape(0)),
/* const int iS[NDIM] = */ {static_cast<int>(in.shape(1))},
/* const int wS[NDIM] = */ {static_cast<int>(wt.shape(1))},
/* const int oS[NDIM] = */ {static_cast<int>(out.shape(1))},
/* const int str[NDIM] = */ {wt_strides[0]},
/* const int pad[NDIM] = */ {padding[0]},
/* const int kdil[NDIM] = */ {wt_dilation[0]},
/* const int idil[NDIM] = */ {in_dilation[0]},
/* const size_t in_strides[NDIM + 2] = */
{in.strides()[0], in.strides()[1], in.strides()[2]},
/* const size_t wt_strides[NDIM + 2] = */
{wt.strides()[0], wt.strides()[1], wt.strides()[2]},
/* const size_t out_strides[NDIM + 2] = */
{out.strides()[0], out.strides()[1], out.strides()[2]},
/* const int groups = */ groups,
/* const bool flip = */ flip};
// Direct to explicit gemm conv
if (groups > 1) {
return explicit_gemm_conv_group_ND_gpu(s, d, in, wt, out, conv_params);
} else {
return explicit_gemm_conv_ND_gpu(s, d, in, wt, out, conv_params);
}
}
void slow_conv_2D_gpu(
const Stream& s,
metal::Device& d,
const array& in,
const array& wt,
array out,
const MLXConvParams<2>& conv_params) {
int bm = 16, bn = 8;
int tm = 4, tn = 4;
std::ostringstream kname;
kname << "naive_conv_2d_" << type_to_name(out) << "_bm" << bm << "_bn" << bn
<< "_tm" << tm << "_tn" << tn;
// Encode and dispatch kernel
auto& compute_encoder = d.get_command_encoder(s.index);
auto kernel = d.get_kernel(kname.str());
compute_encoder.set_compute_pipeline_state(kernel);
size_t n_pixels = conv_params.oS[0] * conv_params.oS[1];
size_t grid_dim_x = (n_pixels + (tm * bm) - 1) / (tm * bm);
size_t grid_dim_y = (conv_params.O + (tn * bn) - 1) / (tn * bn);
size_t grid_dim_z = conv_params.N;
MTL::Size group_dims = MTL::Size(bm, bn, 1);
MTL::Size grid_dims = MTL::Size(grid_dim_x, grid_dim_y, grid_dim_z);
compute_encoder.set_input_array(in, 0);
compute_encoder.set_input_array(wt, 1);
compute_encoder.set_output_array(out, 2);
compute_encoder.set_bytes(conv_params, 3);
compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
}
void implicit_gemm_conv_2D_gpu(
const Stream& s,
metal::Device& d,
@ -771,6 +693,141 @@ void depthwise_conv_2D_gpu(
compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
}
void dispatch_conv_2D_gpu(
const Stream& s,
metal::Device& d,
const array& in,
const array& wt,
array out,
const MLXConvParams<2>& conv_params,
std::vector<array>& copies) {
bool is_stride_one = conv_params.str[0] == 1 && conv_params.str[1] == 1;
bool is_kdil_one = conv_params.kdil[0] == 1 && conv_params.kdil[1] == 1;
bool is_idil_one = conv_params.idil[0] == 1 && conv_params.idil[1] == 1;
if (is_idil_one && conv_params.groups > 1) {
const int C_per_group = conv_params.C / conv_params.groups;
const int O_per_group = conv_params.O / conv_params.groups;
if (C_per_group == 1 && O_per_group == 1 && is_kdil_one &&
conv_params.wS[0] <= 7 && conv_params.wS[1] <= 7 &&
conv_params.str[0] <= 2 && conv_params.str[1] <= 2 &&
conv_params.oS[0] % 8 == 0 && conv_params.oS[1] % 8 == 0 &&
conv_params.wt_strides[1] == conv_params.wS[1] &&
conv_params.C % 16 == 0 && conv_params.C == conv_params.O) {
return depthwise_conv_2D_gpu(s, d, in, wt, out, conv_params);
}
if ((C_per_group <= 4 || C_per_group % 16 == 0) &&
(O_per_group <= 16 || O_per_group % 16 == 0)) {
return implicit_gemm_conv_2D_gpu(s, d, in, wt, out, conv_params);
} else {
return explicit_gemm_conv_group_ND_gpu(s, d, in, wt, out, conv_params);
}
}
// Direct to winograd conv
bool inp_large =
(conv_params.N * conv_params.iS[0] * conv_params.iS[1]) >= 1ul << 12;
bool channels_large = (conv_params.C + conv_params.O) >= 256;
if (!conv_params.flip && is_stride_one && is_kdil_one && is_idil_one &&
conv_params.wS[0] == 3 && conv_params.wS[1] == 3 &&
conv_params.C % 32 == 0 && conv_params.O % 32 == 0 && inp_large &&
channels_large) {
return winograd_conv_2D_gpu(s, d, in, wt, out, conv_params, copies);
}
// Direct to implicit gemm conv
if (is_idil_one && (conv_params.C <= 4 || conv_params.C % 16 == 0) &&
(conv_params.O <= 16 || conv_params.O % 16 == 0)) {
return implicit_gemm_conv_2D_gpu(s, d, in, wt, out, conv_params);
}
else if (conv_params.C % 16 == 0 && conv_params.O % 16 == 0) {
return implicit_gemm_conv_2D_general_gpu(s, d, in, wt, out, conv_params);
}
// Direct to explicit gemm conv
else {
return explicit_gemm_conv_ND_gpu(s, d, in, wt, out, conv_params);
}
}
void conv_1D_gpu(
const Stream& s,
metal::Device& d,
const array& in,
const array& wt,
array out,
const std::vector<int>& padding,
const std::vector<int>& wt_strides,
const std::vector<int>& wt_dilation,
const std::vector<int>& in_dilation,
int groups,
bool flip,
std::vector<array>& copies) {
bool is_idil_one = in_dilation[0] == 1;
int C = in.shape(2);
int O = wt.shape(0);
const int C_per_group = in.shape(2) / groups;
const int O_per_group = wt.shape(0) / groups;
// Direct to implicit gemm conv
if (is_idil_one && (C_per_group <= 4 || C_per_group % 16 == 0) &&
(O_per_group <= 16 || O_per_group % 16 == 0)) {
MLXConvParams<2> conv_params{
/* const int N = */ static_cast<int>(in.shape(0)),
/* const int C = */ C,
/* const int O = */ O,
/* const int iS[NDIM] = */ {static_cast<int>(in.shape(1)), 1},
/* const int wS[NDIM] = */ {static_cast<int>(wt.shape(1)), 1},
/* const int oS[NDIM] = */ {static_cast<int>(out.shape(1)), 1},
/* const int str[NDIM] = */ {wt_strides[0], 1},
/* const int pad[NDIM] = */ {padding[0], 0},
/* const int kdil[NDIM] = */ {wt_dilation[0], 1},
/* const int idil[NDIM] = */ {in_dilation[0], 1},
/* const size_t in_strides[NDIM + 2] = */
{in.strides()[0], in.strides()[1], 0, in.strides()[2]},
/* const size_t wt_strides[NDIM + 2] = */
{wt.strides()[0], wt.strides()[1], 0, wt.strides()[2]},
/* const size_t out_strides[NDIM + 2] = */
{out.strides()[0], out.strides()[1], 0, out.strides()[2]},
/* const int groups = */ groups,
/* const bool flip = */ flip};
dispatch_conv_2D_gpu(s, d, in, wt, out, conv_params, copies);
return;
}
// Make conv params
MLXConvParams<1> conv_params{
/* const int N = */ static_cast<int>(in.shape(0)),
/* const int C = */ static_cast<int>(in.shape(2)),
/* const int O = */ static_cast<int>(wt.shape(0)),
/* const int iS[NDIM] = */ {static_cast<int>(in.shape(1))},
/* const int wS[NDIM] = */ {static_cast<int>(wt.shape(1))},
/* const int oS[NDIM] = */ {static_cast<int>(out.shape(1))},
/* const int str[NDIM] = */ {wt_strides[0]},
/* const int pad[NDIM] = */ {padding[0]},
/* const int kdil[NDIM] = */ {wt_dilation[0]},
/* const int idil[NDIM] = */ {in_dilation[0]},
/* const size_t in_strides[NDIM + 2] = */
{in.strides()[0], in.strides()[1], in.strides()[2]},
/* const size_t wt_strides[NDIM + 2] = */
{wt.strides()[0], wt.strides()[1], wt.strides()[2]},
/* const size_t out_strides[NDIM + 2] = */
{out.strides()[0], out.strides()[1], out.strides()[2]},
/* const int groups = */ groups,
/* const bool flip = */ flip};
// Direct to explicit gemm conv
if (groups > 1) {
return explicit_gemm_conv_group_ND_gpu(s, d, in, wt, out, conv_params);
} else {
return explicit_gemm_conv_ND_gpu(s, d, in, wt, out, conv_params);
}
}
void conv_2D_gpu(
const Stream& s,
metal::Device& d,
@ -808,57 +865,7 @@ void conv_2D_gpu(
/* const int groups = */ groups,
/* const bool flip = */ flip,
};
bool is_stride_one = conv_params.str[0] == 1 && conv_params.str[1] == 1;
bool is_kdil_one = conv_params.kdil[0] == 1 && conv_params.kdil[1] == 1;
bool is_idil_one = conv_params.idil[0] == 1 && conv_params.idil[1] == 1;
if (is_idil_one && groups > 1) {
const int C_per_group = conv_params.C / groups;
const int O_per_group = conv_params.O / groups;
if (C_per_group == 1 && O_per_group == 1 && is_kdil_one &&
conv_params.wS[0] <= 7 && conv_params.wS[1] <= 7 &&
conv_params.str[0] <= 2 && conv_params.str[1] <= 2 &&
conv_params.oS[0] % 8 == 0 && conv_params.oS[1] % 8 == 0 &&
conv_params.wt_strides[1] == conv_params.wS[1] &&
conv_params.C % 16 == 0 && conv_params.C == conv_params.O) {
return depthwise_conv_2D_gpu(s, d, in, wt, out, conv_params);
}
if ((C_per_group <= 4 || C_per_group % 16 == 0) &&
(O_per_group <= 16 || O_per_group % 16 == 0)) {
return implicit_gemm_conv_2D_gpu(s, d, in, wt, out, conv_params);
} else {
return explicit_gemm_conv_group_ND_gpu(s, d, in, wt, out, conv_params);
}
}
// Direct to winograd conv
bool inp_large =
(conv_params.N * conv_params.iS[0] * conv_params.iS[1]) >= 1ul << 12;
bool channels_large = (conv_params.C + conv_params.O) >= 256;
if (!flip && is_stride_one && is_kdil_one && is_idil_one &&
conv_params.wS[0] == 3 && conv_params.wS[1] == 3 &&
conv_params.C % 32 == 0 && conv_params.O % 32 == 0 && inp_large &&
channels_large) {
return winograd_conv_2D_gpu(s, d, in, wt, out, conv_params, copies);
}
// Direct to implicit gemm conv
if (is_idil_one && (conv_params.C <= 4 || conv_params.C % 16 == 0) &&
(conv_params.O <= 16 || conv_params.O % 16 == 0)) {
return implicit_gemm_conv_2D_gpu(s, d, in, wt, out, conv_params);
}
else if (conv_params.C % 16 == 0 && conv_params.O % 16 == 0) {
return implicit_gemm_conv_2D_general_gpu(s, d, in, wt, out, conv_params);
}
// Direct to explicit gemm conv
else {
return explicit_gemm_conv_ND_gpu(s, d, in, wt, out, conv_params);
}
dispatch_conv_2D_gpu(s, d, in, wt, out, conv_params, copies);
}
void conv_3D_gpu(
@ -988,7 +995,8 @@ void Convolution::eval_gpu(const std::vector<array>& inputs, array& out) {
kernel_dilation_,
input_dilation_,
groups_,
flip_);
flip_,
copies);
}
// Throw error
else {

View File

@ -381,6 +381,7 @@ struct Conv2DWeightBlockLoader {
const constant MLXConvParams<2>* params;
int weight_hw;
int weight_step;
const int read_n;
const bool do_read;
@ -402,6 +403,7 @@ struct Conv2DWeightBlockLoader {
src(src_ + bi * src_ld + bj),
params(params_),
weight_hw(0),
weight_step(params->C / params->groups),
read_n(offsets.y + bi),
do_read(read_n + n_rows * TROWS <= gemm_params_->N) {}
@ -435,13 +437,13 @@ struct Conv2DWeightBlockLoader {
/* Iteration helper */
METAL_FUNC void next() {
if (++weight_hw < (params->wS[1] * params->wS[0])) {
src += params->wt_strides[2];
src += weight_step;
return;
}
weight_hw = 0;
src += BK - (params->wS[1] * params->wS[0] - 1) * params->wt_strides[2];
src += BK - (params->wS[1] * params->wS[0] - 1) * weight_step;
}
};

View File

@ -272,7 +272,7 @@ struct Conv2DWeightBlockLoaderSmallChannels {
return;
}
const device T* curr_src = src + weight_hw * params->wt_strides[2];
const device T* curr_src = src + weight_hw * (params->C / params->groups);
if (BN != 8 || do_read) {
STEEL_PRAGMA_UNROLL

View File

@ -3584,21 +3584,21 @@ Shape conv_out_shape(
if (pads_lo.size() != spatial_dims || pads_hi.size() != spatial_dims) {
std::ostringstream msg;
msg << "[conv] Invalid padding " << pads_lo << " | " << pads_hi << "for "
msg << "[conv] Invalid padding " << pads_lo << " | " << pads_hi << " for "
<< spatial_dims << "D convolution.";
throw std::invalid_argument(msg.str());
}
if (kernel_dilation.size() != spatial_dims) {
std::ostringstream msg;
msg << "[conv] Invalid kernel dilation " << kernel_dilation << "for "
msg << "[conv] Invalid kernel dilation " << kernel_dilation << " for "
<< spatial_dims << "D convolution.";
throw std::invalid_argument(msg.str());
}
if (input_dilation.size() != spatial_dims) {
std::ostringstream msg;
msg << "[conv] Invalid input dilation " << input_dilation << "for "
msg << "[conv] Invalid input dilation " << input_dilation << " for "
<< spatial_dims << "D convolution.";
throw std::invalid_argument(msg.str());
}

View File

@ -1152,6 +1152,27 @@ class TestConv(mlx_tests.MLXTestCase):
)
self.assertEqual(grads.shape, k_shape)
def test_1d_conv_with_2d(self):
x = mx.random.uniform(shape=(2, 10, 16))
y = mx.random.normal(shape=(16, 3, 16))
out = mx.conv1d(x, y, padding=1)
out_2d = mx.conv2d(
mx.expand_dims(x, axis=2), mx.expand_dims(y, axis=2), padding=(1, 0)
)
self.assertTrue(mx.allclose(out, out_2d.squeeze(2)))
x = mx.random.uniform(shape=(2, 10, 4))
y = mx.random.normal(shape=(4, 3, 4))
out = mx.conv1d(x, y, padding=1)
out_2d = mx.conv2d(
mx.expand_dims(x, axis=2), mx.expand_dims(y, axis=2), padding=(1, 0)
)
self.assertTrue(mx.allclose(out, out_2d.squeeze(2)))
if __name__ == "__main__":
unittest.main()

View File

@ -634,6 +634,7 @@ class TestVmap(mlx_tests.MLXTestCase):
self.assertEqual(fy.shape, (4, 5, 6, 7))
def test_leaks(self):
mx.synchronize()
if mx.metal.is_available():
mem_pre = mx.get_active_memory()
else: