From f4c8888cbebefe41c2a6486b3508b2a16d5ae999 Mon Sep 17 00:00:00 2001 From: Cheng Date: Thu, 21 Aug 2025 08:55:26 +0900 Subject: [PATCH] [CUDA] Fix stride of singleton dims before passing to cuDNN (#2521) --- mlx/backend/cuda/cudnn_utils.cpp | 28 ++++++++++++++++++++++++---- python/tests/cuda_skip.py | 2 -- 2 files changed, 24 insertions(+), 6 deletions(-) diff --git a/mlx/backend/cuda/cudnn_utils.cpp b/mlx/backend/cuda/cudnn_utils.cpp index 76bcc5b0b..4fc112891 100644 --- a/mlx/backend/cuda/cudnn_utils.cpp +++ b/mlx/backend/cuda/cudnn_utils.cpp @@ -23,6 +23,24 @@ inline cudnn_frontend::Tensor build_cudnn_tensor( .build(); } +// In MLX a singleton dim (shape[dim] == 1) can have any stride, but in cuDNN +// whether a tensor is contiguous is determined with: +// shape[dim] == shape[dim + 1] * strides[dim + 1] +// So a contiguous array with singleton dims in MLX may be mistakenly treated +// as strided in cuDNN, and we work around it by normalizing the strides. +Strides normalized_strides(const array& x) { + if (!x.flags().row_contiguous || x.ndim() < 2) { + return x.strides(); + } + Strides strides = x.strides(); + for (int i = x.ndim() - 2; i >= 0; --i) { + if (x.shape(i) == 1) { + strides[i] = x.shape(i + 1) * strides[i + 1]; + } + } + return strides; +} + // Return the shape and strides after transposing from NHWC to NCHW. auto nhwc_to_nchw(SmallVector shape, SmallVector strides) { assert(shape.size() >= 3); @@ -33,8 +51,9 @@ auto nhwc_to_nchw(SmallVector shape, SmallVector strides) { return std::make_tuple(std::move(shape), std::move(strides)); } -auto nhwc_to_nchw(const array& x) { - return nhwc_to_nchw(convert_vector(x.shape()), x.strides()); +inline auto nhwc_to_nchw(const array& x) { + return nhwc_to_nchw( + convert_vector(x.shape()), normalized_strides(x)); } // Return available engines for a |op_graph|. @@ -140,7 +159,7 @@ bool prepare_cudnn_plan( cudnn_frontend::Tensor build_cudnn_tensor(int64_t id, const array& x) { auto shape = convert_vector(x.shape()); - return build_cudnn_tensor(id, x, shape, x.strides()); + return build_cudnn_tensor(id, x, shape, normalized_strides(x)); } cudnn_frontend::Tensor build_cudnn_tensor_nchw(int64_t id, const array& x) { @@ -160,7 +179,8 @@ cudnn_frontend::Tensor build_cudnn_tensor_4d_nchw(int64_t id, const array& x) { return build_cudnn_tensor(id, x, shape, strides); } if (x.ndim() == 2) { - int64_t s = x.strides(0); + int64_t s = + x.flags().row_contiguous ? x.shape(1) * x.strides(1) : x.strides(0); SmallVector shape = {x.shape(0), x.shape(1), 1, 1}; SmallVector strides = {s, x.strides(1), s, s}; return build_cudnn_tensor(id, x, shape, strides); diff --git a/python/tests/cuda_skip.py b/python/tests/cuda_skip.py index 78639da21..2cc4a6c17 100644 --- a/python/tests/cuda_skip.py +++ b/python/tests/cuda_skip.py @@ -13,8 +13,6 @@ cuda_skip = { # Hadamard NYI "TestOps.test_hadamard", "TestOps.test_hadamard_grad_vmap", - # Convolutions NYI - "TestConv.test_1d_conv_with_2d", # FFTs NYI "TestFFT.test_fft", "TestFFT.test_fft_big_powers_of_two",