mirror of
https://github.com/ml-explore/mlx.git
synced 2025-08-20 10:27:41 +08:00
[CUDA] Fix conv grads with groups (#2495)
* Put reshape utils in one file * [CUDA] Fix conv grads with groups * Put the reshape utils in gpu/copy.h
This commit is contained in:
parent
37b440faa8
commit
1ba18ff7d9
@ -228,31 +228,4 @@ std::pair<Dims, Dims> get_grid_and_block_common(int dim0, int dim1, int dim2) {
|
|||||||
std::make_tuple(gx, gy, gz), std::make_tuple(bx, by, bz));
|
std::make_tuple(gx, gy, gz), std::make_tuple(bx, by, bz));
|
||||||
}
|
}
|
||||||
|
|
||||||
array swapaxes_in_eval(const array& x, int axis1, int axis2) {
|
|
||||||
int ndim = x.ndim();
|
|
||||||
if (axis1 < 0) {
|
|
||||||
axis1 += ndim;
|
|
||||||
}
|
|
||||||
if (axis2 < 0) {
|
|
||||||
axis2 += ndim;
|
|
||||||
}
|
|
||||||
|
|
||||||
auto shape = x.shape();
|
|
||||||
std::swap(shape[axis1], shape[axis2]);
|
|
||||||
auto strides = x.strides();
|
|
||||||
std::swap(strides[axis1], strides[axis2]);
|
|
||||||
|
|
||||||
auto [data_size, row_contiguous, col_contiguous] =
|
|
||||||
check_contiguity(shape, strides);
|
|
||||||
bool contiguous = data_size == x.data_size();
|
|
||||||
|
|
||||||
array out(std::move(shape), x.dtype(), nullptr, {});
|
|
||||||
out.copy_shared_buffer(
|
|
||||||
x,
|
|
||||||
std::move(strides),
|
|
||||||
{contiguous, row_contiguous, col_contiguous},
|
|
||||||
x.data_size());
|
|
||||||
return out;
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace mlx::core
|
} // namespace mlx::core
|
||||||
|
@ -196,9 +196,6 @@ void shared_buffer_reshape(
|
|||||||
const Strides& out_strides,
|
const Strides& out_strides,
|
||||||
array& out);
|
array& out);
|
||||||
|
|
||||||
// Like the swapaxes op but safe to call in eval_gpu.
|
|
||||||
array swapaxes_in_eval(const array& x, int axis1, int axis2);
|
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
inline SmallVector<T> remove_index(SmallVector<T> vec, size_t index) {
|
inline SmallVector<T> remove_index(SmallVector<T> vec, size_t index) {
|
||||||
vec.erase(std::next(vec.begin(), index));
|
vec.erase(std::next(vec.begin(), index));
|
||||||
|
@ -336,6 +336,42 @@ std::optional<cudnn_frontend::OperationGraph> build_op_graph(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Transpose from (C_out, H, W, C_in / groups) to (C_in, H, W, C_out / groups).
|
||||||
|
array group_transpose(
|
||||||
|
const array& x,
|
||||||
|
int groups,
|
||||||
|
int group_dim,
|
||||||
|
int axis1,
|
||||||
|
int axis2,
|
||||||
|
Stream s) {
|
||||||
|
if (groups == 1) {
|
||||||
|
return swapaxes_in_eval(x, axis1, axis2);
|
||||||
|
}
|
||||||
|
int ndim = x.ndim();
|
||||||
|
if (group_dim < 0) {
|
||||||
|
group_dim += ndim;
|
||||||
|
}
|
||||||
|
if (axis1 < 0) {
|
||||||
|
axis1 += ndim;
|
||||||
|
}
|
||||||
|
if (axis2 < 0) {
|
||||||
|
axis2 += ndim;
|
||||||
|
}
|
||||||
|
if (group_dim <= axis1) {
|
||||||
|
axis1 += 1;
|
||||||
|
}
|
||||||
|
if (group_dim <= axis2) {
|
||||||
|
axis2 += 1;
|
||||||
|
}
|
||||||
|
auto shape = x.shape();
|
||||||
|
shape.insert(shape.begin() + group_dim, groups);
|
||||||
|
shape[group_dim + 1] = shape[group_dim + 1] / groups;
|
||||||
|
array x_trans = reshape_in_eval(x, std::move(shape), s);
|
||||||
|
x_trans = swapaxes_in_eval(x_trans, axis1, axis2);
|
||||||
|
x_trans = flatten_in_eval(x_trans, group_dim, group_dim + 1, s);
|
||||||
|
return x_trans;
|
||||||
|
}
|
||||||
|
|
||||||
// Do necessary transposes and copies to prepare the inputs and outputs for
|
// Do necessary transposes and copies to prepare the inputs and outputs for
|
||||||
// building the cuDNN conv op. It is safe to be called multiple times in one
|
// building the cuDNN conv op. It is safe to be called multiple times in one
|
||||||
// eval_gpu, with cost of possible redundant copies.
|
// eval_gpu, with cost of possible redundant copies.
|
||||||
@ -345,13 +381,14 @@ std::tuple<array, array, array> prepare_args(
|
|||||||
array in,
|
array in,
|
||||||
array wt,
|
array wt,
|
||||||
array out,
|
array out,
|
||||||
|
int groups,
|
||||||
Stream s) {
|
Stream s) {
|
||||||
// Transpose the args depending on the backend type.
|
// Transpose the args depending on the backend type.
|
||||||
// TODO: Handle groups.
|
// TODO: Handle groups.
|
||||||
if (backend_type == CONV_BACKWARD_INPUT) {
|
if (backend_type == CONV_BACKWARD_INPUT) {
|
||||||
wt = swapaxes_in_eval(wt, 0, -1);
|
wt = group_transpose(wt, groups, 0, 0, -1, s);
|
||||||
} else if (backend_type == CONV_BACKWARD_WEIGHT) {
|
} else if (backend_type == CONV_BACKWARD_WEIGHT) {
|
||||||
in = swapaxes_in_eval(in, 0, -1);
|
in = group_transpose(in, groups, -1, 0, -1, s);
|
||||||
wt = swapaxes_in_eval(wt, 0, -1);
|
wt = swapaxes_in_eval(wt, 0, -1);
|
||||||
// Create a contiguous array that shares the data with |out|, but with dim
|
// Create a contiguous array that shares the data with |out|, but with dim
|
||||||
// C_in and C_out swapped.
|
// C_in and C_out swapped.
|
||||||
@ -457,7 +494,8 @@ void Convolution::eval_gpu(const std::vector<array>& inputs, array& out_) {
|
|||||||
get_alignment(out)};
|
get_alignment(out)};
|
||||||
if (auto it = conv_cache().find(cache_key); it != conv_cache().end()) {
|
if (auto it = conv_cache().find(cache_key); it != conv_cache().end()) {
|
||||||
auto& [backend_type, plan] = it->second;
|
auto& [backend_type, plan] = it->second;
|
||||||
std::tie(in, wt, out) = prepare_args(encoder, backend_type, in, wt, out, s);
|
std::tie(in, wt, out) =
|
||||||
|
prepare_args(encoder, backend_type, in, wt, out, groups_, s);
|
||||||
register_args(encoder, backend_type, in, wt, out, out_);
|
register_args(encoder, backend_type, in, wt, out, out_);
|
||||||
auto [x, w, y] = dispatch_args(backend_type, in, wt, out);
|
auto [x, w, y] = dispatch_args(backend_type, in, wt, out);
|
||||||
if (!execute_plan(encoder, plan, x, w, y)) {
|
if (!execute_plan(encoder, plan, x, w, y)) {
|
||||||
@ -490,7 +528,7 @@ void Convolution::eval_gpu(const std::vector<array>& inputs, array& out_) {
|
|||||||
std::optional<cudnn_frontend::OperationGraph> op_graph;
|
std::optional<cudnn_frontend::OperationGraph> op_graph;
|
||||||
for (auto try_backend : try_backends) {
|
for (auto try_backend : try_backends) {
|
||||||
auto [in_copy, wt_copy, out_copy] =
|
auto [in_copy, wt_copy, out_copy] =
|
||||||
prepare_args(encoder, try_backend, in, wt, out, s);
|
prepare_args(encoder, try_backend, in, wt, out, groups_, s);
|
||||||
auto [x, w, y] = dispatch_args(try_backend, in_copy, wt_copy, out_copy);
|
auto [x, w, y] = dispatch_args(try_backend, in_copy, wt_copy, out_copy);
|
||||||
auto [stride, padding_lo, padding_hi, dilation] = get_conv_op_settings(
|
auto [stride, padding_lo, padding_hi, dilation] = get_conv_op_settings(
|
||||||
try_backend,
|
try_backend,
|
||||||
|
@ -1,6 +1,5 @@
|
|||||||
// Copyright © 2025 Apple Inc.
|
// Copyright © 2025 Apple Inc.
|
||||||
|
|
||||||
#include "mlx/backend/common/utils.h"
|
|
||||||
#include "mlx/backend/cuda/device.h"
|
#include "mlx/backend/cuda/device.h"
|
||||||
#include "mlx/backend/cuda/kernel_utils.cuh"
|
#include "mlx/backend/cuda/kernel_utils.cuh"
|
||||||
#include "mlx/backend/gpu/copy.h"
|
#include "mlx/backend/gpu/copy.h"
|
||||||
|
@ -52,4 +52,70 @@ array contiguous_copy_gpu(const array& arr, const Stream& s) {
|
|||||||
return arr_copy;
|
return arr_copy;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void reshape_gpu(const array& in, array& out, Stream s) {
|
||||||
|
auto [copy_necessary, out_strides] = prepare_reshape(in, out);
|
||||||
|
if (copy_necessary) {
|
||||||
|
out.set_data(allocator::malloc(out.nbytes()));
|
||||||
|
copy_gpu_inplace(
|
||||||
|
in,
|
||||||
|
out,
|
||||||
|
in.shape(),
|
||||||
|
in.strides(),
|
||||||
|
make_contiguous_strides(in.shape()),
|
||||||
|
0,
|
||||||
|
0,
|
||||||
|
CopyType::General,
|
||||||
|
s);
|
||||||
|
} else {
|
||||||
|
shared_buffer_reshape(in, out_strides, out);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
array flatten_in_eval(const array& x, int start_axis, int end_axis, Stream s) {
|
||||||
|
int ndim = x.ndim();
|
||||||
|
if (start_axis < 0) {
|
||||||
|
start_axis += ndim;
|
||||||
|
}
|
||||||
|
if (end_axis < 0) {
|
||||||
|
end_axis += ndim;
|
||||||
|
}
|
||||||
|
start_axis = std::max(0, start_axis);
|
||||||
|
end_axis = std::min(ndim - 1, end_axis);
|
||||||
|
|
||||||
|
return reshape_in_eval(x, Flatten::output_shape(x, start_axis, end_axis), s);
|
||||||
|
}
|
||||||
|
|
||||||
|
array reshape_in_eval(const array& x, Shape shape, Stream s) {
|
||||||
|
array out(std::move(shape), x.dtype(), nullptr, {});
|
||||||
|
reshape_gpu(x, out, s);
|
||||||
|
return out;
|
||||||
|
}
|
||||||
|
|
||||||
|
array swapaxes_in_eval(const array& x, int axis1, int axis2) {
|
||||||
|
int ndim = x.ndim();
|
||||||
|
if (axis1 < 0) {
|
||||||
|
axis1 += ndim;
|
||||||
|
}
|
||||||
|
if (axis2 < 0) {
|
||||||
|
axis2 += ndim;
|
||||||
|
}
|
||||||
|
|
||||||
|
auto shape = x.shape();
|
||||||
|
std::swap(shape[axis1], shape[axis2]);
|
||||||
|
auto strides = x.strides();
|
||||||
|
std::swap(strides[axis1], strides[axis2]);
|
||||||
|
|
||||||
|
auto [data_size, row_contiguous, col_contiguous] =
|
||||||
|
check_contiguity(shape, strides);
|
||||||
|
bool contiguous = data_size == x.data_size();
|
||||||
|
|
||||||
|
array out(std::move(shape), x.dtype(), nullptr, {});
|
||||||
|
out.copy_shared_buffer(
|
||||||
|
x,
|
||||||
|
std::move(strides),
|
||||||
|
{contiguous, row_contiguous, col_contiguous},
|
||||||
|
x.data_size());
|
||||||
|
return out;
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace mlx::core
|
} // namespace mlx::core
|
||||||
|
@ -46,4 +46,12 @@ void fill_gpu(const array& val, array& out, const Stream& s);
|
|||||||
// Return a contiguous array with same shape that copies the data of |arr|.
|
// Return a contiguous array with same shape that copies the data of |arr|.
|
||||||
array contiguous_copy_gpu(const array& arr, const Stream& s);
|
array contiguous_copy_gpu(const array& arr, const Stream& s);
|
||||||
|
|
||||||
|
// Copy data from |in| and transpose to |out|'s shape.
|
||||||
|
void reshape_gpu(const array& in, array& out, Stream s);
|
||||||
|
|
||||||
|
// Like the normal ops but safe to call in eval_gpu.
|
||||||
|
array flatten_in_eval(const array& x, int start_axis, int end_axis, Stream s);
|
||||||
|
array reshape_in_eval(const array& x, Shape shape, Stream s);
|
||||||
|
array swapaxes_in_eval(const array& x, int axis1, int axis2);
|
||||||
|
|
||||||
} // namespace mlx::core
|
} // namespace mlx::core
|
||||||
|
@ -20,29 +20,6 @@
|
|||||||
|
|
||||||
namespace mlx::core {
|
namespace mlx::core {
|
||||||
|
|
||||||
namespace {
|
|
||||||
|
|
||||||
void reshape(const array& in, array& out, Stream s) {
|
|
||||||
auto [copy_necessary, out_strides] = prepare_reshape(in, out);
|
|
||||||
if (copy_necessary) {
|
|
||||||
out.set_data(allocator::malloc(out.nbytes()));
|
|
||||||
copy_gpu_inplace(
|
|
||||||
in,
|
|
||||||
out,
|
|
||||||
in.shape(),
|
|
||||||
in.strides(),
|
|
||||||
make_contiguous_strides(in.shape()),
|
|
||||||
0,
|
|
||||||
0,
|
|
||||||
CopyType::General,
|
|
||||||
s);
|
|
||||||
} else {
|
|
||||||
shared_buffer_reshape(in, out_strides, out);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace
|
|
||||||
|
|
||||||
void AsStrided::eval_gpu(const std::vector<array>& inputs, array& out) {
|
void AsStrided::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||||
MLX_PROFILER_RANGE("AsStrided::eval_gpu");
|
MLX_PROFILER_RANGE("AsStrided::eval_gpu");
|
||||||
eval(inputs, out);
|
eval(inputs, out);
|
||||||
@ -124,7 +101,7 @@ void Full::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|||||||
|
|
||||||
void Flatten::eval_gpu(const std::vector<array>& inputs, array& out) {
|
void Flatten::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||||
MLX_PROFILER_RANGE("Flatten::eval_gpu");
|
MLX_PROFILER_RANGE("Flatten::eval_gpu");
|
||||||
reshape(inputs[0], out, stream());
|
reshape_gpu(inputs[0], out, stream());
|
||||||
}
|
}
|
||||||
|
|
||||||
void NumberOfElements::eval_gpu(const std::vector<array>& inputs, array& out) {
|
void NumberOfElements::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||||
@ -150,7 +127,7 @@ void Pad::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|||||||
|
|
||||||
void Reshape::eval_gpu(const std::vector<array>& inputs, array& out) {
|
void Reshape::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||||
MLX_PROFILER_RANGE("Reshape::eval_gpu");
|
MLX_PROFILER_RANGE("Reshape::eval_gpu");
|
||||||
reshape(inputs[0], out, stream());
|
reshape_gpu(inputs[0], out, stream());
|
||||||
}
|
}
|
||||||
|
|
||||||
void Split::eval_gpu(
|
void Split::eval_gpu(
|
||||||
@ -224,7 +201,7 @@ void Transpose::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|||||||
|
|
||||||
void Unflatten::eval_gpu(const std::vector<array>& inputs, array& out) {
|
void Unflatten::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||||
MLX_PROFILER_RANGE("Unflatten::eval_gpu");
|
MLX_PROFILER_RANGE("Unflatten::eval_gpu");
|
||||||
reshape(inputs[0], out, stream());
|
reshape_gpu(inputs[0], out, stream());
|
||||||
}
|
}
|
||||||
|
|
||||||
void View::eval_gpu(const std::vector<array>& inputs, array& out) {
|
void View::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||||
|
@ -17,7 +17,6 @@ cuda_skip = {
|
|||||||
"TestConv.test_1d_conv_with_2d",
|
"TestConv.test_1d_conv_with_2d",
|
||||||
"TestConv.test_conv_1d_groups_flipped",
|
"TestConv.test_conv_1d_groups_flipped",
|
||||||
"TestConv.test_conv_general_flip_grad",
|
"TestConv.test_conv_general_flip_grad",
|
||||||
"TestConv.test_conv_groups_grad",
|
|
||||||
"TestConv.test_torch_conv_2D",
|
"TestConv.test_torch_conv_2D",
|
||||||
"TestConv.test_torch_conv_depthwise",
|
"TestConv.test_torch_conv_depthwise",
|
||||||
"TestConv.test_torch_conv_general",
|
"TestConv.test_torch_conv_general",
|
||||||
|
Loading…
Reference in New Issue
Block a user