[CUDA] Fix stride of singleton dims before passing to cuDNN (#2521)

This commit is contained in:
Cheng 2025-08-21 08:55:26 +09:00 committed by GitHub
parent 25c1e03205
commit f4c8888cbe
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 24 additions and 6 deletions

View File

@ -23,6 +23,24 @@ inline cudnn_frontend::Tensor build_cudnn_tensor(
.build(); .build();
} }
// In MLX a singleton dim (shape[dim] == 1) can have any stride, but in cuDNN
// whether a tensor is contiguous is determined with:
// shape[dim] == shape[dim + 1] * strides[dim + 1]
// So a contiguous array with singleton dims in MLX may be mistakenly treated
// as strided in cuDNN, and we work around it by normalizing the strides.
Strides normalized_strides(const array& x) {
if (!x.flags().row_contiguous || x.ndim() < 2) {
return x.strides();
}
Strides strides = x.strides();
for (int i = x.ndim() - 2; i >= 0; --i) {
if (x.shape(i) == 1) {
strides[i] = x.shape(i + 1) * strides[i + 1];
}
}
return strides;
}
// Return the shape and strides after transposing from NHWC to NCHW. // Return the shape and strides after transposing from NHWC to NCHW.
auto nhwc_to_nchw(SmallVector<int64_t> shape, SmallVector<int64_t> strides) { auto nhwc_to_nchw(SmallVector<int64_t> shape, SmallVector<int64_t> strides) {
assert(shape.size() >= 3); assert(shape.size() >= 3);
@ -33,8 +51,9 @@ auto nhwc_to_nchw(SmallVector<int64_t> shape, SmallVector<int64_t> strides) {
return std::make_tuple(std::move(shape), std::move(strides)); return std::make_tuple(std::move(shape), std::move(strides));
} }
auto nhwc_to_nchw(const array& x) { inline auto nhwc_to_nchw(const array& x) {
return nhwc_to_nchw(convert_vector<int64_t>(x.shape()), x.strides()); return nhwc_to_nchw(
convert_vector<int64_t>(x.shape()), normalized_strides(x));
} }
// Return available engines for a |op_graph|. // Return available engines for a |op_graph|.
@ -140,7 +159,7 @@ bool prepare_cudnn_plan(
cudnn_frontend::Tensor build_cudnn_tensor(int64_t id, const array& x) { cudnn_frontend::Tensor build_cudnn_tensor(int64_t id, const array& x) {
auto shape = convert_vector<int64_t>(x.shape()); auto shape = convert_vector<int64_t>(x.shape());
return build_cudnn_tensor(id, x, shape, x.strides()); return build_cudnn_tensor(id, x, shape, normalized_strides(x));
} }
cudnn_frontend::Tensor build_cudnn_tensor_nchw(int64_t id, const array& x) { cudnn_frontend::Tensor build_cudnn_tensor_nchw(int64_t id, const array& x) {
@ -160,7 +179,8 @@ cudnn_frontend::Tensor build_cudnn_tensor_4d_nchw(int64_t id, const array& x) {
return build_cudnn_tensor(id, x, shape, strides); return build_cudnn_tensor(id, x, shape, strides);
} }
if (x.ndim() == 2) { if (x.ndim() == 2) {
int64_t s = x.strides(0); int64_t s =
x.flags().row_contiguous ? x.shape(1) * x.strides(1) : x.strides(0);
SmallVector<int64_t, 4> shape = {x.shape(0), x.shape(1), 1, 1}; SmallVector<int64_t, 4> shape = {x.shape(0), x.shape(1), 1, 1};
SmallVector<int64_t, 4> strides = {s, x.strides(1), s, s}; SmallVector<int64_t, 4> strides = {s, x.strides(1), s, s};
return build_cudnn_tensor(id, x, shape, strides); return build_cudnn_tensor(id, x, shape, strides);

View File

@ -13,8 +13,6 @@ cuda_skip = {
# Hadamard NYI # Hadamard NYI
"TestOps.test_hadamard", "TestOps.test_hadamard",
"TestOps.test_hadamard_grad_vmap", "TestOps.test_hadamard_grad_vmap",
# Convolutions NYI
"TestConv.test_1d_conv_with_2d",
# FFTs NYI # FFTs NYI
"TestFFT.test_fft", "TestFFT.test_fft",
"TestFFT.test_fft_big_powers_of_two", "TestFFT.test_fft_big_powers_of_two",