diff --git a/mlx/backend/cuda/conv.cpp b/mlx/backend/cuda/conv.cpp index 593b8a3c2..fc7b0fee5 100644 --- a/mlx/backend/cuda/conv.cpp +++ b/mlx/backend/cuda/conv.cpp @@ -1,6 +1,7 @@ // Copyright © 2025 Apple Inc. #include "mlx/backend/cuda/device.h" +#include "mlx/backend/gpu/copy.h" #include "mlx/dtype_utils.h" #include "mlx/primitives.h" @@ -168,12 +169,24 @@ auto nhwc_to_nchw(const array& in) { void Convolution::eval_gpu(const std::vector& inputs, array& out) { nvtx3::scoped_range r("Convolution::eval_gpu"); assert(inputs.size() == 2); - const array& in = inputs[0]; - const array& wt = inputs[1]; + array in = inputs[0]; + array wt = inputs[1]; out.set_data(allocator::malloc(out.nbytes())); auto& s = stream(); auto& encoder = cu::get_command_encoder(s); + + // While cuDNN supports passing arbitrary strides, it would fail to build a + // plan with non-contiguous input. + if (!in.flags().row_contiguous) { + in = contiguous_copy_gpu(in, s); + encoder.add_temporary(in); + } + if (!wt.flags().row_contiguous) { + wt = contiguous_copy_gpu(wt, s); + encoder.add_temporary(wt); + } + encoder.set_input_array(in); encoder.set_input_array(wt); encoder.set_output_array(out); diff --git a/python/tests/cuda_skip.py b/python/tests/cuda_skip.py index 7aceedc88..e212abec3 100644 --- a/python/tests/cuda_skip.py +++ b/python/tests/cuda_skip.py @@ -16,12 +16,10 @@ cuda_skip = { # Convolutions NYI "TestConv.test_1d_conv_with_2d", "TestConv.test_asymmetric_padding", - "TestConv.test_conv2d_unaligned_channels", "TestConv.test_conv_1d_groups_flipped", "TestConv.test_conv_general_flip_grad", "TestConv.test_conv_groups_grad", "TestConv.test_numpy_conv", - "TestConv.test_repeated_conv", "TestConv.test_torch_conv_1D", "TestConv.test_torch_conv_1D_grad", "TestConv.test_torch_conv_2D", @@ -39,7 +37,6 @@ cuda_skip = { "TestConvTranspose.test_torch_conv_transpose_3D", "TestConvTranspose.test_torch_conv_transpose_3D_grad", "TestConvTranspose.test_torch_conv_transpose_3d_output_padding", - "TestExportImport.test_export_conv", "TestLayers.test_conv1d", "TestLayers.test_conv2d", "TestVmap.test_vmap_conv",