mirror of
https://github.com/ml-explore/mlx.git
synced 2025-08-22 13:07:51 +08:00
[CUDA] Fix stride of singleton dims before passing to cuDNN (#2521)
This commit is contained in:
parent
25c1e03205
commit
f4c8888cbe
@ -23,6 +23,24 @@ inline cudnn_frontend::Tensor build_cudnn_tensor(
|
|||||||
.build();
|
.build();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// In MLX a singleton dim (shape[dim] == 1) can have any stride, but in cuDNN
|
||||||
|
// whether a tensor is contiguous is determined with:
|
||||||
|
// shape[dim] == shape[dim + 1] * strides[dim + 1]
|
||||||
|
// So a contiguous array with singleton dims in MLX may be mistakenly treated
|
||||||
|
// as strided in cuDNN, and we work around it by normalizing the strides.
|
||||||
|
Strides normalized_strides(const array& x) {
|
||||||
|
if (!x.flags().row_contiguous || x.ndim() < 2) {
|
||||||
|
return x.strides();
|
||||||
|
}
|
||||||
|
Strides strides = x.strides();
|
||||||
|
for (int i = x.ndim() - 2; i >= 0; --i) {
|
||||||
|
if (x.shape(i) == 1) {
|
||||||
|
strides[i] = x.shape(i + 1) * strides[i + 1];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return strides;
|
||||||
|
}
|
||||||
|
|
||||||
// Return the shape and strides after transposing from NHWC to NCHW.
|
// Return the shape and strides after transposing from NHWC to NCHW.
|
||||||
auto nhwc_to_nchw(SmallVector<int64_t> shape, SmallVector<int64_t> strides) {
|
auto nhwc_to_nchw(SmallVector<int64_t> shape, SmallVector<int64_t> strides) {
|
||||||
assert(shape.size() >= 3);
|
assert(shape.size() >= 3);
|
||||||
@ -33,8 +51,9 @@ auto nhwc_to_nchw(SmallVector<int64_t> shape, SmallVector<int64_t> strides) {
|
|||||||
return std::make_tuple(std::move(shape), std::move(strides));
|
return std::make_tuple(std::move(shape), std::move(strides));
|
||||||
}
|
}
|
||||||
|
|
||||||
auto nhwc_to_nchw(const array& x) {
|
inline auto nhwc_to_nchw(const array& x) {
|
||||||
return nhwc_to_nchw(convert_vector<int64_t>(x.shape()), x.strides());
|
return nhwc_to_nchw(
|
||||||
|
convert_vector<int64_t>(x.shape()), normalized_strides(x));
|
||||||
}
|
}
|
||||||
|
|
||||||
// Return available engines for a |op_graph|.
|
// Return available engines for a |op_graph|.
|
||||||
@ -140,7 +159,7 @@ bool prepare_cudnn_plan(
|
|||||||
|
|
||||||
cudnn_frontend::Tensor build_cudnn_tensor(int64_t id, const array& x) {
|
cudnn_frontend::Tensor build_cudnn_tensor(int64_t id, const array& x) {
|
||||||
auto shape = convert_vector<int64_t>(x.shape());
|
auto shape = convert_vector<int64_t>(x.shape());
|
||||||
return build_cudnn_tensor(id, x, shape, x.strides());
|
return build_cudnn_tensor(id, x, shape, normalized_strides(x));
|
||||||
}
|
}
|
||||||
|
|
||||||
cudnn_frontend::Tensor build_cudnn_tensor_nchw(int64_t id, const array& x) {
|
cudnn_frontend::Tensor build_cudnn_tensor_nchw(int64_t id, const array& x) {
|
||||||
@ -160,7 +179,8 @@ cudnn_frontend::Tensor build_cudnn_tensor_4d_nchw(int64_t id, const array& x) {
|
|||||||
return build_cudnn_tensor(id, x, shape, strides);
|
return build_cudnn_tensor(id, x, shape, strides);
|
||||||
}
|
}
|
||||||
if (x.ndim() == 2) {
|
if (x.ndim() == 2) {
|
||||||
int64_t s = x.strides(0);
|
int64_t s =
|
||||||
|
x.flags().row_contiguous ? x.shape(1) * x.strides(1) : x.strides(0);
|
||||||
SmallVector<int64_t, 4> shape = {x.shape(0), x.shape(1), 1, 1};
|
SmallVector<int64_t, 4> shape = {x.shape(0), x.shape(1), 1, 1};
|
||||||
SmallVector<int64_t, 4> strides = {s, x.strides(1), s, s};
|
SmallVector<int64_t, 4> strides = {s, x.strides(1), s, s};
|
||||||
return build_cudnn_tensor(id, x, shape, strides);
|
return build_cudnn_tensor(id, x, shape, strides);
|
||||||
|
@ -13,8 +13,6 @@ cuda_skip = {
|
|||||||
# Hadamard NYI
|
# Hadamard NYI
|
||||||
"TestOps.test_hadamard",
|
"TestOps.test_hadamard",
|
||||||
"TestOps.test_hadamard_grad_vmap",
|
"TestOps.test_hadamard_grad_vmap",
|
||||||
# Convolutions NYI
|
|
||||||
"TestConv.test_1d_conv_with_2d",
|
|
||||||
# FFTs NYI
|
# FFTs NYI
|
||||||
"TestFFT.test_fft",
|
"TestFFT.test_fft",
|
||||||
"TestFFT.test_fft_big_powers_of_two",
|
"TestFFT.test_fft_big_powers_of_two",
|
||||||
|
Loading…
Reference in New Issue
Block a user