mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
cudnn only accepts contiguous inputs
This commit is contained in:
@@ -1,6 +1,7 @@
|
|||||||
// Copyright © 2025 Apple Inc.
|
// Copyright © 2025 Apple Inc.
|
||||||
|
|
||||||
#include "mlx/backend/cuda/device.h"
|
#include "mlx/backend/cuda/device.h"
|
||||||
|
#include "mlx/backend/gpu/copy.h"
|
||||||
#include "mlx/dtype_utils.h"
|
#include "mlx/dtype_utils.h"
|
||||||
#include "mlx/primitives.h"
|
#include "mlx/primitives.h"
|
||||||
|
|
||||||
@@ -168,12 +169,24 @@ auto nhwc_to_nchw(const array& in) {
|
|||||||
void Convolution::eval_gpu(const std::vector<array>& inputs, array& out) {
|
void Convolution::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||||
nvtx3::scoped_range r("Convolution::eval_gpu");
|
nvtx3::scoped_range r("Convolution::eval_gpu");
|
||||||
assert(inputs.size() == 2);
|
assert(inputs.size() == 2);
|
||||||
const array& in = inputs[0];
|
array in = inputs[0];
|
||||||
const array& wt = inputs[1];
|
array wt = inputs[1];
|
||||||
out.set_data(allocator::malloc(out.nbytes()));
|
out.set_data(allocator::malloc(out.nbytes()));
|
||||||
|
|
||||||
auto& s = stream();
|
auto& s = stream();
|
||||||
auto& encoder = cu::get_command_encoder(s);
|
auto& encoder = cu::get_command_encoder(s);
|
||||||
|
|
||||||
|
// While cuDNN supports passing arbitrary strides, it would fail to build a
|
||||||
|
// plan with non-contiguous input.
|
||||||
|
if (!in.flags().row_contiguous) {
|
||||||
|
in = contiguous_copy_gpu(in, s);
|
||||||
|
encoder.add_temporary(in);
|
||||||
|
}
|
||||||
|
if (!wt.flags().row_contiguous) {
|
||||||
|
wt = contiguous_copy_gpu(wt, s);
|
||||||
|
encoder.add_temporary(wt);
|
||||||
|
}
|
||||||
|
|
||||||
encoder.set_input_array(in);
|
encoder.set_input_array(in);
|
||||||
encoder.set_input_array(wt);
|
encoder.set_input_array(wt);
|
||||||
encoder.set_output_array(out);
|
encoder.set_output_array(out);
|
||||||
|
|||||||
@@ -16,12 +16,10 @@ cuda_skip = {
|
|||||||
# Convolutions NYI
|
# Convolutions NYI
|
||||||
"TestConv.test_1d_conv_with_2d",
|
"TestConv.test_1d_conv_with_2d",
|
||||||
"TestConv.test_asymmetric_padding",
|
"TestConv.test_asymmetric_padding",
|
||||||
"TestConv.test_conv2d_unaligned_channels",
|
|
||||||
"TestConv.test_conv_1d_groups_flipped",
|
"TestConv.test_conv_1d_groups_flipped",
|
||||||
"TestConv.test_conv_general_flip_grad",
|
"TestConv.test_conv_general_flip_grad",
|
||||||
"TestConv.test_conv_groups_grad",
|
"TestConv.test_conv_groups_grad",
|
||||||
"TestConv.test_numpy_conv",
|
"TestConv.test_numpy_conv",
|
||||||
"TestConv.test_repeated_conv",
|
|
||||||
"TestConv.test_torch_conv_1D",
|
"TestConv.test_torch_conv_1D",
|
||||||
"TestConv.test_torch_conv_1D_grad",
|
"TestConv.test_torch_conv_1D_grad",
|
||||||
"TestConv.test_torch_conv_2D",
|
"TestConv.test_torch_conv_2D",
|
||||||
@@ -39,7 +37,6 @@ cuda_skip = {
|
|||||||
"TestConvTranspose.test_torch_conv_transpose_3D",
|
"TestConvTranspose.test_torch_conv_transpose_3D",
|
||||||
"TestConvTranspose.test_torch_conv_transpose_3D_grad",
|
"TestConvTranspose.test_torch_conv_transpose_3D_grad",
|
||||||
"TestConvTranspose.test_torch_conv_transpose_3d_output_padding",
|
"TestConvTranspose.test_torch_conv_transpose_3d_output_padding",
|
||||||
"TestExportImport.test_export_conv",
|
|
||||||
"TestLayers.test_conv1d",
|
"TestLayers.test_conv1d",
|
||||||
"TestLayers.test_conv2d",
|
"TestLayers.test_conv2d",
|
||||||
"TestVmap.test_vmap_conv",
|
"TestVmap.test_vmap_conv",
|
||||||
|
|||||||
Reference in New Issue
Block a user