fix conv2d bug + faster conv 1d

This commit is contained in:
Awni Hannun 2025-05-16 12:12:20 -07:00
parent 48ef3e74e2
commit 3b169acf50
6 changed files with 181 additions and 167 deletions

View File

@ -1,5 +1,4 @@
// Copyright © 2023-2024 Apple Inc.
#include <algorithm>
#include <cassert>
#include <numeric>
@ -178,83 +177,6 @@ void explicit_gemm_conv_group_ND_gpu(
/*copies = */ copies);
}
void conv_1D_gpu(
const Stream& s,
metal::Device& d,
const array& in,
const array& wt,
array out,
const std::vector<int>& padding,
const std::vector<int>& wt_strides,
const std::vector<int>& wt_dilation,
const std::vector<int>& in_dilation,
int groups,
bool flip) {
// Make conv params
MLXConvParams<1> conv_params{
/* const int N = */ static_cast<int>(in.shape(0)),
/* const int C = */ static_cast<int>(in.shape(2)),
/* const int O = */ static_cast<int>(wt.shape(0)),
/* const int iS[NDIM] = */ {static_cast<int>(in.shape(1))},
/* const int wS[NDIM] = */ {static_cast<int>(wt.shape(1))},
/* const int oS[NDIM] = */ {static_cast<int>(out.shape(1))},
/* const int str[NDIM] = */ {wt_strides[0]},
/* const int pad[NDIM] = */ {padding[0]},
/* const int kdil[NDIM] = */ {wt_dilation[0]},
/* const int idil[NDIM] = */ {in_dilation[0]},
/* const size_t in_strides[NDIM + 2] = */
{in.strides()[0], in.strides()[1], in.strides()[2]},
/* const size_t wt_strides[NDIM + 2] = */
{wt.strides()[0], wt.strides()[1], wt.strides()[2]},
/* const size_t out_strides[NDIM + 2] = */
{out.strides()[0], out.strides()[1], out.strides()[2]},
/* const int groups = */ groups,
/* const bool flip = */ flip};
// Direct to explicit gemm conv
if (groups > 1) {
return explicit_gemm_conv_group_ND_gpu(s, d, in, wt, out, conv_params);
} else {
return explicit_gemm_conv_ND_gpu(s, d, in, wt, out, conv_params);
}
}
void slow_conv_2D_gpu(
const Stream& s,
metal::Device& d,
const array& in,
const array& wt,
array out,
const MLXConvParams<2>& conv_params) {
int bm = 16, bn = 8;
int tm = 4, tn = 4;
std::ostringstream kname;
kname << "naive_conv_2d_" << type_to_name(out) << "_bm" << bm << "_bn" << bn
<< "_tm" << tm << "_tn" << tn;
// Encode and dispatch kernel
auto& compute_encoder = d.get_command_encoder(s.index);
auto kernel = d.get_kernel(kname.str());
compute_encoder.set_compute_pipeline_state(kernel);
size_t n_pixels = conv_params.oS[0] * conv_params.oS[1];
size_t grid_dim_x = (n_pixels + (tm * bm) - 1) / (tm * bm);
size_t grid_dim_y = (conv_params.O + (tn * bn) - 1) / (tn * bn);
size_t grid_dim_z = conv_params.N;
MTL::Size group_dims = MTL::Size(bm, bn, 1);
MTL::Size grid_dims = MTL::Size(grid_dim_x, grid_dim_y, grid_dim_z);
compute_encoder.set_input_array(in, 0);
compute_encoder.set_input_array(wt, 1);
compute_encoder.set_output_array(out, 2);
compute_encoder.set_bytes(conv_params, 3);
compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
}
void implicit_gemm_conv_2D_gpu(
const Stream& s,
metal::Device& d,
@ -771,6 +693,141 @@ void depthwise_conv_2D_gpu(
compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
}
void dispatch_conv_2D_gpu(
const Stream& s,
metal::Device& d,
const array& in,
const array& wt,
array out,
const MLXConvParams<2>& conv_params,
std::vector<array>& copies) {
bool is_stride_one = conv_params.str[0] == 1 && conv_params.str[1] == 1;
bool is_kdil_one = conv_params.kdil[0] == 1 && conv_params.kdil[1] == 1;
bool is_idil_one = conv_params.idil[0] == 1 && conv_params.idil[1] == 1;
if (is_idil_one && conv_params.groups > 1) {
const int C_per_group = conv_params.C / conv_params.groups;
const int O_per_group = conv_params.O / conv_params.groups;
if (C_per_group == 1 && O_per_group == 1 && is_kdil_one &&
conv_params.wS[0] <= 7 && conv_params.wS[1] <= 7 &&
conv_params.str[0] <= 2 && conv_params.str[1] <= 2 &&
conv_params.oS[0] % 8 == 0 && conv_params.oS[1] % 8 == 0 &&
conv_params.wt_strides[1] == conv_params.wS[1] &&
conv_params.C % 16 == 0 && conv_params.C == conv_params.O) {
return depthwise_conv_2D_gpu(s, d, in, wt, out, conv_params);
}
if ((C_per_group <= 4 || C_per_group % 16 == 0) &&
(O_per_group <= 16 || O_per_group % 16 == 0)) {
return implicit_gemm_conv_2D_gpu(s, d, in, wt, out, conv_params);
} else {
return explicit_gemm_conv_group_ND_gpu(s, d, in, wt, out, conv_params);
}
}
// Direct to winograd conv
bool inp_large =
(conv_params.N * conv_params.iS[0] * conv_params.iS[1]) >= 1ul << 12;
bool channels_large = (conv_params.C + conv_params.O) >= 256;
if (!conv_params.flip && is_stride_one && is_kdil_one && is_idil_one &&
conv_params.wS[0] == 3 && conv_params.wS[1] == 3 &&
conv_params.C % 32 == 0 && conv_params.O % 32 == 0 && inp_large &&
channels_large) {
return winograd_conv_2D_gpu(s, d, in, wt, out, conv_params, copies);
}
// Direct to implicit gemm conv
if (is_idil_one && (conv_params.C <= 4 || conv_params.C % 16 == 0) &&
(conv_params.O <= 16 || conv_params.O % 16 == 0)) {
return implicit_gemm_conv_2D_gpu(s, d, in, wt, out, conv_params);
}
else if (conv_params.C % 16 == 0 && conv_params.O % 16 == 0) {
return implicit_gemm_conv_2D_general_gpu(s, d, in, wt, out, conv_params);
}
// Direct to explicit gemm conv
else {
return explicit_gemm_conv_ND_gpu(s, d, in, wt, out, conv_params);
}
}
void conv_1D_gpu(
const Stream& s,
metal::Device& d,
const array& in,
const array& wt,
array out,
const std::vector<int>& padding,
const std::vector<int>& wt_strides,
const std::vector<int>& wt_dilation,
const std::vector<int>& in_dilation,
int groups,
bool flip,
std::vector<array>& copies) {
bool is_idil_one = in_dilation[0] == 1;
int C = in.shape(2);
int O = wt.shape(0);
const int C_per_group = in.shape(2) / groups;
const int O_per_group = wt.shape(0) / groups;
// Direct to implicit gemm conv
if (is_idil_one && (C_per_group <= 4 || C_per_group % 16 == 0) &&
(O_per_group <= 16 || O_per_group % 16 == 0)) {
MLXConvParams<2> conv_params{
/* const int N = */ static_cast<int>(in.shape(0)),
/* const int C = */ C,
/* const int O = */ O,
/* const int iS[NDIM] = */ {static_cast<int>(in.shape(1)), 1},
/* const int wS[NDIM] = */ {static_cast<int>(wt.shape(1)), 1},
/* const int oS[NDIM] = */ {static_cast<int>(out.shape(1)), 1},
/* const int str[NDIM] = */ {wt_strides[0], 1},
/* const int pad[NDIM] = */ {padding[0], 0},
/* const int kdil[NDIM] = */ {wt_dilation[0], 1},
/* const int idil[NDIM] = */ {in_dilation[0], 1},
/* const size_t in_strides[NDIM + 2] = */
{in.strides()[0], in.strides()[1], 0, in.strides()[2]},
/* const size_t wt_strides[NDIM + 2] = */
{wt.strides()[0], wt.strides()[1], 0, wt.strides()[2]},
/* const size_t out_strides[NDIM + 2] = */
{out.strides()[0], out.strides()[1], 0, out.strides()[2]},
/* const int groups = */ groups,
/* const bool flip = */ flip};
dispatch_conv_2D_gpu(s, d, in, wt, out, conv_params, copies);
return;
}
// Make conv params
MLXConvParams<1> conv_params{
/* const int N = */ static_cast<int>(in.shape(0)),
/* const int C = */ static_cast<int>(in.shape(2)),
/* const int O = */ static_cast<int>(wt.shape(0)),
/* const int iS[NDIM] = */ {static_cast<int>(in.shape(1))},
/* const int wS[NDIM] = */ {static_cast<int>(wt.shape(1))},
/* const int oS[NDIM] = */ {static_cast<int>(out.shape(1))},
/* const int str[NDIM] = */ {wt_strides[0]},
/* const int pad[NDIM] = */ {padding[0]},
/* const int kdil[NDIM] = */ {wt_dilation[0]},
/* const int idil[NDIM] = */ {in_dilation[0]},
/* const size_t in_strides[NDIM + 2] = */
{in.strides()[0], in.strides()[1], in.strides()[2]},
/* const size_t wt_strides[NDIM + 2] = */
{wt.strides()[0], wt.strides()[1], wt.strides()[2]},
/* const size_t out_strides[NDIM + 2] = */
{out.strides()[0], out.strides()[1], out.strides()[2]},
/* const int groups = */ groups,
/* const bool flip = */ flip};
// Direct to explicit gemm conv
if (groups > 1) {
return explicit_gemm_conv_group_ND_gpu(s, d, in, wt, out, conv_params);
} else {
return explicit_gemm_conv_ND_gpu(s, d, in, wt, out, conv_params);
}
}
void conv_2D_gpu(
const Stream& s,
metal::Device& d,
@ -808,57 +865,7 @@ void conv_2D_gpu(
/* const int groups = */ groups,
/* const bool flip = */ flip,
};
bool is_stride_one = conv_params.str[0] == 1 && conv_params.str[1] == 1;
bool is_kdil_one = conv_params.kdil[0] == 1 && conv_params.kdil[1] == 1;
bool is_idil_one = conv_params.idil[0] == 1 && conv_params.idil[1] == 1;
if (is_idil_one && groups > 1) {
const int C_per_group = conv_params.C / groups;
const int O_per_group = conv_params.O / groups;
if (C_per_group == 1 && O_per_group == 1 && is_kdil_one &&
conv_params.wS[0] <= 7 && conv_params.wS[1] <= 7 &&
conv_params.str[0] <= 2 && conv_params.str[1] <= 2 &&
conv_params.oS[0] % 8 == 0 && conv_params.oS[1] % 8 == 0 &&
conv_params.wt_strides[1] == conv_params.wS[1] &&
conv_params.C % 16 == 0 && conv_params.C == conv_params.O) {
return depthwise_conv_2D_gpu(s, d, in, wt, out, conv_params);
}
if ((C_per_group <= 4 || C_per_group % 16 == 0) &&
(O_per_group <= 16 || O_per_group % 16 == 0)) {
return implicit_gemm_conv_2D_gpu(s, d, in, wt, out, conv_params);
} else {
return explicit_gemm_conv_group_ND_gpu(s, d, in, wt, out, conv_params);
}
}
// Direct to winograd conv
bool inp_large =
(conv_params.N * conv_params.iS[0] * conv_params.iS[1]) >= 1ul << 12;
bool channels_large = (conv_params.C + conv_params.O) >= 256;
if (!flip && is_stride_one && is_kdil_one && is_idil_one &&
conv_params.wS[0] == 3 && conv_params.wS[1] == 3 &&
conv_params.C % 32 == 0 && conv_params.O % 32 == 0 && inp_large &&
channels_large) {
return winograd_conv_2D_gpu(s, d, in, wt, out, conv_params, copies);
}
// Direct to implicit gemm conv
if (is_idil_one && (conv_params.C <= 4 || conv_params.C % 16 == 0) &&
(conv_params.O <= 16 || conv_params.O % 16 == 0)) {
return implicit_gemm_conv_2D_gpu(s, d, in, wt, out, conv_params);
}
else if (conv_params.C % 16 == 0 && conv_params.O % 16 == 0) {
return implicit_gemm_conv_2D_general_gpu(s, d, in, wt, out, conv_params);
}
// Direct to explicit gemm conv
else {
return explicit_gemm_conv_ND_gpu(s, d, in, wt, out, conv_params);
}
dispatch_conv_2D_gpu(s, d, in, wt, out, conv_params, copies);
}
void conv_3D_gpu(
@ -988,7 +995,8 @@ void Convolution::eval_gpu(const std::vector<array>& inputs, array& out) {
kernel_dilation_,
input_dilation_,
groups_,
flip_);
flip_,
copies);
}
// Throw error
else {

View File

@ -381,6 +381,7 @@ struct Conv2DWeightBlockLoader {
const constant MLXConvParams<2>* params;
int weight_hw;
int weight_step;
const int read_n;
const bool do_read;
@ -402,6 +403,7 @@ struct Conv2DWeightBlockLoader {
src(src_ + bi * src_ld + bj),
params(params_),
weight_hw(0),
weight_step(params->C / params->groups),
read_n(offsets.y + bi),
do_read(read_n + n_rows * TROWS <= gemm_params_->N) {}
@ -435,15 +437,15 @@ struct Conv2DWeightBlockLoader {
/* Iteration helper */
METAL_FUNC void next() {
if (++weight_hw < (params->wS[1] * params->wS[0])) {
src += params->wt_strides[2];
src += weight_step;
return;
}
weight_hw = 0;
src += BK - (params->wS[1] * params->wS[0] - 1) * params->wt_strides[2];
src += BK - (params->wS[1] * params->wS[0] - 1) * weight_step;
}
};
} // namespace steel
} // namespace mlx
} // namespace mlx

View File

@ -272,7 +272,7 @@ struct Conv2DWeightBlockLoaderSmallChannels {
return;
}
const device T* curr_src = src + weight_hw * params->wt_strides[2];
const device T* curr_src = src + weight_hw * (params->C / params->groups);
if (BN != 8 || do_read) {
STEEL_PRAGMA_UNROLL
@ -316,4 +316,4 @@ struct Conv2DWeightBlockLoaderSmallChannels {
};
} // namespace steel
} // namespace mlx
} // namespace mlx

View File

@ -21,6 +21,8 @@ void single_block_sort(
int bn,
int tn,
bool argsort) {
out.set_data(allocator::malloc(out.nbytes()));
// Prepare shapes
int n_rows = in.size() / in.shape(axis);
@ -156,9 +158,6 @@ void multi_block_sort(
dev_idxs_1.set_data(allocator::malloc(dev_idxs_1.nbytes()));
block_partitions.set_data(allocator::malloc(block_partitions.nbytes()));
std::vector<array> copies = {
dev_vals_0, dev_vals_1, dev_idxs_0, dev_idxs_1, block_partitions};
// Prepare command encoder
auto& compute_encoder = d.get_command_encoder(s.index);
@ -250,25 +249,17 @@ void multi_block_sort(
compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
}
}
// Copy outputs with appropriate strides
auto strides = out.strides();
for (int ax = axis + 1; ax < strides.size(); ax++) {
strides[ax] *= out.shape(axis);
}
strides[axis] = 1;
copy_gpu_inplace(
(argsort) ? dev_idxs_out : dev_vals_out,
out,
out.shape(),
strides,
out.copy_shared_buffer(
argsort ? dev_idxs_out : dev_vals_out,
out.strides(),
0,
0,
(axis == in.ndim() - 1) ? CopyType::Vector : CopyType::General,
s);
d.add_temporaries(std::move(copies), s.index);
out.flags(),
out.data_size());
d.add_temporaries(
{dev_vals_in,
dev_idxs_in,
argsort ? dev_vals_in : dev_idxs_in,
block_partitions},
s.index);
}
void gpu_merge_sort(
@ -318,8 +309,6 @@ void gpu_merge_sort(
void ArgSort::eval_gpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
out.set_data(allocator::malloc(out.nbytes()));
auto& s = stream();
auto& d = metal::device(s.device);
auto& in = inputs[0];
@ -330,8 +319,6 @@ void ArgSort::eval_gpu(const std::vector<array>& inputs, array& out) {
void Sort::eval_gpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
out.set_data(allocator::malloc(out.nbytes()));
auto& s = stream();
auto& d = metal::device(s.device);
auto& in = inputs[0];
@ -343,8 +330,6 @@ void ArgPartition::eval_gpu(const std::vector<array>& inputs, array& out) {
// We direct arg partition to sort for now
assert(inputs.size() == 1);
out.set_data(allocator::malloc(out.nbytes()));
auto& s = stream();
auto& d = metal::device(s.device);
auto& in = inputs[0];
@ -356,8 +341,6 @@ void Partition::eval_gpu(const std::vector<array>& inputs, array& out) {
// We direct partition to sort for now
assert(inputs.size() == 1);
out.set_data(allocator::malloc(out.nbytes()));
auto& s = stream();
auto& d = metal::device(s.device);
auto& in = inputs[0];

View File

@ -3584,21 +3584,21 @@ Shape conv_out_shape(
if (pads_lo.size() != spatial_dims || pads_hi.size() != spatial_dims) {
std::ostringstream msg;
msg << "[conv] Invalid padding " << pads_lo << " | " << pads_hi << "for "
msg << "[conv] Invalid padding " << pads_lo << " | " << pads_hi << " for "
<< spatial_dims << "D convolution.";
throw std::invalid_argument(msg.str());
}
if (kernel_dilation.size() != spatial_dims) {
std::ostringstream msg;
msg << "[conv] Invalid kernel dilation " << kernel_dilation << "for "
msg << "[conv] Invalid kernel dilation " << kernel_dilation << " for "
<< spatial_dims << "D convolution.";
throw std::invalid_argument(msg.str());
}
if (input_dilation.size() != spatial_dims) {
std::ostringstream msg;
msg << "[conv] Invalid input dilation " << input_dilation << "for "
msg << "[conv] Invalid input dilation " << input_dilation << " for "
<< spatial_dims << "D convolution.";
throw std::invalid_argument(msg.str());
}

View File

@ -1152,6 +1152,27 @@ class TestConv(mlx_tests.MLXTestCase):
)
self.assertEqual(grads.shape, k_shape)
def test_1d_conv_with_2d(self):
x = mx.random.uniform(shape=(2, 10, 16))
y = mx.random.normal(shape=(16, 3, 16))
out = mx.conv1d(x, y, padding=1)
out_2d = mx.conv2d(
mx.expand_dims(x, axis=2), mx.expand_dims(y, axis=2), padding=(1, 0)
)
self.assertTrue(mx.allclose(out, out_2d.squeeze(2)))
x = mx.random.uniform(shape=(2, 10, 4))
y = mx.random.normal(shape=(4, 3, 4))
out = mx.conv1d(x, y, padding=1)
out_2d = mx.conv2d(
mx.expand_dims(x, axis=2), mx.expand_dims(y, axis=2), padding=(1, 0)
)
self.assertTrue(mx.allclose(out, out_2d.squeeze(2)))
if __name__ == "__main__":
unittest.main()