[CUDA] Backward convolution (#2431)

This commit is contained in:
Cheng 2025-08-01 09:54:05 +09:00 committed by GitHub
parent 8b25ce62d5
commit 86c6a15571
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 315 additions and 112 deletions

View File

@ -228,4 +228,31 @@ std::pair<Dims, Dims> get_grid_and_block_common(int dim0, int dim1, int dim2) {
std::make_tuple(gx, gy, gz), std::make_tuple(bx, by, bz)); std::make_tuple(gx, gy, gz), std::make_tuple(bx, by, bz));
} }
array swapaxes_in_eval(const array& x, int axis1, int axis2) {
int ndim = x.ndim();
if (axis1 < 0) {
axis1 += ndim;
}
if (axis2 < 0) {
axis2 += ndim;
}
auto shape = x.shape();
std::swap(shape[axis1], shape[axis2]);
auto strides = x.strides();
std::swap(strides[axis1], strides[axis2]);
auto [data_size, row_contiguous, col_contiguous] =
check_contiguity(shape, strides);
bool contiguous = data_size == x.data_size();
array out(std::move(shape), x.dtype(), nullptr, {});
out.copy_shared_buffer(
x,
std::move(strides),
{contiguous, row_contiguous, col_contiguous},
x.data_size());
return out;
}
} // namespace mlx::core } // namespace mlx::core

View File

@ -196,6 +196,9 @@ void shared_buffer_reshape(
const Strides& out_strides, const Strides& out_strides,
array& out); array& out);
// Like the swapaxes op but safe to call in eval_gpu.
array swapaxes_in_eval(const array& x, int axis1, int axis2);
template <typename T> template <typename T>
inline std::vector<T> remove_index(std::vector<T> vec, size_t index) { inline std::vector<T> remove_index(std::vector<T> vec, size_t index) {
vec.erase(std::next(vec.begin(), index)); vec.erase(std::next(vec.begin(), index));

View File

@ -16,7 +16,6 @@
#include <nvtx3/nvtx3.hpp> #include <nvtx3/nvtx3.hpp>
#include <cassert> #include <cassert>
#include <numeric>
namespace mlx::core { namespace mlx::core {
@ -25,25 +24,34 @@ namespace {
// Not all engines support it so can not use this API now. // Not all engines support it so can not use this API now.
#define MLX_USE_CUDNN_NATIVE_CUDA_GRAPH_API 0 #define MLX_USE_CUDNN_NATIVE_CUDA_GRAPH_API 0
// Alias for better readability.
#define CONV_FORWARD CUDNN_BACKEND_OPERATION_CONVOLUTION_FORWARD_DESCRIPTOR
#define CONV_BACKWARD_INPUT \
CUDNN_BACKEND_OPERATION_CONVOLUTION_BACKWARD_DATA_DESCRIPTOR
#define CONV_BACKWARD_WEIGHT \
CUDNN_BACKEND_OPERATION_CONVOLUTION_BACKWARD_FILTER_DESCRIPTOR
struct ConvCacheKey { struct ConvCacheKey {
int device_id; int device_id;
cudnnBackendDescriptorType_t backend_type; cudnnDataType_t cudnn_dtype;
cudnnDataType_t cudnn_type;
std::array<int, MAX_NDIM> input_shape; std::array<int, MAX_NDIM> input_shape;
std::array<int, MAX_NDIM> filter_shape; std::array<int, MAX_NDIM> weight_shape;
std::array<int, MAX_NDIM> stride;
std::array<int, MAX_NDIM> padding_lo; std::array<int, MAX_NDIM> padding_lo;
std::array<int, MAX_NDIM> padding_hi; std::array<int, MAX_NDIM> padding_hi;
std::array<int, MAX_NDIM> stride;
std::array<int, MAX_NDIM> dilation; std::array<int, MAX_NDIM> dilation;
int groups; int groups;
bool flip;
uint8_t input_alignment; uint8_t input_alignment;
uint8_t filter_alignment; uint8_t weight_alignment;
uint8_t output_alignment; uint8_t output_alignment;
}; };
auto& conv_cache() { auto& conv_cache() {
static LRUBytesKeyCache<ConvCacheKey, cudnn_frontend::ExecutionPlan> cache( static LRUBytesKeyCache<
/* capacity */ 128); ConvCacheKey,
std::pair<cudnnBackendDescriptorType_t, cudnn_frontend::ExecutionPlan>>
cache(/* capacity */ 128);
return cache; return cache;
} }
@ -162,17 +170,17 @@ cudnn_frontend::EngineConfigList get_engine_configs(
bool execute_plan( bool execute_plan(
cu::CommandEncoder& encoder, cu::CommandEncoder& encoder,
cudnn_frontend::ExecutionPlan& plan, cudnn_frontend::ExecutionPlan& plan,
const array& in, array& x,
const array& wt, array& w,
array& out) { array& y) {
int workspace_size = plan.getWorkspaceSize(); int workspace_size = plan.getWorkspaceSize();
array workspace(allocator::malloc(workspace_size), {workspace_size}, uint8); array workspace(allocator::malloc(workspace_size), {workspace_size}, uint8);
int64_t uids[3] = {'x', 'w', 'y'}; int64_t uids[3] = {'x', 'w', 'y'};
void* data_ptrs[3] = { void* data_ptrs[3] = {
const_cast<void*>(in.data<void>()), x.data<void>(),
const_cast<void*>(wt.data<void>()), w.data<void>(),
out.data<void>(), y.data<void>(),
}; };
auto variantPack = cudnn_frontend::VariantPackBuilder() auto variantPack = cudnn_frontend::VariantPackBuilder()
@ -212,46 +220,154 @@ bool execute_plan(
bool try_engines( bool try_engines(
cu::CommandEncoder& encoder, cu::CommandEncoder& encoder,
cudnn_frontend::EngineConfigList& configs,
const ConvCacheKey& cache_key, const ConvCacheKey& cache_key,
cudnnBackendDescriptorType_t backend_type,
cudnn_frontend::EngineConfigList& configs,
const std::string& op_graph_tag, const std::string& op_graph_tag,
const array& in, array& x,
const array& wt, array& w,
array& out) { array& y) {
for (auto& config : configs) { for (auto& config : configs) {
try { try {
auto plan = cudnn_frontend::ExecutionPlanBuilder() auto plan = cudnn_frontend::ExecutionPlanBuilder()
.setHandle(encoder.device().cudnn_handle()) .setHandle(encoder.device().cudnn_handle())
.setEngineConfig(config, op_graph_tag) .setEngineConfig(config, op_graph_tag)
.build(); .build();
if (execute_plan(encoder, plan, in, wt, out)) { if (execute_plan(encoder, plan, x, w, y)) {
conv_cache().emplace(cache_key, std::move(plan)); conv_cache().emplace(
cache_key, std::make_pair(backend_type, std::move(plan)));
return true; return true;
} }
} catch (cudnn_frontend::cudnnException&) { } catch (cudnn_frontend::cudnnException& error) {
if (error.getCudnnStatus() != CUDNN_STATUS_NOT_SUPPORTED) {
throw;
}
} }
} }
return false; return false;
} }
} // namespace auto get_conv_op_settings(
cudnnBackendDescriptorType_t backend_type,
array& x,
array& w,
array& y,
const std::vector<int>& kernel_strides,
const std::vector<int>& padding_lo_,
const std::vector<int>& padding_hi_,
const std::vector<int>& kernel_dilation,
const std::vector<int>& input_dilation) {
auto padding_lo = convert_vector<int64_t>(padding_lo_);
auto padding_hi = convert_vector<int64_t>(padding_hi_);
void Convolution::eval_gpu(const std::vector<array>& inputs, array& out) { if (backend_type == CONV_BACKWARD_INPUT) {
nvtx3::scoped_range r("Convolution::eval_gpu"); for (int i = 0; i < padding_lo.size(); ++i) {
if (out.size() == 0) { int wt_size = 1 + kernel_dilation[i] * (w.shape(1 + i) - 1);
return; padding_lo[i] = wt_size - padding_lo[i] - 1;
int in_size = 1 + kernel_strides[i] * (x.shape(1 + i) - 1);
int out_size = 1 + input_dilation[i] * (y.shape(1 + i) - 1);
padding_hi[i] = out_size - in_size + padding_hi[i];
}
return std::make_tuple(
convert_vector<int64_t>(input_dilation),
std::move(padding_lo),
std::move(padding_hi),
convert_vector<int64_t>(kernel_dilation));
} else if (backend_type == CONV_BACKWARD_WEIGHT) {
padding_hi = padding_lo;
return std::make_tuple(
convert_vector<int64_t>(kernel_dilation),
std::move(padding_lo),
std::move(padding_hi),
convert_vector<int64_t>(kernel_strides));
} else {
return std::make_tuple(
convert_vector<int64_t>(kernel_strides),
std::move(padding_lo),
std::move(padding_hi),
convert_vector<int64_t>(kernel_dilation));
}
}
std::optional<cudnn_frontend::OperationGraph> build_op_graph(
cu::CommandEncoder& encoder,
cudnnBackendDescriptorType_t backend_type,
Dtype dtype,
array& x,
array& w,
array& y,
const std::vector<int64_t>& stride,
const std::vector<int64_t>& padding_lo,
const std::vector<int64_t>& padding_hi,
const std::vector<int64_t>& dilation) {
try {
auto compute_dtype = (dtype == float16 || dtype == bfloat16)
? CUDNN_DATA_FLOAT
: dtype_to_cudnn_type(dtype);
auto conv_desc = cudnn_frontend::ConvDescBuilder()
.setDataType(compute_dtype)
.setMathMode(CUDNN_CROSS_CORRELATION)
.setNDims(stride.size())
.setStrides(stride.size(), stride.data())
.setPrePadding(padding_lo.size(), padding_lo.data())
.setPostPadding(padding_hi.size(), padding_hi.data())
.setDilation(dilation.size(), dilation.data())
.build();
auto op = cudnn_frontend::OperationBuilder(backend_type)
.setxDesc(build_tensor('x', x))
.setwDesc(build_tensor('w', w))
.setyDesc(build_tensor('y', y))
.setcDesc(conv_desc)
.build();
std::array<cudnn_frontend::Operation const*, 1> ops = {&op};
return cudnn_frontend::OperationGraphBuilder()
.setHandle(encoder.device().cudnn_handle())
.setOperationGraph(ops.size(), ops.data())
.build();
} catch (cudnn_frontend::cudnnException& error) {
if (error.getCudnnStatus() != CUDNN_STATUS_BAD_PARAM) {
throw;
}
return std::nullopt;
}
}
// Do necessary transposes and copies to prepare the inputs and outputs for
// building the cuDNN conv op. It is safe to be called multiple times in one
// eval_gpu, with cost of possible redundant copies.
std::tuple<array, array, array> prepare_args(
cu::CommandEncoder& encoder,
cudnnBackendDescriptorType_t backend_type,
array in,
array wt,
array out,
Stream s) {
// Transpose the args depending on the backend type.
// TODO: Handle groups.
if (backend_type == CONV_BACKWARD_INPUT) {
wt = swapaxes_in_eval(wt, 0, -1);
} else if (backend_type == CONV_BACKWARD_WEIGHT) {
in = swapaxes_in_eval(in, 0, -1);
wt = swapaxes_in_eval(wt, 0, -1);
// Create a contiguous array that shares the data with |out|, but with dim
// C_in and C_out swapped.
Shape shape(out.shape());
std::swap(shape.front(), shape.back());
Strides strides(shape.size(), 1);
for (int i = shape.size() - 2; i >= 0; --i) {
strides[i] = shape[i + 1] * strides[i + 1];
}
array intermediate(std::move(shape), out.dtype(), nullptr, {});
intermediate.copy_shared_buffer(
out, std::move(strides), {true, true, false}, out.data_size());
out = intermediate;
} }
assert(inputs.size() == 2);
array in = inputs[0];
array wt = inputs[1];
out.set_data(allocator::malloc(out.nbytes()));
auto& s = stream();
auto& encoder = cu::get_command_encoder(s);
// cuDNN requires contiguous input. // cuDNN requires contiguous input.
// TODO: Handle NCHW format specially.
if (!in.flags().row_contiguous) { if (!in.flags().row_contiguous) {
in = contiguous_copy_gpu(in, s); in = contiguous_copy_gpu(in, s);
encoder.add_temporary(in); encoder.add_temporary(in);
@ -261,80 +377,170 @@ void Convolution::eval_gpu(const std::vector<array>& inputs, array& out) {
encoder.add_temporary(wt); encoder.add_temporary(wt);
} }
return {std::move(in), std::move(wt), std::move(out)};
}
// Get the x/w/y args from the in/wt/out args depending on backend type.
inline std::tuple<array&, array&, array&> dispatch_args(
cudnnBackendDescriptorType_t backend_type,
array& in,
array& wt,
array& out) {
switch (backend_type) {
case CONV_BACKWARD_INPUT:
return {out, wt, in};
case CONV_BACKWARD_WEIGHT:
return {in, out, wt};
default:
return {in, wt, out};
}
}
// Register inputs and outputs before actually running conv op. Can only be
// called once per eval_gpu.
void register_args(
cu::CommandEncoder& encoder,
cudnnBackendDescriptorType_t backend_type,
array& in,
array& wt,
array& intermediate_out,
array& final_out) {
encoder.set_input_array(in); encoder.set_input_array(in);
encoder.set_input_array(wt); encoder.set_input_array(wt);
encoder.set_output_array(out); encoder.set_output_array(final_out);
auto backend_type = CUDNN_BACKEND_OPERATION_CONVOLUTION_FORWARD_DESCRIPTOR; if (backend_type == CONV_BACKWARD_WEIGHT) {
auto cudnn_type = dtype_to_cudnn_type(in.dtype()); // Turn |out| into a strided array, which will have C_in and C_out swapped
// in vjp and the final |grad_weight| will then be contiguous.
Strides strides = intermediate_out.strides();
std::swap(strides.front(), strides.back());
final_out.copy_shared_buffer(
intermediate_out,
std::move(strides),
{false, false, false},
intermediate_out.data_size());
}
}
} // namespace
void Convolution::eval_gpu(const std::vector<array>& inputs, array& out_) {
nvtx3::scoped_range r("Convolution::eval_gpu");
if (out_.size() == 0) {
return;
}
assert(inputs.size() == 2);
array in = inputs[0];
array wt = inputs[1];
array out = out_;
out.set_data(allocator::malloc(out.nbytes()));
Dtype dtype = out.dtype();
auto& s = stream();
auto& encoder = cu::get_command_encoder(s);
// Search cache. // Search cache.
ConvCacheKey cache_key{ ConvCacheKey cache_key{
encoder.device().cuda_device(), encoder.device().cuda_device(),
backend_type, dtype_to_cudnn_type(dtype),
cudnn_type,
fixed_vector(in.shape()), fixed_vector(in.shape()),
fixed_vector(wt.shape()), fixed_vector(wt.shape()),
fixed_vector(kernel_strides_),
fixed_vector(padding_lo_), fixed_vector(padding_lo_),
fixed_vector(padding_hi_), fixed_vector(padding_hi_),
fixed_vector(kernel_strides_),
fixed_vector(kernel_dilation_), fixed_vector(kernel_dilation_),
groups_, groups_,
flip_,
get_alignment(in), get_alignment(in),
get_alignment(wt), get_alignment(wt),
get_alignment(out)}; get_alignment(out)};
if (auto it = conv_cache().find(cache_key); it != conv_cache().end()) { if (auto it = conv_cache().find(cache_key); it != conv_cache().end()) {
if (!execute_plan(encoder, it->second, in, wt, out)) { auto& [backend_type, plan] = it->second;
throw std::runtime_error("Cached convolution plan failed to execute."); std::tie(in, wt, out) = prepare_args(encoder, backend_type, in, wt, out, s);
register_args(encoder, backend_type, in, wt, out, out_);
auto [x, w, y] = dispatch_args(backend_type, in, wt, out);
if (!execute_plan(encoder, plan, x, w, y)) {
throw std::runtime_error("[conv] Cached plan failed to execute.");
} }
return; return;
} }
// Build operation graph. // There is no reliable way to deduce the proper cuDNN backend for the
auto compute_data_type = (in.dtype() == float16 || in.dtype() == bfloat16) // convolution, so we make a best guess and then try.
? CUDNN_DATA_FLOAT std::vector<cudnnBackendDescriptorType_t> try_backends;
: cudnn_type; if (flip_) {
// When weight is flipped, we assume it is backward input convolution.
try_backends.push_back(CONV_BACKWARD_INPUT);
} else {
// Otherwise it could be backward weight convolution or forward convolution,
// mathematically there is no difference so we have to use heuristics.
// Empirically backward convolutions have large kernel dimensions, and
// usually have |in| and |wt| transposed.
if (!in.flags().row_contiguous && !wt.flags().row_contiguous &&
wt.shape(2) > out.shape(2)) {
try_backends = {CONV_BACKWARD_WEIGHT, CONV_FORWARD};
} else {
try_backends = {CONV_FORWARD, CONV_BACKWARD_WEIGHT};
}
}
auto stride = convert_vector<int64_t>(kernel_strides_); // Try to build op graph.
auto padding_lo = convert_vector<int64_t>(padding_lo_); cudnnBackendDescriptorType_t backend_type;
auto padding_hi = convert_vector<int64_t>(padding_hi_); std::optional<cudnn_frontend::OperationGraph> op_graph;
auto dilation = convert_vector<int64_t>(kernel_dilation_); for (auto try_backend : try_backends) {
auto [in_copy, wt_copy, out_copy] =
prepare_args(encoder, try_backend, in, wt, out, s);
auto [x, w, y] = dispatch_args(try_backend, in_copy, wt_copy, out_copy);
auto [stride, padding_lo, padding_hi, dilation] = get_conv_op_settings(
try_backend,
x,
w,
y,
kernel_strides_,
padding_lo_,
padding_hi_,
kernel_dilation_,
input_dilation_);
op_graph = build_op_graph(
encoder,
try_backend,
dtype,
x,
w,
y,
stride,
padding_lo,
padding_hi,
dilation);
if (op_graph) {
backend_type = try_backend;
in = std::move(in_copy);
wt = std::move(wt_copy);
out = std::move(out_copy);
break;
}
}
if (!op_graph) {
throw std::runtime_error("[conv] Can not build op graph.");
}
auto conv_desc = cudnn_frontend::ConvDescBuilder() // Get ready to execute the graph.
.setDataType(compute_data_type) register_args(encoder, backend_type, in, wt, out, out_);
.setMathMode(CUDNN_CROSS_CORRELATION)
.setNDims(stride.size())
.setStrides(stride.size(), stride.data())
.setPrePadding(padding_lo.size(), padding_lo.data())
.setPostPadding(padding_hi.size(), padding_hi.data())
.setDilation(dilation.size(), dilation.data())
.build();
auto op = cudnn_frontend::OperationBuilder(backend_type)
.setxDesc(build_tensor('x', in))
.setwDesc(build_tensor('w', wt))
.setyDesc(build_tensor('y', out))
.setcDesc(conv_desc)
.build();
std::array<cudnn_frontend::Operation const*, 1> ops = {&op};
auto op_graph = cudnn_frontend::OperationGraphBuilder()
.setHandle(encoder.device().cudnn_handle())
.setOperationGraph(ops.size(), ops.data())
.build();
// Try to run plans based on heuristics. // Try to run plans based on heuristics.
auto configs = get_engine_configs(backend_type, in.dtype(), op_graph); auto configs = get_engine_configs(backend_type, dtype, *op_graph);
auto op_graph_tag = op_graph.getTag(); auto tag = op_graph->getTag();
if (try_engines(encoder, configs, cache_key, op_graph_tag, in, wt, out)) { auto [x, w, y] = dispatch_args(backend_type, in, wt, out);
if (try_engines(encoder, cache_key, backend_type, configs, tag, x, w, y)) {
return; return;
} }
// Then try fallback plans. // Then try fallback plans.
configs = get_engine_configs(backend_type, in.dtype(), op_graph); configs = get_engine_configs(backend_type, dtype, *op_graph);
if (try_engines(encoder, configs, cache_key, op_graph_tag, in, wt, out)) { if (try_engines(encoder, cache_key, backend_type, configs, tag, x, w, y)) {
return; return;
} }
throw std::runtime_error("Unable to find an engine for convolution."); throw std::runtime_error("[conv] Unable to find a working engine.");
} }
} // namespace mlx::core } // namespace mlx::core

View File

@ -13,7 +13,6 @@
#include <cub/device/device_segmented_sort.cuh> #include <cub/device/device_segmented_sort.cuh>
#include <cassert> #include <cassert>
#include <numeric>
namespace mlx::core { namespace mlx::core {
@ -27,29 +26,6 @@ struct ModOp {
} }
}; };
// We can not use any op in eval, make an utility.
array swapaxes_in_eval(const array& in, int axis1, int axis2) {
std::vector<int> axes(in.ndim());
std::iota(axes.begin(), axes.end(), 0);
std::swap(axes[axis1], axes[axis2]);
// TODO: Share the code with Transpose::eval.
Shape shape(axes.size());
Strides strides(in.ndim());
for (size_t ax = 0; ax < axes.size(); ++ax) {
shape[ax] = in.shape()[axes[ax]];
strides[ax] = in.strides()[axes[ax]];
}
auto flags = in.flags();
if (flags.contiguous) {
auto [_, row_contiguous, col_contiguous] = check_contiguity(shape, strides);
flags.row_contiguous = row_contiguous;
flags.col_contiguous = col_contiguous;
}
array out(shape, in.dtype(), nullptr, {});
out.copy_shared_buffer(in, strides, flags, in.data_size());
return out;
}
struct OffsetTransform { struct OffsetTransform {
int nsort; int nsort;

View File

@ -18,21 +18,12 @@ cuda_skip = {
"TestConv.test_conv_1d_groups_flipped", "TestConv.test_conv_1d_groups_flipped",
"TestConv.test_conv_general_flip_grad", "TestConv.test_conv_general_flip_grad",
"TestConv.test_conv_groups_grad", "TestConv.test_conv_groups_grad",
"TestConv.test_torch_conv_1D_grad",
"TestConv.test_torch_conv_2D", "TestConv.test_torch_conv_2D",
"TestConv.test_torch_conv_2D_grad",
"TestConv.test_torch_conv_3D_grad",
"TestConv.test_torch_conv_depthwise", "TestConv.test_torch_conv_depthwise",
"TestConv.test_torch_conv_general", "TestConv.test_torch_conv_general",
"TestConvTranspose.test_torch_conv_tranpose_1d_output_padding",
"TestConvTranspose.test_torch_conv_transpose_1D",
"TestConvTranspose.test_torch_conv_transpose_1D_grad", "TestConvTranspose.test_torch_conv_transpose_1D_grad",
"TestConvTranspose.test_torch_conv_transpose_2D",
"TestConvTranspose.test_torch_conv_transpose_2D_grad", "TestConvTranspose.test_torch_conv_transpose_2D_grad",
"TestConvTranspose.test_torch_conv_transpose_2d_output_padding",
"TestConvTranspose.test_torch_conv_transpose_3D",
"TestConvTranspose.test_torch_conv_transpose_3D_grad", "TestConvTranspose.test_torch_conv_transpose_3D_grad",
"TestConvTranspose.test_torch_conv_transpose_3d_output_padding",
# FFTs NYI # FFTs NYI
"TestFFT.test_fft", "TestFFT.test_fft",
"TestFFT.test_fft_big_powers_of_two", "TestFFT.test_fft_big_powers_of_two",