From 849fee90f329eb01508b8997ad46c8547647f93b Mon Sep 17 00:00:00 2001 From: Cheng Date: Sun, 17 Aug 2025 17:23:23 -0700 Subject: [PATCH] Add gemm_grouped_conv --- mlx/backend/cuda/CMakeLists.txt | 1 + mlx/backend/cuda/conv/conv.h | 64 +++++- mlx/backend/cuda/conv/gemm_conv.cu | 27 +-- mlx/backend/cuda/conv/gemm_grouped_conv.cu | 231 +++++++++++++++++++++ mlx/backend/cuda/gemms/cublas_gemm.cpp | 19 ++ mlx/backend/cuda/gemms/cublas_gemm.h | 11 + python/tests/cuda_skip.py | 2 - 7 files changed, 330 insertions(+), 25 deletions(-) create mode 100644 mlx/backend/cuda/conv/gemm_grouped_conv.cu diff --git a/mlx/backend/cuda/CMakeLists.txt b/mlx/backend/cuda/CMakeLists.txt index eff16c98b..994307284 100644 --- a/mlx/backend/cuda/CMakeLists.txt +++ b/mlx/backend/cuda/CMakeLists.txt @@ -17,6 +17,7 @@ target_sources( ${CMAKE_CURRENT_SOURCE_DIR}/copy/copy_general_input.cu ${CMAKE_CURRENT_SOURCE_DIR}/conv.cpp ${CMAKE_CURRENT_SOURCE_DIR}/conv/gemm_conv.cu + ${CMAKE_CURRENT_SOURCE_DIR}/conv/gemm_grouped_conv.cu ${CMAKE_CURRENT_SOURCE_DIR}/cuda.cpp ${CMAKE_CURRENT_SOURCE_DIR}/cudnn_utils.cpp ${CMAKE_CURRENT_SOURCE_DIR}/device.cpp diff --git a/mlx/backend/cuda/conv/conv.h b/mlx/backend/cuda/conv/conv.h index 7fd0a37ae..62dc9343e 100644 --- a/mlx/backend/cuda/conv/conv.h +++ b/mlx/backend/cuda/conv/conv.h @@ -3,6 +3,7 @@ #pragma once #include "mlx/backend/cuda/device.h" +#include "mlx/backend/gpu/copy.h" namespace mlx::core { @@ -48,7 +49,32 @@ struct ConvParams { } }; +void gemm_grouped_conv( + cu::CommandEncoder& encoder, + const array& in, + const array& wt, + array& out, + const std::vector& strides, + const std::vector& padding, + const std::vector& kernel_dilation, + const std::vector& input_dilation, + int groups, + bool flip, + Stream s); + void gemm_conv( + cu::CommandEncoder& encoder, + const array& in, + const array& wt, + array& out, + const std::vector& strides, + const std::vector& padding, + const std::vector& kernel_dilation, + const std::vector& input_dilation, + bool flip, + Stream s); + +inline void gemm_conv( cu::CommandEncoder& encoder, array in, array wt, @@ -59,6 +85,42 @@ void gemm_conv( const std::vector& input_dilation, int groups, bool flip, - Stream s); + Stream s) { + if (!in.flags().row_contiguous) { + in = contiguous_copy_gpu(in, s); + encoder.add_temporary(in); + } + if (!wt.flags().row_contiguous) { + wt = contiguous_copy_gpu(wt, s); + encoder.add_temporary(wt); + } + + if (groups == 1) { + gemm_conv( + encoder, + in, + wt, + out, + strides, + padding, + kernel_dilation, + input_dilation, + flip, + s); + } else { + gemm_grouped_conv( + encoder, + in, + wt, + out, + strides, + padding, + kernel_dilation, + input_dilation, + groups, + flip, + s); + } +} } // namespace mlx::core diff --git a/mlx/backend/cuda/conv/gemm_conv.cu b/mlx/backend/cuda/conv/gemm_conv.cu index 3aae9dae5..11a78a7ab 100644 --- a/mlx/backend/cuda/conv/gemm_conv.cu +++ b/mlx/backend/cuda/conv/gemm_conv.cu @@ -3,7 +3,6 @@ #include "mlx/backend/cuda/conv/conv.h" #include "mlx/backend/cuda/gemms/cublas_gemm.h" #include "mlx/backend/cuda/kernel_utils.cuh" -#include "mlx/backend/gpu/copy.h" #include "mlx/dtype_utils.h" #include @@ -137,21 +136,16 @@ void gemm_conv_nd( array& out, ConvParams& params, Stream s) { - if (params.groups > 1) { - throw std::runtime_error( - "[conv] gemm_conv does not support grouped convolution yet."); - } - // Get gemm shapes. int mat_M = out.size() / params.O; // N * H_out * W_out int mat_K = wt.size() / params.O; // C * H_wt * W_wt int mat_N = params.O; // O - // Unfold input for gemm. + // Unfold input to (N * H_out * W_out, C * H_wt * W_wt) for gemm. array in_unfolded = unfold_inputs_nd(encoder, in, mat_M, mat_K, mat_N, params); - // Reshape weight for gemm. + // Reshape weight to (C * H_wt * W_wt, O) for gemm. array wt_reshaped({mat_K, mat_N}, wt.dtype(), nullptr, {}); wt_reshaped.copy_shared_buffer( wt, @@ -191,14 +185,13 @@ void gemm_conv_nd( void gemm_conv( cu::CommandEncoder& encoder, - array in, - array wt, + const array& in, + const array& wt, array& out, const std::vector& strides, const std::vector& padding, const std::vector& kernel_dilation, const std::vector& input_dilation, - int groups, bool flip, Stream s) { int conv_ndim = in.ndim() - 2; @@ -206,16 +199,6 @@ void gemm_conv( throw std::runtime_error( fmt::format("[conv] Unsupported gemm_conv for {}D conv.", conv_ndim)); } - - if (!in.flags().row_contiguous) { - in = contiguous_copy_gpu(in, s); - encoder.add_temporary(in); - } - if (!wt.flags().row_contiguous) { - wt = contiguous_copy_gpu(wt, s); - encoder.add_temporary(wt); - } - dispatch_1_2_3(conv_ndim, [&](auto ndim_constant) { ConvParams params( in, @@ -225,7 +208,7 @@ void gemm_conv( padding, kernel_dilation, input_dilation, - groups, + 1, // groups flip); gemm_conv_nd(encoder, in, wt, out, params, s); }); diff --git a/mlx/backend/cuda/conv/gemm_grouped_conv.cu b/mlx/backend/cuda/conv/gemm_grouped_conv.cu new file mode 100644 index 000000000..7ceb58166 --- /dev/null +++ b/mlx/backend/cuda/conv/gemm_grouped_conv.cu @@ -0,0 +1,231 @@ +// Copyright © 2025 Apple Inc. + +#include "mlx/backend/cuda/conv/conv.h" +#include "mlx/backend/cuda/gemms/cublas_gemm.h" +#include "mlx/backend/cuda/kernel_utils.cuh" +#include "mlx/dtype_utils.h" + +#include + +namespace mlx::core { + +namespace cu { + +namespace cg = cooperative_groups; + +template +__global__ void naive_grouped_unfold_transpose_nd( + const T* in, + T* out, + int filter_size, + int out_pixels, + const __grid_constant__ ConvParams params) { + auto block = cg::this_thread_block(); + auto tid = block.group_index(); + auto lid = block.thread_index(); + + int index_batch = tid.z / out_pixels; // [0, N) + int index_out_spatial = tid.z % out_pixels; // [0, H_out * W_out) + int index_wt_spatial = + tid.x * block.dim_threads().x + lid.x; // [0, H_wt * W_wt) + + if (index_wt_spatial >= filter_size / params.C) { + return; + } + + in += tid.y; // [0, C) + out += tid.z * filter_size + tid.y * (filter_size / params.C); + + bool valid = index_batch < params.N; + + // Get the coordinates in input. + int index_in[NDIM] = {}; + int wt_stride = 1; +#pragma unroll + for (int i = NDIM - 1; i >= 0; --i) { + int index_out = index_out_spatial % params.out_spatial_dims[i]; + int index_wt = index_wt_spatial % params.wt_spatial_dims[i]; + out += index_wt * wt_stride; + + if (params.flip) { + index_wt = params.wt_spatial_dims[i] - index_wt - 1; + } + + int index = index_out * params.strides[i] - params.padding[i] + + index_wt * params.kernel_dilation[i]; + int index_max = + 1 + params.input_dilation[i] * (params.in_spatial_dims[i] - 1); + + valid &= (index >= 0) && (index < index_max) && + (index % params.input_dilation[i] == 0); + + index_in[i] = index / params.input_dilation[i]; + + index_out_spatial /= params.out_spatial_dims[i]; + index_wt_spatial /= params.wt_spatial_dims[i]; + wt_stride *= params.wt_spatial_dims[i]; + } + + if (valid) { + int in_offset = index_batch * params.in_strides[0]; +#pragma unroll + for (int i = 0; i < NDIM; ++i) { + in_offset += index_in[i] * params.in_strides[i + 1]; + } + *out = in[in_offset]; + } else { + *out = T{0}; + } +} + +} // namespace cu + +template +array grouped_unfold_transpose_inputs_nd( + cu::CommandEncoder& encoder, + const array& in, + int mat_M, + int mat_K, + int mat_N, + ConvParams& params) { + array unfolded({mat_M, mat_K * params.groups}, in.dtype(), nullptr, {}); + unfolded.set_data(allocator::malloc(unfolded.nbytes())); + encoder.add_temporary(unfolded); + + int filter_size = params.C; +#pragma unroll + for (int i = 0; i < NDIM; ++i) { + filter_size *= params.wt_spatial_dims[i]; + } + + int out_pixels = 1; +#pragma unroll + for (int i = 0; i < NDIM; ++i) { + out_pixels *= params.out_spatial_dims[i]; + } + + int wt_spatial_size = (mat_K * params.groups) / params.C; + dim3 block_dims; + block_dims.x = std::min(std::max(wt_spatial_size, 32), 1024); + dim3 num_blocks; + num_blocks.x = cuda::ceil_div(wt_spatial_size, block_dims.x); + num_blocks.y = params.C; + num_blocks.z = mat_M; + + encoder.set_input_array(in); + encoder.set_output_array(unfolded); + dispatch_float_types(in.dtype(), "unfold", [&](auto type_tag) { + using DataType = cuda_type_t; + encoder.add_kernel_node( + cu::naive_grouped_unfold_transpose_nd, + num_blocks, + block_dims, + 0, + in.data(), + unfolded.data(), + filter_size, + out_pixels, + params); + }); + + return unfolded; +} + +template +void gemm_grouped_conv_nd( + cu::CommandEncoder& encoder, + const array& in, + const array& wt, + array& out, + ConvParams& params, + Stream s) { + // Get gemm shapes. + int C_per_group = params.C / params.groups; + int O_per_group = params.O / params.groups; + int mat_M = out.size() / params.O; // N * H_out * W_out + int mat_K = wt.size() / params.O; // C_per_group * H_wt * W_wt + int mat_N = O_per_group; // O_per_group + + // Unfold input to (N * H_out * W_out, C * H_wt * W_wt) for gemm. + array in_unfolded = grouped_unfold_transpose_inputs_nd( + encoder, in, mat_M, mat_K, mat_N, params); + + // Reshape weight to (O, C_per_group, H_wt * W_wt) for gemm. + int wt_spatial_size = (wt.size() / wt.shape(0)) / wt.shape(-1); + array wt_view( + {params.O, C_per_group, wt_spatial_size}, wt.dtype(), nullptr, {}); + wt_view.copy_shared_buffer( + wt, {wt.strides(0), 1, C_per_group}, wt.flags(), wt.size()); + array wt_reshaped = contiguous_copy_gpu(wt_view, s); + + // Batch with size of groups. + Shape batch_shape{params.groups}; + Strides a_batch_strides{mat_K}; + Strides b_batch_strides{mat_N * mat_K}; + + // Run matmul. + CublasGemm gemm( + encoder.device(), + in.dtype(), + false, // a_transposed + mat_M, // a_rows + mat_K, // a_cols + mat_K * params.groups, // lda + true, // b_transposed + mat_K, // b_rows + mat_N, // b_cols + mat_K, // ldb + batch_shape.back(), + a_batch_strides.back(), + b_batch_strides.back()); + gemm.set_out( + out.dtype(), + false, // out_transposed + mat_M, // out_rows + mat_N, // out_cols + mat_N * params.groups, // out_ld + params.groups, // batch_count + mat_N); // batch_stride + gemm.run( + encoder, + out, + in_unfolded, + wt_reshaped, + batch_shape, + a_batch_strides, + b_batch_strides); +} + +void gemm_grouped_conv( + cu::CommandEncoder& encoder, + const array& in, + const array& wt, + array& out, + const std::vector& strides, + const std::vector& padding, + const std::vector& kernel_dilation, + const std::vector& input_dilation, + int groups, + bool flip, + Stream s) { + int conv_ndim = in.ndim() - 2; + if (conv_ndim < 1 || conv_ndim > 3) { + throw std::runtime_error( + fmt::format("[conv] Unsupported gemm_conv for {}D conv.", conv_ndim)); + } + dispatch_1_2_3(conv_ndim, [&](auto ndim_constant) { + ConvParams params( + in, + wt, + out, + strides, + padding, + kernel_dilation, + input_dilation, + groups, + flip); + gemm_grouped_conv_nd(encoder, in, wt, out, params, s); + }); +} + +} // namespace mlx::core diff --git a/mlx/backend/cuda/gemms/cublas_gemm.cpp b/mlx/backend/cuda/gemms/cublas_gemm.cpp index 1aeeefa38..836385dfe 100644 --- a/mlx/backend/cuda/gemms/cublas_gemm.cpp +++ b/mlx/backend/cuda/gemms/cublas_gemm.cpp @@ -202,6 +202,25 @@ CublasGemm::~CublasGemm() { CHECK_CUBLAS_ERROR(cublasLtMatmulDescDestroy(matmul_desc_)); } +void CublasGemm::set_out( + Dtype dtype, + bool transposed, + uint64_t rows, + uint64_t cols, + int64_t ld, + int32_t batch_count, + int64_t batch_stride) { + CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutDestroy(out_desc_)); + out_desc_ = create_matrix_layout( + dtype_to_cublas_type(dtype), + rows, + cols, + transposed, + ld, + batch_count, + batch_stride); +} + void CublasGemm::run( cu::CommandEncoder& encoder, array& out, diff --git a/mlx/backend/cuda/gemms/cublas_gemm.h b/mlx/backend/cuda/gemms/cublas_gemm.h index e093351b6..1b06fb2f7 100644 --- a/mlx/backend/cuda/gemms/cublas_gemm.h +++ b/mlx/backend/cuda/gemms/cublas_gemm.h @@ -44,6 +44,17 @@ class CublasGemm { ~CublasGemm(); + // The output's descriptor is inferred from inputs by default, use this method + // for unusual output. + void set_out( + Dtype dtype, + bool transposed, + uint64_t rows, + uint64_t cols, + int64_t ld, + int32_t batch_count, + int64_t batch_stride); + void run( cu::CommandEncoder& encoder, array& out, diff --git a/python/tests/cuda_skip.py b/python/tests/cuda_skip.py index ca5e18bb6..78639da21 100644 --- a/python/tests/cuda_skip.py +++ b/python/tests/cuda_skip.py @@ -15,8 +15,6 @@ cuda_skip = { "TestOps.test_hadamard_grad_vmap", # Convolutions NYI "TestConv.test_1d_conv_with_2d", - "TestConv.test_conv_1d_groups_flipped", - "TestConv.test_torch_conv_depthwise", # FFTs NYI "TestFFT.test_fft", "TestFFT.test_fft_big_powers_of_two",