diff --git a/mlx/backend/common/utils.cpp b/mlx/backend/common/utils.cpp index 4c9e39dc6..ae169e35e 100644 --- a/mlx/backend/common/utils.cpp +++ b/mlx/backend/common/utils.cpp @@ -228,31 +228,4 @@ std::pair get_grid_and_block_common(int dim0, int dim1, int dim2) { std::make_tuple(gx, gy, gz), std::make_tuple(bx, by, bz)); } -array swapaxes_in_eval(const array& x, int axis1, int axis2) { - int ndim = x.ndim(); - if (axis1 < 0) { - axis1 += ndim; - } - if (axis2 < 0) { - axis2 += ndim; - } - - auto shape = x.shape(); - std::swap(shape[axis1], shape[axis2]); - auto strides = x.strides(); - std::swap(strides[axis1], strides[axis2]); - - auto [data_size, row_contiguous, col_contiguous] = - check_contiguity(shape, strides); - bool contiguous = data_size == x.data_size(); - - array out(std::move(shape), x.dtype(), nullptr, {}); - out.copy_shared_buffer( - x, - std::move(strides), - {contiguous, row_contiguous, col_contiguous}, - x.data_size()); - return out; -} - } // namespace mlx::core diff --git a/mlx/backend/common/utils.h b/mlx/backend/common/utils.h index db0da5e10..1b6902ff3 100644 --- a/mlx/backend/common/utils.h +++ b/mlx/backend/common/utils.h @@ -196,9 +196,6 @@ void shared_buffer_reshape( const Strides& out_strides, array& out); -// Like the swapaxes op but safe to call in eval_gpu. -array swapaxes_in_eval(const array& x, int axis1, int axis2); - template inline SmallVector remove_index(SmallVector vec, size_t index) { vec.erase(std::next(vec.begin(), index)); diff --git a/mlx/backend/cuda/conv.cpp b/mlx/backend/cuda/conv.cpp index 1484e8c46..3d7ef60bc 100644 --- a/mlx/backend/cuda/conv.cpp +++ b/mlx/backend/cuda/conv.cpp @@ -336,6 +336,42 @@ std::optional build_op_graph( } } +// Transpose from (C_out, H, W, C_in / groups) to (C_in, H, W, C_out / groups). +array group_transpose( + const array& x, + int groups, + int group_dim, + int axis1, + int axis2, + Stream s) { + if (groups == 1) { + return swapaxes_in_eval(x, axis1, axis2); + } + int ndim = x.ndim(); + if (group_dim < 0) { + group_dim += ndim; + } + if (axis1 < 0) { + axis1 += ndim; + } + if (axis2 < 0) { + axis2 += ndim; + } + if (group_dim <= axis1) { + axis1 += 1; + } + if (group_dim <= axis2) { + axis2 += 1; + } + auto shape = x.shape(); + shape.insert(shape.begin() + group_dim, groups); + shape[group_dim + 1] = shape[group_dim + 1] / groups; + array x_trans = reshape_in_eval(x, std::move(shape), s); + x_trans = swapaxes_in_eval(x_trans, axis1, axis2); + x_trans = flatten_in_eval(x_trans, group_dim, group_dim + 1, s); + return x_trans; +} + // Do necessary transposes and copies to prepare the inputs and outputs for // building the cuDNN conv op. It is safe to be called multiple times in one // eval_gpu, with cost of possible redundant copies. @@ -345,13 +381,14 @@ std::tuple prepare_args( array in, array wt, array out, + int groups, Stream s) { // Transpose the args depending on the backend type. // TODO: Handle groups. if (backend_type == CONV_BACKWARD_INPUT) { - wt = swapaxes_in_eval(wt, 0, -1); + wt = group_transpose(wt, groups, 0, 0, -1, s); } else if (backend_type == CONV_BACKWARD_WEIGHT) { - in = swapaxes_in_eval(in, 0, -1); + in = group_transpose(in, groups, -1, 0, -1, s); wt = swapaxes_in_eval(wt, 0, -1); // Create a contiguous array that shares the data with |out|, but with dim // C_in and C_out swapped. @@ -457,7 +494,8 @@ void Convolution::eval_gpu(const std::vector& inputs, array& out_) { get_alignment(out)}; if (auto it = conv_cache().find(cache_key); it != conv_cache().end()) { auto& [backend_type, plan] = it->second; - std::tie(in, wt, out) = prepare_args(encoder, backend_type, in, wt, out, s); + std::tie(in, wt, out) = + prepare_args(encoder, backend_type, in, wt, out, groups_, s); register_args(encoder, backend_type, in, wt, out, out_); auto [x, w, y] = dispatch_args(backend_type, in, wt, out); if (!execute_plan(encoder, plan, x, w, y)) { @@ -490,7 +528,7 @@ void Convolution::eval_gpu(const std::vector& inputs, array& out_) { std::optional op_graph; for (auto try_backend : try_backends) { auto [in_copy, wt_copy, out_copy] = - prepare_args(encoder, try_backend, in, wt, out, s); + prepare_args(encoder, try_backend, in, wt, out, groups_, s); auto [x, w, y] = dispatch_args(try_backend, in_copy, wt_copy, out_copy); auto [stride, padding_lo, padding_hi, dilation] = get_conv_op_settings( try_backend, diff --git a/mlx/backend/cuda/sort.cu b/mlx/backend/cuda/sort.cu index 5bbd72fd5..e81ae12eb 100644 --- a/mlx/backend/cuda/sort.cu +++ b/mlx/backend/cuda/sort.cu @@ -1,6 +1,5 @@ // Copyright © 2025 Apple Inc. -#include "mlx/backend/common/utils.h" #include "mlx/backend/cuda/device.h" #include "mlx/backend/cuda/kernel_utils.cuh" #include "mlx/backend/gpu/copy.h" diff --git a/mlx/backend/gpu/copy.cpp b/mlx/backend/gpu/copy.cpp index 4556f7d98..472ee486b 100644 --- a/mlx/backend/gpu/copy.cpp +++ b/mlx/backend/gpu/copy.cpp @@ -52,4 +52,70 @@ array contiguous_copy_gpu(const array& arr, const Stream& s) { return arr_copy; } +void reshape_gpu(const array& in, array& out, Stream s) { + auto [copy_necessary, out_strides] = prepare_reshape(in, out); + if (copy_necessary) { + out.set_data(allocator::malloc(out.nbytes())); + copy_gpu_inplace( + in, + out, + in.shape(), + in.strides(), + make_contiguous_strides(in.shape()), + 0, + 0, + CopyType::General, + s); + } else { + shared_buffer_reshape(in, out_strides, out); + } +} + +array flatten_in_eval(const array& x, int start_axis, int end_axis, Stream s) { + int ndim = x.ndim(); + if (start_axis < 0) { + start_axis += ndim; + } + if (end_axis < 0) { + end_axis += ndim; + } + start_axis = std::max(0, start_axis); + end_axis = std::min(ndim - 1, end_axis); + + return reshape_in_eval(x, Flatten::output_shape(x, start_axis, end_axis), s); +} + +array reshape_in_eval(const array& x, Shape shape, Stream s) { + array out(std::move(shape), x.dtype(), nullptr, {}); + reshape_gpu(x, out, s); + return out; +} + +array swapaxes_in_eval(const array& x, int axis1, int axis2) { + int ndim = x.ndim(); + if (axis1 < 0) { + axis1 += ndim; + } + if (axis2 < 0) { + axis2 += ndim; + } + + auto shape = x.shape(); + std::swap(shape[axis1], shape[axis2]); + auto strides = x.strides(); + std::swap(strides[axis1], strides[axis2]); + + auto [data_size, row_contiguous, col_contiguous] = + check_contiguity(shape, strides); + bool contiguous = data_size == x.data_size(); + + array out(std::move(shape), x.dtype(), nullptr, {}); + out.copy_shared_buffer( + x, + std::move(strides), + {contiguous, row_contiguous, col_contiguous}, + x.data_size()); + return out; +} + } // namespace mlx::core diff --git a/mlx/backend/gpu/copy.h b/mlx/backend/gpu/copy.h index f01fe9fda..274250202 100644 --- a/mlx/backend/gpu/copy.h +++ b/mlx/backend/gpu/copy.h @@ -46,4 +46,12 @@ void fill_gpu(const array& val, array& out, const Stream& s); // Return a contiguous array with same shape that copies the data of |arr|. array contiguous_copy_gpu(const array& arr, const Stream& s); +// Copy data from |in| and transpose to |out|'s shape. +void reshape_gpu(const array& in, array& out, Stream s); + +// Like the normal ops but safe to call in eval_gpu. +array flatten_in_eval(const array& x, int start_axis, int end_axis, Stream s); +array reshape_in_eval(const array& x, Shape shape, Stream s); +array swapaxes_in_eval(const array& x, int axis1, int axis2); + } // namespace mlx::core diff --git a/mlx/backend/gpu/primitives.cpp b/mlx/backend/gpu/primitives.cpp index 56d389b4f..6017879a5 100644 --- a/mlx/backend/gpu/primitives.cpp +++ b/mlx/backend/gpu/primitives.cpp @@ -20,29 +20,6 @@ namespace mlx::core { -namespace { - -void reshape(const array& in, array& out, Stream s) { - auto [copy_necessary, out_strides] = prepare_reshape(in, out); - if (copy_necessary) { - out.set_data(allocator::malloc(out.nbytes())); - copy_gpu_inplace( - in, - out, - in.shape(), - in.strides(), - make_contiguous_strides(in.shape()), - 0, - 0, - CopyType::General, - s); - } else { - shared_buffer_reshape(in, out_strides, out); - } -} - -} // namespace - void AsStrided::eval_gpu(const std::vector& inputs, array& out) { MLX_PROFILER_RANGE("AsStrided::eval_gpu"); eval(inputs, out); @@ -124,7 +101,7 @@ void Full::eval_gpu(const std::vector& inputs, array& out) { void Flatten::eval_gpu(const std::vector& inputs, array& out) { MLX_PROFILER_RANGE("Flatten::eval_gpu"); - reshape(inputs[0], out, stream()); + reshape_gpu(inputs[0], out, stream()); } void NumberOfElements::eval_gpu(const std::vector& inputs, array& out) { @@ -150,7 +127,7 @@ void Pad::eval_gpu(const std::vector& inputs, array& out) { void Reshape::eval_gpu(const std::vector& inputs, array& out) { MLX_PROFILER_RANGE("Reshape::eval_gpu"); - reshape(inputs[0], out, stream()); + reshape_gpu(inputs[0], out, stream()); } void Split::eval_gpu( @@ -224,7 +201,7 @@ void Transpose::eval_gpu(const std::vector& inputs, array& out) { void Unflatten::eval_gpu(const std::vector& inputs, array& out) { MLX_PROFILER_RANGE("Unflatten::eval_gpu"); - reshape(inputs[0], out, stream()); + reshape_gpu(inputs[0], out, stream()); } void View::eval_gpu(const std::vector& inputs, array& out) { diff --git a/python/tests/cuda_skip.py b/python/tests/cuda_skip.py index 0f57100e8..c635de9ad 100644 --- a/python/tests/cuda_skip.py +++ b/python/tests/cuda_skip.py @@ -17,7 +17,6 @@ cuda_skip = { "TestConv.test_1d_conv_with_2d", "TestConv.test_conv_1d_groups_flipped", "TestConv.test_conv_general_flip_grad", - "TestConv.test_conv_groups_grad", "TestConv.test_torch_conv_2D", "TestConv.test_torch_conv_depthwise", "TestConv.test_torch_conv_general",