From 8576e6fe3606bf5b805162fd5f4a7803a9a0d349 Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Sun, 18 May 2025 06:05:11 -0700 Subject: [PATCH] fix conv2d bug + faster conv 1d (#2195) * fix conv2d bug + faster conv 1d * revert sort + flaky test --- mlx/backend/metal/conv.cpp | 268 +++++++++--------- .../steel/conv/loaders/loader_channel_l.h | 8 +- .../steel/conv/loaders/loader_channel_n.h | 4 +- mlx/ops.cpp | 6 +- python/tests/test_conv.py | 21 ++ python/tests/test_vmap.py | 1 + 6 files changed, 170 insertions(+), 138 deletions(-) diff --git a/mlx/backend/metal/conv.cpp b/mlx/backend/metal/conv.cpp index 35ed3d44e..6b4b70d47 100644 --- a/mlx/backend/metal/conv.cpp +++ b/mlx/backend/metal/conv.cpp @@ -1,5 +1,4 @@ // Copyright © 2023-2024 Apple Inc. - #include #include #include @@ -178,83 +177,6 @@ void explicit_gemm_conv_group_ND_gpu( /*copies = */ copies); } -void conv_1D_gpu( - const Stream& s, - metal::Device& d, - const array& in, - const array& wt, - array out, - const std::vector& padding, - const std::vector& wt_strides, - const std::vector& wt_dilation, - const std::vector& in_dilation, - int groups, - bool flip) { - // Make conv params - MLXConvParams<1> conv_params{ - /* const int N = */ static_cast(in.shape(0)), - /* const int C = */ static_cast(in.shape(2)), - /* const int O = */ static_cast(wt.shape(0)), - /* const int iS[NDIM] = */ {static_cast(in.shape(1))}, - /* const int wS[NDIM] = */ {static_cast(wt.shape(1))}, - /* const int oS[NDIM] = */ {static_cast(out.shape(1))}, - /* const int str[NDIM] = */ {wt_strides[0]}, - /* const int pad[NDIM] = */ {padding[0]}, - /* const int kdil[NDIM] = */ {wt_dilation[0]}, - /* const int idil[NDIM] = */ {in_dilation[0]}, - /* const size_t in_strides[NDIM + 2] = */ - {in.strides()[0], in.strides()[1], in.strides()[2]}, - /* const size_t wt_strides[NDIM + 2] = */ - {wt.strides()[0], wt.strides()[1], wt.strides()[2]}, - /* const size_t out_strides[NDIM + 2] = */ - {out.strides()[0], out.strides()[1], out.strides()[2]}, - /* const int groups = */ groups, - /* const bool flip = */ flip}; - - // Direct to explicit gemm conv - if (groups > 1) { - return explicit_gemm_conv_group_ND_gpu(s, d, in, wt, out, conv_params); - } else { - return explicit_gemm_conv_ND_gpu(s, d, in, wt, out, conv_params); - } -} - -void slow_conv_2D_gpu( - const Stream& s, - metal::Device& d, - const array& in, - const array& wt, - array out, - const MLXConvParams<2>& conv_params) { - int bm = 16, bn = 8; - int tm = 4, tn = 4; - - std::ostringstream kname; - kname << "naive_conv_2d_" << type_to_name(out) << "_bm" << bm << "_bn" << bn - << "_tm" << tm << "_tn" << tn; - - // Encode and dispatch kernel - auto& compute_encoder = d.get_command_encoder(s.index); - auto kernel = d.get_kernel(kname.str()); - compute_encoder.set_compute_pipeline_state(kernel); - - size_t n_pixels = conv_params.oS[0] * conv_params.oS[1]; - - size_t grid_dim_x = (n_pixels + (tm * bm) - 1) / (tm * bm); - size_t grid_dim_y = (conv_params.O + (tn * bn) - 1) / (tn * bn); - size_t grid_dim_z = conv_params.N; - - MTL::Size group_dims = MTL::Size(bm, bn, 1); - MTL::Size grid_dims = MTL::Size(grid_dim_x, grid_dim_y, grid_dim_z); - - compute_encoder.set_input_array(in, 0); - compute_encoder.set_input_array(wt, 1); - compute_encoder.set_output_array(out, 2); - - compute_encoder.set_bytes(conv_params, 3); - compute_encoder.dispatch_threadgroups(grid_dims, group_dims); -} - void implicit_gemm_conv_2D_gpu( const Stream& s, metal::Device& d, @@ -771,6 +693,141 @@ void depthwise_conv_2D_gpu( compute_encoder.dispatch_threadgroups(grid_dims, group_dims); } +void dispatch_conv_2D_gpu( + const Stream& s, + metal::Device& d, + const array& in, + const array& wt, + array out, + const MLXConvParams<2>& conv_params, + std::vector& copies) { + bool is_stride_one = conv_params.str[0] == 1 && conv_params.str[1] == 1; + bool is_kdil_one = conv_params.kdil[0] == 1 && conv_params.kdil[1] == 1; + bool is_idil_one = conv_params.idil[0] == 1 && conv_params.idil[1] == 1; + + if (is_idil_one && conv_params.groups > 1) { + const int C_per_group = conv_params.C / conv_params.groups; + const int O_per_group = conv_params.O / conv_params.groups; + + if (C_per_group == 1 && O_per_group == 1 && is_kdil_one && + conv_params.wS[0] <= 7 && conv_params.wS[1] <= 7 && + conv_params.str[0] <= 2 && conv_params.str[1] <= 2 && + conv_params.oS[0] % 8 == 0 && conv_params.oS[1] % 8 == 0 && + conv_params.wt_strides[1] == conv_params.wS[1] && + conv_params.C % 16 == 0 && conv_params.C == conv_params.O) { + return depthwise_conv_2D_gpu(s, d, in, wt, out, conv_params); + } + + if ((C_per_group <= 4 || C_per_group % 16 == 0) && + (O_per_group <= 16 || O_per_group % 16 == 0)) { + return implicit_gemm_conv_2D_gpu(s, d, in, wt, out, conv_params); + } else { + return explicit_gemm_conv_group_ND_gpu(s, d, in, wt, out, conv_params); + } + } + + // Direct to winograd conv + bool inp_large = + (conv_params.N * conv_params.iS[0] * conv_params.iS[1]) >= 1ul << 12; + bool channels_large = (conv_params.C + conv_params.O) >= 256; + if (!conv_params.flip && is_stride_one && is_kdil_one && is_idil_one && + conv_params.wS[0] == 3 && conv_params.wS[1] == 3 && + conv_params.C % 32 == 0 && conv_params.O % 32 == 0 && inp_large && + channels_large) { + return winograd_conv_2D_gpu(s, d, in, wt, out, conv_params, copies); + } + + // Direct to implicit gemm conv + if (is_idil_one && (conv_params.C <= 4 || conv_params.C % 16 == 0) && + (conv_params.O <= 16 || conv_params.O % 16 == 0)) { + return implicit_gemm_conv_2D_gpu(s, d, in, wt, out, conv_params); + } + + else if (conv_params.C % 16 == 0 && conv_params.O % 16 == 0) { + return implicit_gemm_conv_2D_general_gpu(s, d, in, wt, out, conv_params); + } + + // Direct to explicit gemm conv + else { + return explicit_gemm_conv_ND_gpu(s, d, in, wt, out, conv_params); + } +} + +void conv_1D_gpu( + const Stream& s, + metal::Device& d, + const array& in, + const array& wt, + array out, + const std::vector& padding, + const std::vector& wt_strides, + const std::vector& wt_dilation, + const std::vector& in_dilation, + int groups, + bool flip, + std::vector& copies) { + bool is_idil_one = in_dilation[0] == 1; + int C = in.shape(2); + int O = wt.shape(0); + const int C_per_group = in.shape(2) / groups; + const int O_per_group = wt.shape(0) / groups; + + // Direct to implicit gemm conv + if (is_idil_one && (C_per_group <= 4 || C_per_group % 16 == 0) && + (O_per_group <= 16 || O_per_group % 16 == 0)) { + MLXConvParams<2> conv_params{ + /* const int N = */ static_cast(in.shape(0)), + /* const int C = */ C, + /* const int O = */ O, + /* const int iS[NDIM] = */ {static_cast(in.shape(1)), 1}, + /* const int wS[NDIM] = */ {static_cast(wt.shape(1)), 1}, + /* const int oS[NDIM] = */ {static_cast(out.shape(1)), 1}, + /* const int str[NDIM] = */ {wt_strides[0], 1}, + /* const int pad[NDIM] = */ {padding[0], 0}, + /* const int kdil[NDIM] = */ {wt_dilation[0], 1}, + /* const int idil[NDIM] = */ {in_dilation[0], 1}, + /* const size_t in_strides[NDIM + 2] = */ + {in.strides()[0], in.strides()[1], 0, in.strides()[2]}, + /* const size_t wt_strides[NDIM + 2] = */ + {wt.strides()[0], wt.strides()[1], 0, wt.strides()[2]}, + /* const size_t out_strides[NDIM + 2] = */ + {out.strides()[0], out.strides()[1], 0, out.strides()[2]}, + /* const int groups = */ groups, + /* const bool flip = */ flip}; + + dispatch_conv_2D_gpu(s, d, in, wt, out, conv_params, copies); + return; + } + + // Make conv params + MLXConvParams<1> conv_params{ + /* const int N = */ static_cast(in.shape(0)), + /* const int C = */ static_cast(in.shape(2)), + /* const int O = */ static_cast(wt.shape(0)), + /* const int iS[NDIM] = */ {static_cast(in.shape(1))}, + /* const int wS[NDIM] = */ {static_cast(wt.shape(1))}, + /* const int oS[NDIM] = */ {static_cast(out.shape(1))}, + /* const int str[NDIM] = */ {wt_strides[0]}, + /* const int pad[NDIM] = */ {padding[0]}, + /* const int kdil[NDIM] = */ {wt_dilation[0]}, + /* const int idil[NDIM] = */ {in_dilation[0]}, + /* const size_t in_strides[NDIM + 2] = */ + {in.strides()[0], in.strides()[1], in.strides()[2]}, + /* const size_t wt_strides[NDIM + 2] = */ + {wt.strides()[0], wt.strides()[1], wt.strides()[2]}, + /* const size_t out_strides[NDIM + 2] = */ + {out.strides()[0], out.strides()[1], out.strides()[2]}, + /* const int groups = */ groups, + /* const bool flip = */ flip}; + + // Direct to explicit gemm conv + if (groups > 1) { + return explicit_gemm_conv_group_ND_gpu(s, d, in, wt, out, conv_params); + } else { + return explicit_gemm_conv_ND_gpu(s, d, in, wt, out, conv_params); + } +} + void conv_2D_gpu( const Stream& s, metal::Device& d, @@ -808,57 +865,7 @@ void conv_2D_gpu( /* const int groups = */ groups, /* const bool flip = */ flip, }; - - bool is_stride_one = conv_params.str[0] == 1 && conv_params.str[1] == 1; - bool is_kdil_one = conv_params.kdil[0] == 1 && conv_params.kdil[1] == 1; - bool is_idil_one = conv_params.idil[0] == 1 && conv_params.idil[1] == 1; - - if (is_idil_one && groups > 1) { - const int C_per_group = conv_params.C / groups; - const int O_per_group = conv_params.O / groups; - - if (C_per_group == 1 && O_per_group == 1 && is_kdil_one && - conv_params.wS[0] <= 7 && conv_params.wS[1] <= 7 && - conv_params.str[0] <= 2 && conv_params.str[1] <= 2 && - conv_params.oS[0] % 8 == 0 && conv_params.oS[1] % 8 == 0 && - conv_params.wt_strides[1] == conv_params.wS[1] && - conv_params.C % 16 == 0 && conv_params.C == conv_params.O) { - return depthwise_conv_2D_gpu(s, d, in, wt, out, conv_params); - } - - if ((C_per_group <= 4 || C_per_group % 16 == 0) && - (O_per_group <= 16 || O_per_group % 16 == 0)) { - return implicit_gemm_conv_2D_gpu(s, d, in, wt, out, conv_params); - } else { - return explicit_gemm_conv_group_ND_gpu(s, d, in, wt, out, conv_params); - } - } - - // Direct to winograd conv - bool inp_large = - (conv_params.N * conv_params.iS[0] * conv_params.iS[1]) >= 1ul << 12; - bool channels_large = (conv_params.C + conv_params.O) >= 256; - if (!flip && is_stride_one && is_kdil_one && is_idil_one && - conv_params.wS[0] == 3 && conv_params.wS[1] == 3 && - conv_params.C % 32 == 0 && conv_params.O % 32 == 0 && inp_large && - channels_large) { - return winograd_conv_2D_gpu(s, d, in, wt, out, conv_params, copies); - } - - // Direct to implicit gemm conv - if (is_idil_one && (conv_params.C <= 4 || conv_params.C % 16 == 0) && - (conv_params.O <= 16 || conv_params.O % 16 == 0)) { - return implicit_gemm_conv_2D_gpu(s, d, in, wt, out, conv_params); - } - - else if (conv_params.C % 16 == 0 && conv_params.O % 16 == 0) { - return implicit_gemm_conv_2D_general_gpu(s, d, in, wt, out, conv_params); - } - - // Direct to explicit gemm conv - else { - return explicit_gemm_conv_ND_gpu(s, d, in, wt, out, conv_params); - } + dispatch_conv_2D_gpu(s, d, in, wt, out, conv_params, copies); } void conv_3D_gpu( @@ -988,7 +995,8 @@ void Convolution::eval_gpu(const std::vector& inputs, array& out) { kernel_dilation_, input_dilation_, groups_, - flip_); + flip_, + copies); } // Throw error else { diff --git a/mlx/backend/metal/kernels/steel/conv/loaders/loader_channel_l.h b/mlx/backend/metal/kernels/steel/conv/loaders/loader_channel_l.h index dad496e81..d52642b73 100644 --- a/mlx/backend/metal/kernels/steel/conv/loaders/loader_channel_l.h +++ b/mlx/backend/metal/kernels/steel/conv/loaders/loader_channel_l.h @@ -381,6 +381,7 @@ struct Conv2DWeightBlockLoader { const constant MLXConvParams<2>* params; int weight_hw; + int weight_step; const int read_n; const bool do_read; @@ -402,6 +403,7 @@ struct Conv2DWeightBlockLoader { src(src_ + bi * src_ld + bj), params(params_), weight_hw(0), + weight_step(params->C / params->groups), read_n(offsets.y + bi), do_read(read_n + n_rows * TROWS <= gemm_params_->N) {} @@ -435,15 +437,15 @@ struct Conv2DWeightBlockLoader { /* Iteration helper */ METAL_FUNC void next() { if (++weight_hw < (params->wS[1] * params->wS[0])) { - src += params->wt_strides[2]; + src += weight_step; return; } weight_hw = 0; - src += BK - (params->wS[1] * params->wS[0] - 1) * params->wt_strides[2]; + src += BK - (params->wS[1] * params->wS[0] - 1) * weight_step; } }; } // namespace steel -} // namespace mlx \ No newline at end of file +} // namespace mlx diff --git a/mlx/backend/metal/kernels/steel/conv/loaders/loader_channel_n.h b/mlx/backend/metal/kernels/steel/conv/loaders/loader_channel_n.h index 56027916e..b0b98d21a 100644 --- a/mlx/backend/metal/kernels/steel/conv/loaders/loader_channel_n.h +++ b/mlx/backend/metal/kernels/steel/conv/loaders/loader_channel_n.h @@ -272,7 +272,7 @@ struct Conv2DWeightBlockLoaderSmallChannels { return; } - const device T* curr_src = src + weight_hw * params->wt_strides[2]; + const device T* curr_src = src + weight_hw * (params->C / params->groups); if (BN != 8 || do_read) { STEEL_PRAGMA_UNROLL @@ -316,4 +316,4 @@ struct Conv2DWeightBlockLoaderSmallChannels { }; } // namespace steel -} // namespace mlx \ No newline at end of file +} // namespace mlx diff --git a/mlx/ops.cpp b/mlx/ops.cpp index 0c18cccfe..a72c2bc85 100644 --- a/mlx/ops.cpp +++ b/mlx/ops.cpp @@ -3584,21 +3584,21 @@ Shape conv_out_shape( if (pads_lo.size() != spatial_dims || pads_hi.size() != spatial_dims) { std::ostringstream msg; - msg << "[conv] Invalid padding " << pads_lo << " | " << pads_hi << "for " + msg << "[conv] Invalid padding " << pads_lo << " | " << pads_hi << " for " << spatial_dims << "D convolution."; throw std::invalid_argument(msg.str()); } if (kernel_dilation.size() != spatial_dims) { std::ostringstream msg; - msg << "[conv] Invalid kernel dilation " << kernel_dilation << "for " + msg << "[conv] Invalid kernel dilation " << kernel_dilation << " for " << spatial_dims << "D convolution."; throw std::invalid_argument(msg.str()); } if (input_dilation.size() != spatial_dims) { std::ostringstream msg; - msg << "[conv] Invalid input dilation " << input_dilation << "for " + msg << "[conv] Invalid input dilation " << input_dilation << " for " << spatial_dims << "D convolution."; throw std::invalid_argument(msg.str()); } diff --git a/python/tests/test_conv.py b/python/tests/test_conv.py index 7d63e4751..9fe11286d 100644 --- a/python/tests/test_conv.py +++ b/python/tests/test_conv.py @@ -1152,6 +1152,27 @@ class TestConv(mlx_tests.MLXTestCase): ) self.assertEqual(grads.shape, k_shape) + def test_1d_conv_with_2d(self): + x = mx.random.uniform(shape=(2, 10, 16)) + y = mx.random.normal(shape=(16, 3, 16)) + + out = mx.conv1d(x, y, padding=1) + out_2d = mx.conv2d( + mx.expand_dims(x, axis=2), mx.expand_dims(y, axis=2), padding=(1, 0) + ) + + self.assertTrue(mx.allclose(out, out_2d.squeeze(2))) + + x = mx.random.uniform(shape=(2, 10, 4)) + y = mx.random.normal(shape=(4, 3, 4)) + + out = mx.conv1d(x, y, padding=1) + out_2d = mx.conv2d( + mx.expand_dims(x, axis=2), mx.expand_dims(y, axis=2), padding=(1, 0) + ) + + self.assertTrue(mx.allclose(out, out_2d.squeeze(2))) + if __name__ == "__main__": unittest.main() diff --git a/python/tests/test_vmap.py b/python/tests/test_vmap.py index e571678d3..ddfceb0a1 100644 --- a/python/tests/test_vmap.py +++ b/python/tests/test_vmap.py @@ -634,6 +634,7 @@ class TestVmap(mlx_tests.MLXTestCase): self.assertEqual(fy.shape, (4, 5, 6, 7)) def test_leaks(self): + mx.synchronize() if mx.metal.is_available(): mem_pre = mx.get_active_memory() else: