mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-24 17:31:16 +08:00
fix conv2d bug + faster conv 1d (#2195)
* fix conv2d bug + faster conv 1d * revert sort + flaky test
This commit is contained in:
parent
0654543dcc
commit
8576e6fe36
@ -1,5 +1,4 @@
|
||||
// Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
#include <algorithm>
|
||||
#include <cassert>
|
||||
#include <numeric>
|
||||
@ -178,83 +177,6 @@ void explicit_gemm_conv_group_ND_gpu(
|
||||
/*copies = */ copies);
|
||||
}
|
||||
|
||||
void conv_1D_gpu(
|
||||
const Stream& s,
|
||||
metal::Device& d,
|
||||
const array& in,
|
||||
const array& wt,
|
||||
array out,
|
||||
const std::vector<int>& padding,
|
||||
const std::vector<int>& wt_strides,
|
||||
const std::vector<int>& wt_dilation,
|
||||
const std::vector<int>& in_dilation,
|
||||
int groups,
|
||||
bool flip) {
|
||||
// Make conv params
|
||||
MLXConvParams<1> conv_params{
|
||||
/* const int N = */ static_cast<int>(in.shape(0)),
|
||||
/* const int C = */ static_cast<int>(in.shape(2)),
|
||||
/* const int O = */ static_cast<int>(wt.shape(0)),
|
||||
/* const int iS[NDIM] = */ {static_cast<int>(in.shape(1))},
|
||||
/* const int wS[NDIM] = */ {static_cast<int>(wt.shape(1))},
|
||||
/* const int oS[NDIM] = */ {static_cast<int>(out.shape(1))},
|
||||
/* const int str[NDIM] = */ {wt_strides[0]},
|
||||
/* const int pad[NDIM] = */ {padding[0]},
|
||||
/* const int kdil[NDIM] = */ {wt_dilation[0]},
|
||||
/* const int idil[NDIM] = */ {in_dilation[0]},
|
||||
/* const size_t in_strides[NDIM + 2] = */
|
||||
{in.strides()[0], in.strides()[1], in.strides()[2]},
|
||||
/* const size_t wt_strides[NDIM + 2] = */
|
||||
{wt.strides()[0], wt.strides()[1], wt.strides()[2]},
|
||||
/* const size_t out_strides[NDIM + 2] = */
|
||||
{out.strides()[0], out.strides()[1], out.strides()[2]},
|
||||
/* const int groups = */ groups,
|
||||
/* const bool flip = */ flip};
|
||||
|
||||
// Direct to explicit gemm conv
|
||||
if (groups > 1) {
|
||||
return explicit_gemm_conv_group_ND_gpu(s, d, in, wt, out, conv_params);
|
||||
} else {
|
||||
return explicit_gemm_conv_ND_gpu(s, d, in, wt, out, conv_params);
|
||||
}
|
||||
}
|
||||
|
||||
void slow_conv_2D_gpu(
|
||||
const Stream& s,
|
||||
metal::Device& d,
|
||||
const array& in,
|
||||
const array& wt,
|
||||
array out,
|
||||
const MLXConvParams<2>& conv_params) {
|
||||
int bm = 16, bn = 8;
|
||||
int tm = 4, tn = 4;
|
||||
|
||||
std::ostringstream kname;
|
||||
kname << "naive_conv_2d_" << type_to_name(out) << "_bm" << bm << "_bn" << bn
|
||||
<< "_tm" << tm << "_tn" << tn;
|
||||
|
||||
// Encode and dispatch kernel
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
auto kernel = d.get_kernel(kname.str());
|
||||
compute_encoder.set_compute_pipeline_state(kernel);
|
||||
|
||||
size_t n_pixels = conv_params.oS[0] * conv_params.oS[1];
|
||||
|
||||
size_t grid_dim_x = (n_pixels + (tm * bm) - 1) / (tm * bm);
|
||||
size_t grid_dim_y = (conv_params.O + (tn * bn) - 1) / (tn * bn);
|
||||
size_t grid_dim_z = conv_params.N;
|
||||
|
||||
MTL::Size group_dims = MTL::Size(bm, bn, 1);
|
||||
MTL::Size grid_dims = MTL::Size(grid_dim_x, grid_dim_y, grid_dim_z);
|
||||
|
||||
compute_encoder.set_input_array(in, 0);
|
||||
compute_encoder.set_input_array(wt, 1);
|
||||
compute_encoder.set_output_array(out, 2);
|
||||
|
||||
compute_encoder.set_bytes(conv_params, 3);
|
||||
compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
|
||||
}
|
||||
|
||||
void implicit_gemm_conv_2D_gpu(
|
||||
const Stream& s,
|
||||
metal::Device& d,
|
||||
@ -771,6 +693,141 @@ void depthwise_conv_2D_gpu(
|
||||
compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
|
||||
}
|
||||
|
||||
void dispatch_conv_2D_gpu(
|
||||
const Stream& s,
|
||||
metal::Device& d,
|
||||
const array& in,
|
||||
const array& wt,
|
||||
array out,
|
||||
const MLXConvParams<2>& conv_params,
|
||||
std::vector<array>& copies) {
|
||||
bool is_stride_one = conv_params.str[0] == 1 && conv_params.str[1] == 1;
|
||||
bool is_kdil_one = conv_params.kdil[0] == 1 && conv_params.kdil[1] == 1;
|
||||
bool is_idil_one = conv_params.idil[0] == 1 && conv_params.idil[1] == 1;
|
||||
|
||||
if (is_idil_one && conv_params.groups > 1) {
|
||||
const int C_per_group = conv_params.C / conv_params.groups;
|
||||
const int O_per_group = conv_params.O / conv_params.groups;
|
||||
|
||||
if (C_per_group == 1 && O_per_group == 1 && is_kdil_one &&
|
||||
conv_params.wS[0] <= 7 && conv_params.wS[1] <= 7 &&
|
||||
conv_params.str[0] <= 2 && conv_params.str[1] <= 2 &&
|
||||
conv_params.oS[0] % 8 == 0 && conv_params.oS[1] % 8 == 0 &&
|
||||
conv_params.wt_strides[1] == conv_params.wS[1] &&
|
||||
conv_params.C % 16 == 0 && conv_params.C == conv_params.O) {
|
||||
return depthwise_conv_2D_gpu(s, d, in, wt, out, conv_params);
|
||||
}
|
||||
|
||||
if ((C_per_group <= 4 || C_per_group % 16 == 0) &&
|
||||
(O_per_group <= 16 || O_per_group % 16 == 0)) {
|
||||
return implicit_gemm_conv_2D_gpu(s, d, in, wt, out, conv_params);
|
||||
} else {
|
||||
return explicit_gemm_conv_group_ND_gpu(s, d, in, wt, out, conv_params);
|
||||
}
|
||||
}
|
||||
|
||||
// Direct to winograd conv
|
||||
bool inp_large =
|
||||
(conv_params.N * conv_params.iS[0] * conv_params.iS[1]) >= 1ul << 12;
|
||||
bool channels_large = (conv_params.C + conv_params.O) >= 256;
|
||||
if (!conv_params.flip && is_stride_one && is_kdil_one && is_idil_one &&
|
||||
conv_params.wS[0] == 3 && conv_params.wS[1] == 3 &&
|
||||
conv_params.C % 32 == 0 && conv_params.O % 32 == 0 && inp_large &&
|
||||
channels_large) {
|
||||
return winograd_conv_2D_gpu(s, d, in, wt, out, conv_params, copies);
|
||||
}
|
||||
|
||||
// Direct to implicit gemm conv
|
||||
if (is_idil_one && (conv_params.C <= 4 || conv_params.C % 16 == 0) &&
|
||||
(conv_params.O <= 16 || conv_params.O % 16 == 0)) {
|
||||
return implicit_gemm_conv_2D_gpu(s, d, in, wt, out, conv_params);
|
||||
}
|
||||
|
||||
else if (conv_params.C % 16 == 0 && conv_params.O % 16 == 0) {
|
||||
return implicit_gemm_conv_2D_general_gpu(s, d, in, wt, out, conv_params);
|
||||
}
|
||||
|
||||
// Direct to explicit gemm conv
|
||||
else {
|
||||
return explicit_gemm_conv_ND_gpu(s, d, in, wt, out, conv_params);
|
||||
}
|
||||
}
|
||||
|
||||
void conv_1D_gpu(
|
||||
const Stream& s,
|
||||
metal::Device& d,
|
||||
const array& in,
|
||||
const array& wt,
|
||||
array out,
|
||||
const std::vector<int>& padding,
|
||||
const std::vector<int>& wt_strides,
|
||||
const std::vector<int>& wt_dilation,
|
||||
const std::vector<int>& in_dilation,
|
||||
int groups,
|
||||
bool flip,
|
||||
std::vector<array>& copies) {
|
||||
bool is_idil_one = in_dilation[0] == 1;
|
||||
int C = in.shape(2);
|
||||
int O = wt.shape(0);
|
||||
const int C_per_group = in.shape(2) / groups;
|
||||
const int O_per_group = wt.shape(0) / groups;
|
||||
|
||||
// Direct to implicit gemm conv
|
||||
if (is_idil_one && (C_per_group <= 4 || C_per_group % 16 == 0) &&
|
||||
(O_per_group <= 16 || O_per_group % 16 == 0)) {
|
||||
MLXConvParams<2> conv_params{
|
||||
/* const int N = */ static_cast<int>(in.shape(0)),
|
||||
/* const int C = */ C,
|
||||
/* const int O = */ O,
|
||||
/* const int iS[NDIM] = */ {static_cast<int>(in.shape(1)), 1},
|
||||
/* const int wS[NDIM] = */ {static_cast<int>(wt.shape(1)), 1},
|
||||
/* const int oS[NDIM] = */ {static_cast<int>(out.shape(1)), 1},
|
||||
/* const int str[NDIM] = */ {wt_strides[0], 1},
|
||||
/* const int pad[NDIM] = */ {padding[0], 0},
|
||||
/* const int kdil[NDIM] = */ {wt_dilation[0], 1},
|
||||
/* const int idil[NDIM] = */ {in_dilation[0], 1},
|
||||
/* const size_t in_strides[NDIM + 2] = */
|
||||
{in.strides()[0], in.strides()[1], 0, in.strides()[2]},
|
||||
/* const size_t wt_strides[NDIM + 2] = */
|
||||
{wt.strides()[0], wt.strides()[1], 0, wt.strides()[2]},
|
||||
/* const size_t out_strides[NDIM + 2] = */
|
||||
{out.strides()[0], out.strides()[1], 0, out.strides()[2]},
|
||||
/* const int groups = */ groups,
|
||||
/* const bool flip = */ flip};
|
||||
|
||||
dispatch_conv_2D_gpu(s, d, in, wt, out, conv_params, copies);
|
||||
return;
|
||||
}
|
||||
|
||||
// Make conv params
|
||||
MLXConvParams<1> conv_params{
|
||||
/* const int N = */ static_cast<int>(in.shape(0)),
|
||||
/* const int C = */ static_cast<int>(in.shape(2)),
|
||||
/* const int O = */ static_cast<int>(wt.shape(0)),
|
||||
/* const int iS[NDIM] = */ {static_cast<int>(in.shape(1))},
|
||||
/* const int wS[NDIM] = */ {static_cast<int>(wt.shape(1))},
|
||||
/* const int oS[NDIM] = */ {static_cast<int>(out.shape(1))},
|
||||
/* const int str[NDIM] = */ {wt_strides[0]},
|
||||
/* const int pad[NDIM] = */ {padding[0]},
|
||||
/* const int kdil[NDIM] = */ {wt_dilation[0]},
|
||||
/* const int idil[NDIM] = */ {in_dilation[0]},
|
||||
/* const size_t in_strides[NDIM + 2] = */
|
||||
{in.strides()[0], in.strides()[1], in.strides()[2]},
|
||||
/* const size_t wt_strides[NDIM + 2] = */
|
||||
{wt.strides()[0], wt.strides()[1], wt.strides()[2]},
|
||||
/* const size_t out_strides[NDIM + 2] = */
|
||||
{out.strides()[0], out.strides()[1], out.strides()[2]},
|
||||
/* const int groups = */ groups,
|
||||
/* const bool flip = */ flip};
|
||||
|
||||
// Direct to explicit gemm conv
|
||||
if (groups > 1) {
|
||||
return explicit_gemm_conv_group_ND_gpu(s, d, in, wt, out, conv_params);
|
||||
} else {
|
||||
return explicit_gemm_conv_ND_gpu(s, d, in, wt, out, conv_params);
|
||||
}
|
||||
}
|
||||
|
||||
void conv_2D_gpu(
|
||||
const Stream& s,
|
||||
metal::Device& d,
|
||||
@ -808,57 +865,7 @@ void conv_2D_gpu(
|
||||
/* const int groups = */ groups,
|
||||
/* const bool flip = */ flip,
|
||||
};
|
||||
|
||||
bool is_stride_one = conv_params.str[0] == 1 && conv_params.str[1] == 1;
|
||||
bool is_kdil_one = conv_params.kdil[0] == 1 && conv_params.kdil[1] == 1;
|
||||
bool is_idil_one = conv_params.idil[0] == 1 && conv_params.idil[1] == 1;
|
||||
|
||||
if (is_idil_one && groups > 1) {
|
||||
const int C_per_group = conv_params.C / groups;
|
||||
const int O_per_group = conv_params.O / groups;
|
||||
|
||||
if (C_per_group == 1 && O_per_group == 1 && is_kdil_one &&
|
||||
conv_params.wS[0] <= 7 && conv_params.wS[1] <= 7 &&
|
||||
conv_params.str[0] <= 2 && conv_params.str[1] <= 2 &&
|
||||
conv_params.oS[0] % 8 == 0 && conv_params.oS[1] % 8 == 0 &&
|
||||
conv_params.wt_strides[1] == conv_params.wS[1] &&
|
||||
conv_params.C % 16 == 0 && conv_params.C == conv_params.O) {
|
||||
return depthwise_conv_2D_gpu(s, d, in, wt, out, conv_params);
|
||||
}
|
||||
|
||||
if ((C_per_group <= 4 || C_per_group % 16 == 0) &&
|
||||
(O_per_group <= 16 || O_per_group % 16 == 0)) {
|
||||
return implicit_gemm_conv_2D_gpu(s, d, in, wt, out, conv_params);
|
||||
} else {
|
||||
return explicit_gemm_conv_group_ND_gpu(s, d, in, wt, out, conv_params);
|
||||
}
|
||||
}
|
||||
|
||||
// Direct to winograd conv
|
||||
bool inp_large =
|
||||
(conv_params.N * conv_params.iS[0] * conv_params.iS[1]) >= 1ul << 12;
|
||||
bool channels_large = (conv_params.C + conv_params.O) >= 256;
|
||||
if (!flip && is_stride_one && is_kdil_one && is_idil_one &&
|
||||
conv_params.wS[0] == 3 && conv_params.wS[1] == 3 &&
|
||||
conv_params.C % 32 == 0 && conv_params.O % 32 == 0 && inp_large &&
|
||||
channels_large) {
|
||||
return winograd_conv_2D_gpu(s, d, in, wt, out, conv_params, copies);
|
||||
}
|
||||
|
||||
// Direct to implicit gemm conv
|
||||
if (is_idil_one && (conv_params.C <= 4 || conv_params.C % 16 == 0) &&
|
||||
(conv_params.O <= 16 || conv_params.O % 16 == 0)) {
|
||||
return implicit_gemm_conv_2D_gpu(s, d, in, wt, out, conv_params);
|
||||
}
|
||||
|
||||
else if (conv_params.C % 16 == 0 && conv_params.O % 16 == 0) {
|
||||
return implicit_gemm_conv_2D_general_gpu(s, d, in, wt, out, conv_params);
|
||||
}
|
||||
|
||||
// Direct to explicit gemm conv
|
||||
else {
|
||||
return explicit_gemm_conv_ND_gpu(s, d, in, wt, out, conv_params);
|
||||
}
|
||||
dispatch_conv_2D_gpu(s, d, in, wt, out, conv_params, copies);
|
||||
}
|
||||
|
||||
void conv_3D_gpu(
|
||||
@ -988,7 +995,8 @@ void Convolution::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
kernel_dilation_,
|
||||
input_dilation_,
|
||||
groups_,
|
||||
flip_);
|
||||
flip_,
|
||||
copies);
|
||||
}
|
||||
// Throw error
|
||||
else {
|
||||
|
@ -381,6 +381,7 @@ struct Conv2DWeightBlockLoader {
|
||||
const constant MLXConvParams<2>* params;
|
||||
|
||||
int weight_hw;
|
||||
int weight_step;
|
||||
|
||||
const int read_n;
|
||||
const bool do_read;
|
||||
@ -402,6 +403,7 @@ struct Conv2DWeightBlockLoader {
|
||||
src(src_ + bi * src_ld + bj),
|
||||
params(params_),
|
||||
weight_hw(0),
|
||||
weight_step(params->C / params->groups),
|
||||
read_n(offsets.y + bi),
|
||||
do_read(read_n + n_rows * TROWS <= gemm_params_->N) {}
|
||||
|
||||
@ -435,13 +437,13 @@ struct Conv2DWeightBlockLoader {
|
||||
/* Iteration helper */
|
||||
METAL_FUNC void next() {
|
||||
if (++weight_hw < (params->wS[1] * params->wS[0])) {
|
||||
src += params->wt_strides[2];
|
||||
src += weight_step;
|
||||
return;
|
||||
}
|
||||
|
||||
weight_hw = 0;
|
||||
|
||||
src += BK - (params->wS[1] * params->wS[0] - 1) * params->wt_strides[2];
|
||||
src += BK - (params->wS[1] * params->wS[0] - 1) * weight_step;
|
||||
}
|
||||
};
|
||||
|
||||
|
@ -272,7 +272,7 @@ struct Conv2DWeightBlockLoaderSmallChannels {
|
||||
return;
|
||||
}
|
||||
|
||||
const device T* curr_src = src + weight_hw * params->wt_strides[2];
|
||||
const device T* curr_src = src + weight_hw * (params->C / params->groups);
|
||||
|
||||
if (BN != 8 || do_read) {
|
||||
STEEL_PRAGMA_UNROLL
|
||||
|
@ -3584,21 +3584,21 @@ Shape conv_out_shape(
|
||||
|
||||
if (pads_lo.size() != spatial_dims || pads_hi.size() != spatial_dims) {
|
||||
std::ostringstream msg;
|
||||
msg << "[conv] Invalid padding " << pads_lo << " | " << pads_hi << "for "
|
||||
msg << "[conv] Invalid padding " << pads_lo << " | " << pads_hi << " for "
|
||||
<< spatial_dims << "D convolution.";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
|
||||
if (kernel_dilation.size() != spatial_dims) {
|
||||
std::ostringstream msg;
|
||||
msg << "[conv] Invalid kernel dilation " << kernel_dilation << "for "
|
||||
msg << "[conv] Invalid kernel dilation " << kernel_dilation << " for "
|
||||
<< spatial_dims << "D convolution.";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
|
||||
if (input_dilation.size() != spatial_dims) {
|
||||
std::ostringstream msg;
|
||||
msg << "[conv] Invalid input dilation " << input_dilation << "for "
|
||||
msg << "[conv] Invalid input dilation " << input_dilation << " for "
|
||||
<< spatial_dims << "D convolution.";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
|
@ -1152,6 +1152,27 @@ class TestConv(mlx_tests.MLXTestCase):
|
||||
)
|
||||
self.assertEqual(grads.shape, k_shape)
|
||||
|
||||
def test_1d_conv_with_2d(self):
|
||||
x = mx.random.uniform(shape=(2, 10, 16))
|
||||
y = mx.random.normal(shape=(16, 3, 16))
|
||||
|
||||
out = mx.conv1d(x, y, padding=1)
|
||||
out_2d = mx.conv2d(
|
||||
mx.expand_dims(x, axis=2), mx.expand_dims(y, axis=2), padding=(1, 0)
|
||||
)
|
||||
|
||||
self.assertTrue(mx.allclose(out, out_2d.squeeze(2)))
|
||||
|
||||
x = mx.random.uniform(shape=(2, 10, 4))
|
||||
y = mx.random.normal(shape=(4, 3, 4))
|
||||
|
||||
out = mx.conv1d(x, y, padding=1)
|
||||
out_2d = mx.conv2d(
|
||||
mx.expand_dims(x, axis=2), mx.expand_dims(y, axis=2), padding=(1, 0)
|
||||
)
|
||||
|
||||
self.assertTrue(mx.allclose(out, out_2d.squeeze(2)))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
@ -634,6 +634,7 @@ class TestVmap(mlx_tests.MLXTestCase):
|
||||
self.assertEqual(fy.shape, (4, 5, 6, 7))
|
||||
|
||||
def test_leaks(self):
|
||||
mx.synchronize()
|
||||
if mx.metal.is_available():
|
||||
mem_pre = mx.get_active_memory()
|
||||
else:
|
||||
|
Loading…
Reference in New Issue
Block a user