mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-01 12:49:44 +08:00
[CUDA] Fix conv grads with groups
This commit is contained in:
@@ -4,6 +4,7 @@
|
||||
#include "mlx/backend/cuda/device/config.h"
|
||||
#include "mlx/backend/cuda/lru_cache.h"
|
||||
#include "mlx/backend/gpu/copy.h"
|
||||
#include "mlx/backend/gpu/reshape.h"
|
||||
#include "mlx/dtype_utils.h"
|
||||
#include "mlx/primitives.h"
|
||||
|
||||
@@ -336,6 +337,42 @@ std::optional<cudnn_frontend::OperationGraph> build_op_graph(
|
||||
}
|
||||
}
|
||||
|
||||
// Transpose from (C_out, H, W, C_in / groups) to (C_in, H, W, C_out / groups).
|
||||
array group_transpose(
|
||||
const array& x,
|
||||
int groups,
|
||||
int group_dim,
|
||||
int axis1,
|
||||
int axis2,
|
||||
Stream s) {
|
||||
if (groups == 1) {
|
||||
return swapaxes_in_eval(x, axis1, axis2);
|
||||
}
|
||||
int ndim = x.ndim();
|
||||
if (group_dim < 0) {
|
||||
group_dim += ndim;
|
||||
}
|
||||
if (axis1 < 0) {
|
||||
axis1 += ndim;
|
||||
}
|
||||
if (axis2 < 0) {
|
||||
axis2 += ndim;
|
||||
}
|
||||
if (group_dim <= axis1) {
|
||||
axis1 += 1;
|
||||
}
|
||||
if (group_dim <= axis2) {
|
||||
axis2 += 1;
|
||||
}
|
||||
auto shape = x.shape();
|
||||
shape.insert(shape.begin() + group_dim, groups);
|
||||
shape[group_dim + 1] = shape[group_dim + 1] / groups;
|
||||
array x_trans = reshape_in_eval(x, std::move(shape), s);
|
||||
x_trans = swapaxes_in_eval(x_trans, axis1, axis2);
|
||||
x_trans = flatten_in_eval(x_trans, group_dim, group_dim + 1, s);
|
||||
return x_trans;
|
||||
}
|
||||
|
||||
// Do necessary transposes and copies to prepare the inputs and outputs for
|
||||
// building the cuDNN conv op. It is safe to be called multiple times in one
|
||||
// eval_gpu, with cost of possible redundant copies.
|
||||
@@ -345,13 +382,14 @@ std::tuple<array, array, array> prepare_args(
|
||||
array in,
|
||||
array wt,
|
||||
array out,
|
||||
int groups,
|
||||
Stream s) {
|
||||
// Transpose the args depending on the backend type.
|
||||
// TODO: Handle groups.
|
||||
if (backend_type == CONV_BACKWARD_INPUT) {
|
||||
wt = swapaxes_in_eval(wt, 0, -1);
|
||||
wt = group_transpose(wt, groups, 0, 0, -1, s);
|
||||
} else if (backend_type == CONV_BACKWARD_WEIGHT) {
|
||||
in = swapaxes_in_eval(in, 0, -1);
|
||||
in = group_transpose(in, groups, -1, 0, -1, s);
|
||||
wt = swapaxes_in_eval(wt, 0, -1);
|
||||
// Create a contiguous array that shares the data with |out|, but with dim
|
||||
// C_in and C_out swapped.
|
||||
@@ -457,7 +495,8 @@ void Convolution::eval_gpu(const std::vector<array>& inputs, array& out_) {
|
||||
get_alignment(out)};
|
||||
if (auto it = conv_cache().find(cache_key); it != conv_cache().end()) {
|
||||
auto& [backend_type, plan] = it->second;
|
||||
std::tie(in, wt, out) = prepare_args(encoder, backend_type, in, wt, out, s);
|
||||
std::tie(in, wt, out) =
|
||||
prepare_args(encoder, backend_type, in, wt, out, groups_, s);
|
||||
register_args(encoder, backend_type, in, wt, out, out_);
|
||||
auto [x, w, y] = dispatch_args(backend_type, in, wt, out);
|
||||
if (!execute_plan(encoder, plan, x, w, y)) {
|
||||
@@ -490,7 +529,7 @@ void Convolution::eval_gpu(const std::vector<array>& inputs, array& out_) {
|
||||
std::optional<cudnn_frontend::OperationGraph> op_graph;
|
||||
for (auto try_backend : try_backends) {
|
||||
auto [in_copy, wt_copy, out_copy] =
|
||||
prepare_args(encoder, try_backend, in, wt, out, s);
|
||||
prepare_args(encoder, try_backend, in, wt, out, groups_, s);
|
||||
auto [x, w, y] = dispatch_args(try_backend, in_copy, wt_copy, out_copy);
|
||||
auto [stride, padding_lo, padding_hi, dilation] = get_conv_op_settings(
|
||||
try_backend,
|
||||
|
@@ -1,9 +1,9 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#include "mlx/backend/common/utils.h"
|
||||
#include "mlx/backend/cuda/device.h"
|
||||
#include "mlx/backend/cuda/kernel_utils.cuh"
|
||||
#include "mlx/backend/gpu/copy.h"
|
||||
#include "mlx/backend/gpu/reshape.h"
|
||||
#include "mlx/dtype_utils.h"
|
||||
#include "mlx/primitives.h"
|
||||
|
||||
|
@@ -17,7 +17,6 @@ cuda_skip = {
|
||||
"TestConv.test_1d_conv_with_2d",
|
||||
"TestConv.test_conv_1d_groups_flipped",
|
||||
"TestConv.test_conv_general_flip_grad",
|
||||
"TestConv.test_conv_groups_grad",
|
||||
"TestConv.test_torch_conv_2D",
|
||||
"TestConv.test_torch_conv_depthwise",
|
||||
"TestConv.test_torch_conv_general",
|
||||
|
Reference in New Issue
Block a user