mirror of
https://github.com/ml-explore/mlx.git
synced 2025-08-21 12:06:42 +08:00
[CUDA] Add GEMM-based fallback convolution kernels (#2511)
* Add gemm_conv * Add gemm_grouped_conv
This commit is contained in:
parent
65d0d40232
commit
ac85ddfdb7
@ -16,6 +16,8 @@ target_sources(
|
|||||||
${CMAKE_CURRENT_SOURCE_DIR}/copy/copy_general_dynamic.cu
|
${CMAKE_CURRENT_SOURCE_DIR}/copy/copy_general_dynamic.cu
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/copy/copy_general_input.cu
|
${CMAKE_CURRENT_SOURCE_DIR}/copy/copy_general_input.cu
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/conv.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/conv.cpp
|
||||||
|
${CMAKE_CURRENT_SOURCE_DIR}/conv/gemm_conv.cu
|
||||||
|
${CMAKE_CURRENT_SOURCE_DIR}/conv/gemm_grouped_conv.cu
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/cuda.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/cuda.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/cudnn_utils.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/cudnn_utils.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/device.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/device.cpp
|
||||||
|
@ -1,5 +1,6 @@
|
|||||||
// Copyright © 2025 Apple Inc.
|
// Copyright © 2025 Apple Inc.
|
||||||
|
|
||||||
|
#include "mlx/backend/cuda/conv/conv.h"
|
||||||
#include "mlx/backend/cuda/cudnn_utils.h"
|
#include "mlx/backend/cuda/cudnn_utils.h"
|
||||||
#include "mlx/backend/cuda/device.h"
|
#include "mlx/backend/cuda/device.h"
|
||||||
#include "mlx/backend/cuda/lru_cache.h"
|
#include "mlx/backend/cuda/lru_cache.h"
|
||||||
@ -21,6 +22,9 @@ namespace {
|
|||||||
#define CONV_BACKWARD_WEIGHT \
|
#define CONV_BACKWARD_WEIGHT \
|
||||||
CUDNN_BACKEND_OPERATION_CONVOLUTION_BACKWARD_FILTER_DESCRIPTOR
|
CUDNN_BACKEND_OPERATION_CONVOLUTION_BACKWARD_FILTER_DESCRIPTOR
|
||||||
|
|
||||||
|
// Custom placeholder representing fallback kernel.
|
||||||
|
#define CONV_FALLBACK static_cast<cudnnBackendDescriptorType_t>(-1)
|
||||||
|
|
||||||
struct ConvCacheKey {
|
struct ConvCacheKey {
|
||||||
int device_id;
|
int device_id;
|
||||||
cudnnDataType_t cudnn_dtype;
|
cudnnDataType_t cudnn_dtype;
|
||||||
@ -40,7 +44,9 @@ struct ConvCacheKey {
|
|||||||
auto& conv_cache() {
|
auto& conv_cache() {
|
||||||
static LRUBytesKeyCache<
|
static LRUBytesKeyCache<
|
||||||
ConvCacheKey,
|
ConvCacheKey,
|
||||||
std::pair<cudnnBackendDescriptorType_t, cudnn_frontend::ExecutionPlan>>
|
std::pair<
|
||||||
|
cudnnBackendDescriptorType_t,
|
||||||
|
std::optional<cudnn_frontend::ExecutionPlan>>>
|
||||||
cache(/* capacity */ 128);
|
cache(/* capacity */ 128);
|
||||||
return cache;
|
return cache;
|
||||||
}
|
}
|
||||||
@ -292,12 +298,29 @@ void Convolution::eval_gpu(const std::vector<array>& inputs, array& out_) {
|
|||||||
get_alignment(out)};
|
get_alignment(out)};
|
||||||
if (auto it = conv_cache().find(cache_key); it != conv_cache().end()) {
|
if (auto it = conv_cache().find(cache_key); it != conv_cache().end()) {
|
||||||
auto& [backend_type, plan] = it->second;
|
auto& [backend_type, plan] = it->second;
|
||||||
std::tie(in, wt, out) =
|
if (plan) {
|
||||||
prepare_args(encoder, backend_type, in, wt, out, groups_, s);
|
// Run cached plan.
|
||||||
register_args(encoder, backend_type, in, wt, out, out_);
|
std::tie(in, wt, out) =
|
||||||
auto [x, w, y] = dispatch_args(backend_type, in, wt, out);
|
prepare_args(encoder, backend_type, in, wt, out, groups_, s);
|
||||||
if (!encode_cudnn_plan(encoder, plan, {'x', 'w', 'y'}, x, w, y)) {
|
register_args(encoder, backend_type, in, wt, out, out_);
|
||||||
throw std::runtime_error("[conv] Cached plan failed to execute.");
|
auto [x, w, y] = dispatch_args(backend_type, in, wt, out);
|
||||||
|
if (!encode_cudnn_plan(encoder, *plan, {'x', 'w', 'y'}, x, w, y)) {
|
||||||
|
throw std::runtime_error("[conv] Cached plan failed to execute.");
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// Run fallback kernel.
|
||||||
|
gemm_conv(
|
||||||
|
encoder,
|
||||||
|
in,
|
||||||
|
wt,
|
||||||
|
out,
|
||||||
|
kernel_strides_,
|
||||||
|
padding_lo_,
|
||||||
|
kernel_dilation_,
|
||||||
|
input_dilation_,
|
||||||
|
groups_,
|
||||||
|
flip_,
|
||||||
|
s);
|
||||||
}
|
}
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
@ -357,25 +380,39 @@ void Convolution::eval_gpu(const std::vector<array>& inputs, array& out_) {
|
|||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if (!op_graph) {
|
|
||||||
throw std::runtime_error("[conv] Can not build op graph.");
|
if (op_graph) {
|
||||||
|
// Setup inputs and outputs.
|
||||||
|
register_args(encoder, backend_type, in, wt, out, out_);
|
||||||
|
|
||||||
|
// Find a plan for the graph and execute it.
|
||||||
|
auto plan = find_cudnn_plan_from_op_graph(
|
||||||
|
encoder.device().cudnn_handle(), backend_type, dtype, *op_graph);
|
||||||
|
if (!plan) {
|
||||||
|
throw std::runtime_error("[conv] Unable to find an execution plan.");
|
||||||
|
}
|
||||||
|
auto [x, w, y] = dispatch_args(backend_type, in, wt, out);
|
||||||
|
if (encode_cudnn_plan(encoder, *plan, {'x', 'w', 'y'}, x, w, y)) {
|
||||||
|
conv_cache().emplace(
|
||||||
|
cache_key, std::make_pair(backend_type, std::move(*plan)));
|
||||||
|
return;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Setup inputs and outputs.
|
// Use fallback kernel for settings not supported by cuDNN.
|
||||||
register_args(encoder, backend_type, in, wt, out, out_);
|
gemm_conv(
|
||||||
|
encoder,
|
||||||
// Find a plan for the graph and execute it.
|
in,
|
||||||
auto plan = find_cudnn_plan_from_op_graph(
|
wt,
|
||||||
encoder.device().cudnn_handle(), backend_type, dtype, *op_graph);
|
out,
|
||||||
if (!plan) {
|
kernel_strides_,
|
||||||
throw std::runtime_error("[conv] Unable to find an execution plan.");
|
padding_lo_,
|
||||||
}
|
kernel_dilation_,
|
||||||
auto [x, w, y] = dispatch_args(backend_type, in, wt, out);
|
input_dilation_,
|
||||||
if (!encode_cudnn_plan(encoder, *plan, {'x', 'w', 'y'}, x, w, y)) {
|
groups_,
|
||||||
throw std::runtime_error("[conv] Failed to run execution plan.");
|
flip_,
|
||||||
}
|
s);
|
||||||
conv_cache().emplace(
|
conv_cache().emplace(cache_key, std::make_pair(CONV_FALLBACK, std::nullopt));
|
||||||
cache_key, std::make_pair(backend_type, std::move(*plan)));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace mlx::core
|
} // namespace mlx::core
|
||||||
|
126
mlx/backend/cuda/conv/conv.h
Normal file
126
mlx/backend/cuda/conv/conv.h
Normal file
@ -0,0 +1,126 @@
|
|||||||
|
// Copyright © 2025 Apple Inc.
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include "mlx/backend/cuda/device.h"
|
||||||
|
#include "mlx/backend/gpu/copy.h"
|
||||||
|
|
||||||
|
namespace mlx::core {
|
||||||
|
|
||||||
|
template <int NDIM>
|
||||||
|
struct ConvParams {
|
||||||
|
int N; // Batch size
|
||||||
|
int C; // In channels
|
||||||
|
int O; // Out channels
|
||||||
|
int strides[NDIM];
|
||||||
|
int padding[NDIM];
|
||||||
|
int kernel_dilation[NDIM];
|
||||||
|
int input_dilation[NDIM];
|
||||||
|
int groups;
|
||||||
|
bool flip;
|
||||||
|
int in_spatial_dims[NDIM];
|
||||||
|
int wt_spatial_dims[NDIM];
|
||||||
|
int out_spatial_dims[NDIM];
|
||||||
|
int64_t in_strides[NDIM + 2];
|
||||||
|
|
||||||
|
ConvParams(
|
||||||
|
const array& in,
|
||||||
|
const array& wt,
|
||||||
|
const array& out,
|
||||||
|
const std::vector<int>& strides,
|
||||||
|
const std::vector<int>& padding,
|
||||||
|
const std::vector<int>& kernel_dilation,
|
||||||
|
const std::vector<int>& input_dilation,
|
||||||
|
int groups,
|
||||||
|
bool flip)
|
||||||
|
: N(in.shape(0)),
|
||||||
|
C(in.shape(-1)),
|
||||||
|
O(wt.shape(0)),
|
||||||
|
groups(groups),
|
||||||
|
flip(flip) {
|
||||||
|
std::copy_n(strides.begin(), NDIM, this->strides);
|
||||||
|
std::copy_n(padding.begin(), NDIM, this->padding);
|
||||||
|
std::copy_n(kernel_dilation.begin(), NDIM, this->kernel_dilation);
|
||||||
|
std::copy_n(input_dilation.begin(), NDIM, this->input_dilation);
|
||||||
|
std::copy_n(in.shape().begin() + 1, NDIM, this->in_spatial_dims);
|
||||||
|
std::copy_n(wt.shape().begin() + 1, NDIM, this->wt_spatial_dims);
|
||||||
|
std::copy_n(out.shape().begin() + 1, NDIM, this->out_spatial_dims);
|
||||||
|
std::copy_n(in.strides().begin(), NDIM + 2, this->in_strides);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
void gemm_grouped_conv(
|
||||||
|
cu::CommandEncoder& encoder,
|
||||||
|
const array& in,
|
||||||
|
const array& wt,
|
||||||
|
array& out,
|
||||||
|
const std::vector<int>& strides,
|
||||||
|
const std::vector<int>& padding,
|
||||||
|
const std::vector<int>& kernel_dilation,
|
||||||
|
const std::vector<int>& input_dilation,
|
||||||
|
int groups,
|
||||||
|
bool flip,
|
||||||
|
Stream s);
|
||||||
|
|
||||||
|
void gemm_conv(
|
||||||
|
cu::CommandEncoder& encoder,
|
||||||
|
const array& in,
|
||||||
|
const array& wt,
|
||||||
|
array& out,
|
||||||
|
const std::vector<int>& strides,
|
||||||
|
const std::vector<int>& padding,
|
||||||
|
const std::vector<int>& kernel_dilation,
|
||||||
|
const std::vector<int>& input_dilation,
|
||||||
|
bool flip,
|
||||||
|
Stream s);
|
||||||
|
|
||||||
|
inline void gemm_conv(
|
||||||
|
cu::CommandEncoder& encoder,
|
||||||
|
array in,
|
||||||
|
array wt,
|
||||||
|
array& out,
|
||||||
|
const std::vector<int>& strides,
|
||||||
|
const std::vector<int>& padding,
|
||||||
|
const std::vector<int>& kernel_dilation,
|
||||||
|
const std::vector<int>& input_dilation,
|
||||||
|
int groups,
|
||||||
|
bool flip,
|
||||||
|
Stream s) {
|
||||||
|
if (!in.flags().row_contiguous) {
|
||||||
|
in = contiguous_copy_gpu(in, s);
|
||||||
|
encoder.add_temporary(in);
|
||||||
|
}
|
||||||
|
if (!wt.flags().row_contiguous) {
|
||||||
|
wt = contiguous_copy_gpu(wt, s);
|
||||||
|
encoder.add_temporary(wt);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (groups == 1) {
|
||||||
|
gemm_conv(
|
||||||
|
encoder,
|
||||||
|
in,
|
||||||
|
wt,
|
||||||
|
out,
|
||||||
|
strides,
|
||||||
|
padding,
|
||||||
|
kernel_dilation,
|
||||||
|
input_dilation,
|
||||||
|
flip,
|
||||||
|
s);
|
||||||
|
} else {
|
||||||
|
gemm_grouped_conv(
|
||||||
|
encoder,
|
||||||
|
in,
|
||||||
|
wt,
|
||||||
|
out,
|
||||||
|
strides,
|
||||||
|
padding,
|
||||||
|
kernel_dilation,
|
||||||
|
input_dilation,
|
||||||
|
groups,
|
||||||
|
flip,
|
||||||
|
s);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace mlx::core
|
217
mlx/backend/cuda/conv/gemm_conv.cu
Normal file
217
mlx/backend/cuda/conv/gemm_conv.cu
Normal file
@ -0,0 +1,217 @@
|
|||||||
|
// Copyright © 2025 Apple Inc.
|
||||||
|
|
||||||
|
#include "mlx/backend/cuda/conv/conv.h"
|
||||||
|
#include "mlx/backend/cuda/gemms/cublas_gemm.h"
|
||||||
|
#include "mlx/backend/cuda/kernel_utils.cuh"
|
||||||
|
#include "mlx/dtype_utils.h"
|
||||||
|
|
||||||
|
#include <cooperative_groups.h>
|
||||||
|
|
||||||
|
namespace mlx::core {
|
||||||
|
|
||||||
|
namespace cu {
|
||||||
|
|
||||||
|
namespace cg = cooperative_groups;
|
||||||
|
|
||||||
|
template <typename T, int NDIM>
|
||||||
|
__global__ void naive_unfold_nd(
|
||||||
|
const T* in,
|
||||||
|
T* out,
|
||||||
|
int filter_size,
|
||||||
|
int out_pixels,
|
||||||
|
const __grid_constant__ ConvParams<NDIM> params) {
|
||||||
|
auto block = cg::this_thread_block();
|
||||||
|
auto tid = block.group_index();
|
||||||
|
auto lid = block.thread_index();
|
||||||
|
|
||||||
|
int index_batch = tid.z / out_pixels; // [0, N)
|
||||||
|
int index_out_spatial = tid.z % out_pixels; // [0, H_out * W_out)
|
||||||
|
int index_wt_spatial =
|
||||||
|
tid.x * block.dim_threads().x + lid.x; // [0, H_wt * W_wt)
|
||||||
|
|
||||||
|
if (index_wt_spatial >= filter_size / params.C) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
in += tid.y; // [0, C)
|
||||||
|
out += tid.z * filter_size + index_wt_spatial * params.C + tid.y;
|
||||||
|
|
||||||
|
bool valid = index_batch < params.N;
|
||||||
|
|
||||||
|
// Get the coordinates in input.
|
||||||
|
int index_in[NDIM] = {};
|
||||||
|
#pragma unroll
|
||||||
|
for (int i = NDIM - 1; i >= 0; --i) {
|
||||||
|
int index_out = index_out_spatial % params.out_spatial_dims[i];
|
||||||
|
int index_wt = index_wt_spatial % params.wt_spatial_dims[i];
|
||||||
|
|
||||||
|
if (params.flip) {
|
||||||
|
index_wt = params.wt_spatial_dims[i] - index_wt - 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
int index = index_out * params.strides[i] - params.padding[i] +
|
||||||
|
index_wt * params.kernel_dilation[i];
|
||||||
|
int index_max =
|
||||||
|
1 + params.input_dilation[i] * (params.in_spatial_dims[i] - 1);
|
||||||
|
|
||||||
|
valid &= (index >= 0) && (index < index_max) &&
|
||||||
|
(index % params.input_dilation[i] == 0);
|
||||||
|
|
||||||
|
index_in[i] = index / params.input_dilation[i];
|
||||||
|
|
||||||
|
index_out_spatial /= params.out_spatial_dims[i];
|
||||||
|
index_wt_spatial /= params.wt_spatial_dims[i];
|
||||||
|
}
|
||||||
|
|
||||||
|
if (valid) {
|
||||||
|
int in_offset = index_batch * params.in_strides[0];
|
||||||
|
#pragma unroll
|
||||||
|
for (int i = 0; i < NDIM; ++i) {
|
||||||
|
in_offset += index_in[i] * params.in_strides[i + 1];
|
||||||
|
}
|
||||||
|
*out = in[in_offset];
|
||||||
|
} else {
|
||||||
|
*out = T{0};
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace cu
|
||||||
|
|
||||||
|
template <int NDIM>
|
||||||
|
array unfold_inputs_nd(
|
||||||
|
cu::CommandEncoder& encoder,
|
||||||
|
const array& in,
|
||||||
|
int mat_M,
|
||||||
|
int mat_K,
|
||||||
|
int mat_N,
|
||||||
|
ConvParams<NDIM>& params) {
|
||||||
|
array unfolded({mat_M, mat_K}, in.dtype(), nullptr, {});
|
||||||
|
unfolded.set_data(allocator::malloc(unfolded.nbytes()));
|
||||||
|
encoder.add_temporary(unfolded);
|
||||||
|
|
||||||
|
int filter_size = params.C;
|
||||||
|
#pragma unroll
|
||||||
|
for (int i = 0; i < NDIM; ++i) {
|
||||||
|
filter_size *= params.wt_spatial_dims[i];
|
||||||
|
}
|
||||||
|
|
||||||
|
int out_pixels = 1;
|
||||||
|
#pragma unroll
|
||||||
|
for (int i = 0; i < NDIM; ++i) {
|
||||||
|
out_pixels *= params.out_spatial_dims[i];
|
||||||
|
}
|
||||||
|
|
||||||
|
int wt_spatial_size = mat_K / params.C;
|
||||||
|
dim3 block_dims;
|
||||||
|
block_dims.x = std::min(std::max(wt_spatial_size, 32), 1024);
|
||||||
|
dim3 num_blocks;
|
||||||
|
num_blocks.x = cuda::ceil_div(wt_spatial_size, block_dims.x);
|
||||||
|
num_blocks.y = params.C;
|
||||||
|
num_blocks.z = mat_M;
|
||||||
|
|
||||||
|
encoder.set_input_array(in);
|
||||||
|
encoder.set_output_array(unfolded);
|
||||||
|
dispatch_float_types(in.dtype(), "unfold", [&](auto type_tag) {
|
||||||
|
using DataType = cuda_type_t<MLX_GET_TYPE(type_tag)>;
|
||||||
|
encoder.add_kernel_node(
|
||||||
|
cu::naive_unfold_nd<DataType, NDIM>,
|
||||||
|
num_blocks,
|
||||||
|
block_dims,
|
||||||
|
0,
|
||||||
|
in.data<DataType>(),
|
||||||
|
unfolded.data<DataType>(),
|
||||||
|
filter_size,
|
||||||
|
out_pixels,
|
||||||
|
params);
|
||||||
|
});
|
||||||
|
|
||||||
|
return unfolded;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <int NDIM>
|
||||||
|
void gemm_conv_nd(
|
||||||
|
cu::CommandEncoder& encoder,
|
||||||
|
const array& in,
|
||||||
|
const array& wt,
|
||||||
|
array& out,
|
||||||
|
ConvParams<NDIM>& params,
|
||||||
|
Stream s) {
|
||||||
|
// Get gemm shapes.
|
||||||
|
int mat_M = out.size() / params.O; // N * H_out * W_out
|
||||||
|
int mat_K = wt.size() / params.O; // C * H_wt * W_wt
|
||||||
|
int mat_N = params.O; // O
|
||||||
|
|
||||||
|
// Unfold input to (N * H_out * W_out, C * H_wt * W_wt) for gemm.
|
||||||
|
array in_unfolded =
|
||||||
|
unfold_inputs_nd<NDIM>(encoder, in, mat_M, mat_K, mat_N, params);
|
||||||
|
|
||||||
|
// Reshape weight to (C * H_wt * W_wt, O) for gemm.
|
||||||
|
array wt_reshaped({mat_K, mat_N}, wt.dtype(), nullptr, {});
|
||||||
|
wt_reshaped.copy_shared_buffer(
|
||||||
|
wt,
|
||||||
|
{1, mat_K},
|
||||||
|
{false, false, /* col_contiguous */ true},
|
||||||
|
wt.data_size());
|
||||||
|
|
||||||
|
// Single batch.
|
||||||
|
Shape batch_shape{1};
|
||||||
|
Strides a_batch_strides{0};
|
||||||
|
Strides b_batch_strides{0};
|
||||||
|
|
||||||
|
// Run matmul.
|
||||||
|
CublasGemm gemm(
|
||||||
|
encoder.device(),
|
||||||
|
in.dtype(),
|
||||||
|
false, // a_transposed
|
||||||
|
mat_M, // a_rows
|
||||||
|
mat_K, // a_cols
|
||||||
|
mat_K, // lda
|
||||||
|
true, // b_transposed
|
||||||
|
mat_K, // b_rows
|
||||||
|
mat_N, // b_cols
|
||||||
|
mat_K, // ldb
|
||||||
|
batch_shape.back(),
|
||||||
|
a_batch_strides.back(),
|
||||||
|
b_batch_strides.back());
|
||||||
|
gemm.run(
|
||||||
|
encoder,
|
||||||
|
out,
|
||||||
|
in_unfolded,
|
||||||
|
wt_reshaped,
|
||||||
|
batch_shape,
|
||||||
|
a_batch_strides,
|
||||||
|
b_batch_strides);
|
||||||
|
}
|
||||||
|
|
||||||
|
void gemm_conv(
|
||||||
|
cu::CommandEncoder& encoder,
|
||||||
|
const array& in,
|
||||||
|
const array& wt,
|
||||||
|
array& out,
|
||||||
|
const std::vector<int>& strides,
|
||||||
|
const std::vector<int>& padding,
|
||||||
|
const std::vector<int>& kernel_dilation,
|
||||||
|
const std::vector<int>& input_dilation,
|
||||||
|
bool flip,
|
||||||
|
Stream s) {
|
||||||
|
int conv_ndim = in.ndim() - 2;
|
||||||
|
if (conv_ndim < 1 || conv_ndim > 3) {
|
||||||
|
throw std::runtime_error(
|
||||||
|
fmt::format("[conv] Unsupported gemm_conv for {}D conv.", conv_ndim));
|
||||||
|
}
|
||||||
|
dispatch_1_2_3(conv_ndim, [&](auto ndim_constant) {
|
||||||
|
ConvParams<ndim_constant()> params(
|
||||||
|
in,
|
||||||
|
wt,
|
||||||
|
out,
|
||||||
|
strides,
|
||||||
|
padding,
|
||||||
|
kernel_dilation,
|
||||||
|
input_dilation,
|
||||||
|
1, // groups
|
||||||
|
flip);
|
||||||
|
gemm_conv_nd<ndim_constant()>(encoder, in, wt, out, params, s);
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace mlx::core
|
231
mlx/backend/cuda/conv/gemm_grouped_conv.cu
Normal file
231
mlx/backend/cuda/conv/gemm_grouped_conv.cu
Normal file
@ -0,0 +1,231 @@
|
|||||||
|
// Copyright © 2025 Apple Inc.
|
||||||
|
|
||||||
|
#include "mlx/backend/cuda/conv/conv.h"
|
||||||
|
#include "mlx/backend/cuda/gemms/cublas_gemm.h"
|
||||||
|
#include "mlx/backend/cuda/kernel_utils.cuh"
|
||||||
|
#include "mlx/dtype_utils.h"
|
||||||
|
|
||||||
|
#include <cooperative_groups.h>
|
||||||
|
|
||||||
|
namespace mlx::core {
|
||||||
|
|
||||||
|
namespace cu {
|
||||||
|
|
||||||
|
namespace cg = cooperative_groups;
|
||||||
|
|
||||||
|
template <typename T, int NDIM>
|
||||||
|
__global__ void naive_grouped_unfold_transpose_nd(
|
||||||
|
const T* in,
|
||||||
|
T* out,
|
||||||
|
int filter_size,
|
||||||
|
int out_pixels,
|
||||||
|
const __grid_constant__ ConvParams<NDIM> params) {
|
||||||
|
auto block = cg::this_thread_block();
|
||||||
|
auto tid = block.group_index();
|
||||||
|
auto lid = block.thread_index();
|
||||||
|
|
||||||
|
int index_batch = tid.z / out_pixels; // [0, N)
|
||||||
|
int index_out_spatial = tid.z % out_pixels; // [0, H_out * W_out)
|
||||||
|
int index_wt_spatial =
|
||||||
|
tid.x * block.dim_threads().x + lid.x; // [0, H_wt * W_wt)
|
||||||
|
|
||||||
|
if (index_wt_spatial >= filter_size / params.C) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
in += tid.y; // [0, C)
|
||||||
|
out += tid.z * filter_size + tid.y * (filter_size / params.C);
|
||||||
|
|
||||||
|
bool valid = index_batch < params.N;
|
||||||
|
|
||||||
|
// Get the coordinates in input.
|
||||||
|
int index_in[NDIM] = {};
|
||||||
|
int wt_stride = 1;
|
||||||
|
#pragma unroll
|
||||||
|
for (int i = NDIM - 1; i >= 0; --i) {
|
||||||
|
int index_out = index_out_spatial % params.out_spatial_dims[i];
|
||||||
|
int index_wt = index_wt_spatial % params.wt_spatial_dims[i];
|
||||||
|
out += index_wt * wt_stride;
|
||||||
|
|
||||||
|
if (params.flip) {
|
||||||
|
index_wt = params.wt_spatial_dims[i] - index_wt - 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
int index = index_out * params.strides[i] - params.padding[i] +
|
||||||
|
index_wt * params.kernel_dilation[i];
|
||||||
|
int index_max =
|
||||||
|
1 + params.input_dilation[i] * (params.in_spatial_dims[i] - 1);
|
||||||
|
|
||||||
|
valid &= (index >= 0) && (index < index_max) &&
|
||||||
|
(index % params.input_dilation[i] == 0);
|
||||||
|
|
||||||
|
index_in[i] = index / params.input_dilation[i];
|
||||||
|
|
||||||
|
index_out_spatial /= params.out_spatial_dims[i];
|
||||||
|
index_wt_spatial /= params.wt_spatial_dims[i];
|
||||||
|
wt_stride *= params.wt_spatial_dims[i];
|
||||||
|
}
|
||||||
|
|
||||||
|
if (valid) {
|
||||||
|
int in_offset = index_batch * params.in_strides[0];
|
||||||
|
#pragma unroll
|
||||||
|
for (int i = 0; i < NDIM; ++i) {
|
||||||
|
in_offset += index_in[i] * params.in_strides[i + 1];
|
||||||
|
}
|
||||||
|
*out = in[in_offset];
|
||||||
|
} else {
|
||||||
|
*out = T{0};
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace cu
|
||||||
|
|
||||||
|
template <int NDIM>
|
||||||
|
array grouped_unfold_transpose_inputs_nd(
|
||||||
|
cu::CommandEncoder& encoder,
|
||||||
|
const array& in,
|
||||||
|
int mat_M,
|
||||||
|
int mat_K,
|
||||||
|
int mat_N,
|
||||||
|
ConvParams<NDIM>& params) {
|
||||||
|
array unfolded({mat_M, mat_K * params.groups}, in.dtype(), nullptr, {});
|
||||||
|
unfolded.set_data(allocator::malloc(unfolded.nbytes()));
|
||||||
|
encoder.add_temporary(unfolded);
|
||||||
|
|
||||||
|
int filter_size = params.C;
|
||||||
|
#pragma unroll
|
||||||
|
for (int i = 0; i < NDIM; ++i) {
|
||||||
|
filter_size *= params.wt_spatial_dims[i];
|
||||||
|
}
|
||||||
|
|
||||||
|
int out_pixels = 1;
|
||||||
|
#pragma unroll
|
||||||
|
for (int i = 0; i < NDIM; ++i) {
|
||||||
|
out_pixels *= params.out_spatial_dims[i];
|
||||||
|
}
|
||||||
|
|
||||||
|
int wt_spatial_size = (mat_K * params.groups) / params.C;
|
||||||
|
dim3 block_dims;
|
||||||
|
block_dims.x = std::min(std::max(wt_spatial_size, 32), 1024);
|
||||||
|
dim3 num_blocks;
|
||||||
|
num_blocks.x = cuda::ceil_div(wt_spatial_size, block_dims.x);
|
||||||
|
num_blocks.y = params.C;
|
||||||
|
num_blocks.z = mat_M;
|
||||||
|
|
||||||
|
encoder.set_input_array(in);
|
||||||
|
encoder.set_output_array(unfolded);
|
||||||
|
dispatch_float_types(in.dtype(), "unfold", [&](auto type_tag) {
|
||||||
|
using DataType = cuda_type_t<MLX_GET_TYPE(type_tag)>;
|
||||||
|
encoder.add_kernel_node(
|
||||||
|
cu::naive_grouped_unfold_transpose_nd<DataType, NDIM>,
|
||||||
|
num_blocks,
|
||||||
|
block_dims,
|
||||||
|
0,
|
||||||
|
in.data<DataType>(),
|
||||||
|
unfolded.data<DataType>(),
|
||||||
|
filter_size,
|
||||||
|
out_pixels,
|
||||||
|
params);
|
||||||
|
});
|
||||||
|
|
||||||
|
return unfolded;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <int NDIM>
|
||||||
|
void gemm_grouped_conv_nd(
|
||||||
|
cu::CommandEncoder& encoder,
|
||||||
|
const array& in,
|
||||||
|
const array& wt,
|
||||||
|
array& out,
|
||||||
|
ConvParams<NDIM>& params,
|
||||||
|
Stream s) {
|
||||||
|
// Get gemm shapes.
|
||||||
|
int C_per_group = params.C / params.groups;
|
||||||
|
int O_per_group = params.O / params.groups;
|
||||||
|
int mat_M = out.size() / params.O; // N * H_out * W_out
|
||||||
|
int mat_K = wt.size() / params.O; // C_per_group * H_wt * W_wt
|
||||||
|
int mat_N = O_per_group; // O_per_group
|
||||||
|
|
||||||
|
// Unfold input to (N * H_out * W_out, C * H_wt * W_wt) for gemm.
|
||||||
|
array in_unfolded = grouped_unfold_transpose_inputs_nd<NDIM>(
|
||||||
|
encoder, in, mat_M, mat_K, mat_N, params);
|
||||||
|
|
||||||
|
// Reshape weight to (O, C_per_group, H_wt * W_wt) for gemm.
|
||||||
|
int wt_spatial_size = (wt.size() / wt.shape(0)) / wt.shape(-1);
|
||||||
|
array wt_view(
|
||||||
|
{params.O, C_per_group, wt_spatial_size}, wt.dtype(), nullptr, {});
|
||||||
|
wt_view.copy_shared_buffer(
|
||||||
|
wt, {wt.strides(0), 1, C_per_group}, wt.flags(), wt.size());
|
||||||
|
array wt_reshaped = contiguous_copy_gpu(wt_view, s);
|
||||||
|
|
||||||
|
// Batch with size of groups.
|
||||||
|
Shape batch_shape{params.groups};
|
||||||
|
Strides a_batch_strides{mat_K};
|
||||||
|
Strides b_batch_strides{mat_N * mat_K};
|
||||||
|
|
||||||
|
// Run matmul.
|
||||||
|
CublasGemm gemm(
|
||||||
|
encoder.device(),
|
||||||
|
in.dtype(),
|
||||||
|
false, // a_transposed
|
||||||
|
mat_M, // a_rows
|
||||||
|
mat_K, // a_cols
|
||||||
|
mat_K * params.groups, // lda
|
||||||
|
true, // b_transposed
|
||||||
|
mat_K, // b_rows
|
||||||
|
mat_N, // b_cols
|
||||||
|
mat_K, // ldb
|
||||||
|
batch_shape.back(),
|
||||||
|
a_batch_strides.back(),
|
||||||
|
b_batch_strides.back());
|
||||||
|
gemm.set_out(
|
||||||
|
out.dtype(),
|
||||||
|
false, // out_transposed
|
||||||
|
mat_M, // out_rows
|
||||||
|
mat_N, // out_cols
|
||||||
|
mat_N * params.groups, // out_ld
|
||||||
|
params.groups, // batch_count
|
||||||
|
mat_N); // batch_stride
|
||||||
|
gemm.run(
|
||||||
|
encoder,
|
||||||
|
out,
|
||||||
|
in_unfolded,
|
||||||
|
wt_reshaped,
|
||||||
|
batch_shape,
|
||||||
|
a_batch_strides,
|
||||||
|
b_batch_strides);
|
||||||
|
}
|
||||||
|
|
||||||
|
void gemm_grouped_conv(
|
||||||
|
cu::CommandEncoder& encoder,
|
||||||
|
const array& in,
|
||||||
|
const array& wt,
|
||||||
|
array& out,
|
||||||
|
const std::vector<int>& strides,
|
||||||
|
const std::vector<int>& padding,
|
||||||
|
const std::vector<int>& kernel_dilation,
|
||||||
|
const std::vector<int>& input_dilation,
|
||||||
|
int groups,
|
||||||
|
bool flip,
|
||||||
|
Stream s) {
|
||||||
|
int conv_ndim = in.ndim() - 2;
|
||||||
|
if (conv_ndim < 1 || conv_ndim > 3) {
|
||||||
|
throw std::runtime_error(
|
||||||
|
fmt::format("[conv] Unsupported gemm_conv for {}D conv.", conv_ndim));
|
||||||
|
}
|
||||||
|
dispatch_1_2_3(conv_ndim, [&](auto ndim_constant) {
|
||||||
|
ConvParams<ndim_constant()> params(
|
||||||
|
in,
|
||||||
|
wt,
|
||||||
|
out,
|
||||||
|
strides,
|
||||||
|
padding,
|
||||||
|
kernel_dilation,
|
||||||
|
input_dilation,
|
||||||
|
groups,
|
||||||
|
flip);
|
||||||
|
gemm_grouped_conv_nd<ndim_constant()>(encoder, in, wt, out, params, s);
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace mlx::core
|
@ -202,6 +202,25 @@ CublasGemm::~CublasGemm() {
|
|||||||
CHECK_CUBLAS_ERROR(cublasLtMatmulDescDestroy(matmul_desc_));
|
CHECK_CUBLAS_ERROR(cublasLtMatmulDescDestroy(matmul_desc_));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void CublasGemm::set_out(
|
||||||
|
Dtype dtype,
|
||||||
|
bool transposed,
|
||||||
|
uint64_t rows,
|
||||||
|
uint64_t cols,
|
||||||
|
int64_t ld,
|
||||||
|
int32_t batch_count,
|
||||||
|
int64_t batch_stride) {
|
||||||
|
CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutDestroy(out_desc_));
|
||||||
|
out_desc_ = create_matrix_layout(
|
||||||
|
dtype_to_cublas_type(dtype),
|
||||||
|
rows,
|
||||||
|
cols,
|
||||||
|
transposed,
|
||||||
|
ld,
|
||||||
|
batch_count,
|
||||||
|
batch_stride);
|
||||||
|
}
|
||||||
|
|
||||||
void CublasGemm::run(
|
void CublasGemm::run(
|
||||||
cu::CommandEncoder& encoder,
|
cu::CommandEncoder& encoder,
|
||||||
array& out,
|
array& out,
|
||||||
|
@ -44,6 +44,17 @@ class CublasGemm {
|
|||||||
|
|
||||||
~CublasGemm();
|
~CublasGemm();
|
||||||
|
|
||||||
|
// The output's descriptor is inferred from inputs by default, use this method
|
||||||
|
// for unusual output.
|
||||||
|
void set_out(
|
||||||
|
Dtype dtype,
|
||||||
|
bool transposed,
|
||||||
|
uint64_t rows,
|
||||||
|
uint64_t cols,
|
||||||
|
int64_t ld,
|
||||||
|
int32_t batch_count,
|
||||||
|
int64_t batch_stride);
|
||||||
|
|
||||||
void run(
|
void run(
|
||||||
cu::CommandEncoder& encoder,
|
cu::CommandEncoder& encoder,
|
||||||
array& out,
|
array& out,
|
||||||
|
@ -15,14 +15,6 @@ cuda_skip = {
|
|||||||
"TestOps.test_hadamard_grad_vmap",
|
"TestOps.test_hadamard_grad_vmap",
|
||||||
# Convolutions NYI
|
# Convolutions NYI
|
||||||
"TestConv.test_1d_conv_with_2d",
|
"TestConv.test_1d_conv_with_2d",
|
||||||
"TestConv.test_conv_1d_groups_flipped",
|
|
||||||
"TestConv.test_conv_general_flip_grad",
|
|
||||||
"TestConv.test_torch_conv_2D",
|
|
||||||
"TestConv.test_torch_conv_depthwise",
|
|
||||||
"TestConv.test_torch_conv_general",
|
|
||||||
"TestConvTranspose.test_torch_conv_transpose_1D_grad",
|
|
||||||
"TestConvTranspose.test_torch_conv_transpose_2D_grad",
|
|
||||||
"TestConvTranspose.test_torch_conv_transpose_3D_grad",
|
|
||||||
# FFTs NYI
|
# FFTs NYI
|
||||||
"TestFFT.test_fft",
|
"TestFFT.test_fft",
|
||||||
"TestFFT.test_fft_big_powers_of_two",
|
"TestFFT.test_fft_big_powers_of_two",
|
||||||
|
Loading…
Reference in New Issue
Block a user