diff --git a/mlx/backend/cuda/conv.cpp b/mlx/backend/cuda/conv.cpp index 2e79c3c25..593b8a3c2 100644 --- a/mlx/backend/cuda/conv.cpp +++ b/mlx/backend/cuda/conv.cpp @@ -29,19 +29,6 @@ using namespace cudnn_frontend; throw std::runtime_error(fmt::format("{} failed.", #cmd)); \ } -auto swapaxes(const array& in, int axis1, int axis2) { - std::vector axes(in.ndim()); - std::iota(axes.begin(), axes.end(), 0); - std::swap(axes[axis1], axes[axis2]); - std::vector shape(in.ndim()); - std::vector strides(in.ndim()); - for (size_t ax = 0; ax < axes.size(); ++ax) { - shape[ax] = in.shape()[axes[ax]]; - strides[ax] = in.strides()[axes[ax]]; - } - return std::make_tuple(shape, strides); -} - class Convolution { public: Convolution( @@ -56,8 +43,7 @@ class Convolution { const std::vector& stride, const std::vector& padding_lo, const std::vector& padding_hi, - const std::vector& dilation, - int groups) + const std::vector& dilation) : handle_(device.cudnn_handle()) { auto cudnn_type = dtype_to_cudnn_type(dtype); bool is_half = dtype == float16 || dtype == bfloat16; @@ -98,7 +84,7 @@ class Convolution { } void run( - cu::CommandEncoder& encoder, + CommandEncoder& encoder, const void* input, const void* filter, void* output) { @@ -160,37 +146,57 @@ class Convolution { } // namespace cu +namespace { + +template +inline std::vector convert_vector(const std::vector& vec) { + return std::vector(vec.begin(), vec.end()); +} + +auto nhwc_to_nchw(const array& in) { + auto shape = convert_vector(in.shape()); + shape.insert(shape.begin() + 1, shape.back()); + shape.erase(shape.end() - 1); + auto strides = convert_vector(in.strides()); + strides.insert(strides.begin() + 1, strides.back()); + strides.erase(strides.end() - 1); + return std::make_tuple(shape, strides); +} + +} // namespace + void Convolution::eval_gpu(const std::vector& inputs, array& out) { nvtx3::scoped_range r("Convolution::eval_gpu"); - auto& s = stream(); - auto& encoder = cu::get_command_encoder(s); - assert(inputs.size() == 2); - const array& input = inputs[0]; - const array& filter = inputs[1]; + const array& in = inputs[0]; + const array& wt = inputs[1]; out.set_data(allocator::malloc(out.nbytes())); + auto& s = stream(); + auto& encoder = cu::get_command_encoder(s); + encoder.set_input_array(in); + encoder.set_input_array(wt); + encoder.set_output_array(out); + // cuDNN requires dims to be passed as NCHW. - int ndim = input.ndim(); - auto [input_shape, input_strides] = cu::swapaxes(input, 1, ndim - 1); - auto [filter_shape, filter_strides] = cu::swapaxes(filter, 1, ndim - 1); - auto [output_shape, output_strides] = cu::swapaxes(out, 1, ndim - 1); + auto [input_shape, input_strides] = nhwc_to_nchw(in); + auto [filter_shape, filter_strides] = nhwc_to_nchw(wt); + auto [output_shape, output_strides] = nhwc_to_nchw(out); cu::Convolution conv( cu::device(s.device), - input.dtype(), + in.dtype(), input_shape, input_strides, filter_shape, filter_strides, output_shape, output_strides, - std::vector(kernel_strides_.begin(), kernel_strides_.end()), - std::vector(padding_lo_.begin(), padding_lo_.end()), - std::vector(padding_hi_.begin(), padding_hi_.end()), - std::vector(kernel_dilation_.begin(), kernel_dilation_.end()), - groups_); - conv.run(encoder, input.data(), filter.data(), out.data()); + convert_vector(kernel_strides_), + convert_vector(padding_lo_), + convert_vector(padding_hi_), + convert_vector(kernel_dilation_)); + conv.run(encoder, in.data(), wt.data(), out.data()); } } // namespace mlx::core diff --git a/python/tests/cuda_skip.py b/python/tests/cuda_skip.py index 50cb8dcbe..7aceedc88 100644 --- a/python/tests/cuda_skip.py +++ b/python/tests/cuda_skip.py @@ -16,7 +16,6 @@ cuda_skip = { # Convolutions NYI "TestConv.test_1d_conv_with_2d", "TestConv.test_asymmetric_padding", - "TestConv.test_basic_grad_shapes", "TestConv.test_conv2d_unaligned_channels", "TestConv.test_conv_1d_groups_flipped", "TestConv.test_conv_general_flip_grad", diff --git a/tests/ops_tests.cpp b/tests/ops_tests.cpp index f4c34a279..969bc2ba7 100644 --- a/tests/ops_tests.cpp +++ b/tests/ops_tests.cpp @@ -4,7 +4,6 @@ #define _USE_MATH_DEFINES #include -#include #include #include "doctest/doctest.h" @@ -3643,7 +3642,6 @@ TEST_CASE("test conv1d") { {1, 3, 4}); auto out = conv1d(in, wt, stride, padding, /* dilation= */ 1, groups); - std::cout << out << std::endl; CHECK(allclose(out, expected).item()); } @@ -3722,7 +3720,6 @@ TEST_CASE("test conv2d") { auto expected = array({1.9549234, -0.98542136, 0.2097499, 0.20991313}, {1, 1, 1, 4}); auto out = conv2d(in, wt, stride, padding, /* dilation= */ {1, 1}, groups); - std::cout << out << std::endl; CHECK(allclose(out, expected).item()); }