mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-01 12:49:44 +08:00
cudnn only accepts contiguous inputs
This commit is contained in:
@@ -1,6 +1,7 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#include "mlx/backend/cuda/device.h"
|
||||
#include "mlx/backend/gpu/copy.h"
|
||||
#include "mlx/dtype_utils.h"
|
||||
#include "mlx/primitives.h"
|
||||
|
||||
@@ -168,12 +169,24 @@ auto nhwc_to_nchw(const array& in) {
|
||||
void Convolution::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
nvtx3::scoped_range r("Convolution::eval_gpu");
|
||||
assert(inputs.size() == 2);
|
||||
const array& in = inputs[0];
|
||||
const array& wt = inputs[1];
|
||||
array in = inputs[0];
|
||||
array wt = inputs[1];
|
||||
out.set_data(allocator::malloc(out.nbytes()));
|
||||
|
||||
auto& s = stream();
|
||||
auto& encoder = cu::get_command_encoder(s);
|
||||
|
||||
// While cuDNN supports passing arbitrary strides, it would fail to build a
|
||||
// plan with non-contiguous input.
|
||||
if (!in.flags().row_contiguous) {
|
||||
in = contiguous_copy_gpu(in, s);
|
||||
encoder.add_temporary(in);
|
||||
}
|
||||
if (!wt.flags().row_contiguous) {
|
||||
wt = contiguous_copy_gpu(wt, s);
|
||||
encoder.add_temporary(wt);
|
||||
}
|
||||
|
||||
encoder.set_input_array(in);
|
||||
encoder.set_input_array(wt);
|
||||
encoder.set_output_array(out);
|
||||
|
@@ -16,12 +16,10 @@ cuda_skip = {
|
||||
# Convolutions NYI
|
||||
"TestConv.test_1d_conv_with_2d",
|
||||
"TestConv.test_asymmetric_padding",
|
||||
"TestConv.test_conv2d_unaligned_channels",
|
||||
"TestConv.test_conv_1d_groups_flipped",
|
||||
"TestConv.test_conv_general_flip_grad",
|
||||
"TestConv.test_conv_groups_grad",
|
||||
"TestConv.test_numpy_conv",
|
||||
"TestConv.test_repeated_conv",
|
||||
"TestConv.test_torch_conv_1D",
|
||||
"TestConv.test_torch_conv_1D_grad",
|
||||
"TestConv.test_torch_conv_2D",
|
||||
@@ -39,7 +37,6 @@ cuda_skip = {
|
||||
"TestConvTranspose.test_torch_conv_transpose_3D",
|
||||
"TestConvTranspose.test_torch_conv_transpose_3D_grad",
|
||||
"TestConvTranspose.test_torch_conv_transpose_3d_output_padding",
|
||||
"TestExportImport.test_export_conv",
|
||||
"TestLayers.test_conv1d",
|
||||
"TestLayers.test_conv2d",
|
||||
"TestVmap.test_vmap_conv",
|
||||
|
Reference in New Issue
Block a user