From ac85ddfdb70a0e80f2281f42146d3a89999b1609 Mon Sep 17 00:00:00 2001 From: Cheng Date: Wed, 20 Aug 2025 10:06:22 +0900 Subject: [PATCH] [CUDA] Add GEMM-based fallback convolution kernels (#2511) * Add gemm_conv * Add gemm_grouped_conv --- mlx/backend/cuda/CMakeLists.txt | 2 + mlx/backend/cuda/conv.cpp | 85 +++++--- mlx/backend/cuda/conv/conv.h | 126 +++++++++++ mlx/backend/cuda/conv/gemm_conv.cu | 217 +++++++++++++++++++ mlx/backend/cuda/conv/gemm_grouped_conv.cu | 231 +++++++++++++++++++++ mlx/backend/cuda/gemms/cublas_gemm.cpp | 19 ++ mlx/backend/cuda/gemms/cublas_gemm.h | 11 + python/tests/cuda_skip.py | 8 - 8 files changed, 667 insertions(+), 32 deletions(-) create mode 100644 mlx/backend/cuda/conv/conv.h create mode 100644 mlx/backend/cuda/conv/gemm_conv.cu create mode 100644 mlx/backend/cuda/conv/gemm_grouped_conv.cu diff --git a/mlx/backend/cuda/CMakeLists.txt b/mlx/backend/cuda/CMakeLists.txt index c529af1d2..994307284 100644 --- a/mlx/backend/cuda/CMakeLists.txt +++ b/mlx/backend/cuda/CMakeLists.txt @@ -16,6 +16,8 @@ target_sources( ${CMAKE_CURRENT_SOURCE_DIR}/copy/copy_general_dynamic.cu ${CMAKE_CURRENT_SOURCE_DIR}/copy/copy_general_input.cu ${CMAKE_CURRENT_SOURCE_DIR}/conv.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/conv/gemm_conv.cu + ${CMAKE_CURRENT_SOURCE_DIR}/conv/gemm_grouped_conv.cu ${CMAKE_CURRENT_SOURCE_DIR}/cuda.cpp ${CMAKE_CURRENT_SOURCE_DIR}/cudnn_utils.cpp ${CMAKE_CURRENT_SOURCE_DIR}/device.cpp diff --git a/mlx/backend/cuda/conv.cpp b/mlx/backend/cuda/conv.cpp index 4c8016e9d..63188fbc8 100644 --- a/mlx/backend/cuda/conv.cpp +++ b/mlx/backend/cuda/conv.cpp @@ -1,5 +1,6 @@ // Copyright © 2025 Apple Inc. +#include "mlx/backend/cuda/conv/conv.h" #include "mlx/backend/cuda/cudnn_utils.h" #include "mlx/backend/cuda/device.h" #include "mlx/backend/cuda/lru_cache.h" @@ -21,6 +22,9 @@ namespace { #define CONV_BACKWARD_WEIGHT \ CUDNN_BACKEND_OPERATION_CONVOLUTION_BACKWARD_FILTER_DESCRIPTOR +// Custom placeholder representing fallback kernel. +#define CONV_FALLBACK static_cast(-1) + struct ConvCacheKey { int device_id; cudnnDataType_t cudnn_dtype; @@ -40,7 +44,9 @@ struct ConvCacheKey { auto& conv_cache() { static LRUBytesKeyCache< ConvCacheKey, - std::pair> + std::pair< + cudnnBackendDescriptorType_t, + std::optional>> cache(/* capacity */ 128); return cache; } @@ -292,12 +298,29 @@ void Convolution::eval_gpu(const std::vector& inputs, array& out_) { get_alignment(out)}; if (auto it = conv_cache().find(cache_key); it != conv_cache().end()) { auto& [backend_type, plan] = it->second; - std::tie(in, wt, out) = - prepare_args(encoder, backend_type, in, wt, out, groups_, s); - register_args(encoder, backend_type, in, wt, out, out_); - auto [x, w, y] = dispatch_args(backend_type, in, wt, out); - if (!encode_cudnn_plan(encoder, plan, {'x', 'w', 'y'}, x, w, y)) { - throw std::runtime_error("[conv] Cached plan failed to execute."); + if (plan) { + // Run cached plan. + std::tie(in, wt, out) = + prepare_args(encoder, backend_type, in, wt, out, groups_, s); + register_args(encoder, backend_type, in, wt, out, out_); + auto [x, w, y] = dispatch_args(backend_type, in, wt, out); + if (!encode_cudnn_plan(encoder, *plan, {'x', 'w', 'y'}, x, w, y)) { + throw std::runtime_error("[conv] Cached plan failed to execute."); + } + } else { + // Run fallback kernel. + gemm_conv( + encoder, + in, + wt, + out, + kernel_strides_, + padding_lo_, + kernel_dilation_, + input_dilation_, + groups_, + flip_, + s); } return; } @@ -357,25 +380,39 @@ void Convolution::eval_gpu(const std::vector& inputs, array& out_) { break; } } - if (!op_graph) { - throw std::runtime_error("[conv] Can not build op graph."); + + if (op_graph) { + // Setup inputs and outputs. + register_args(encoder, backend_type, in, wt, out, out_); + + // Find a plan for the graph and execute it. + auto plan = find_cudnn_plan_from_op_graph( + encoder.device().cudnn_handle(), backend_type, dtype, *op_graph); + if (!plan) { + throw std::runtime_error("[conv] Unable to find an execution plan."); + } + auto [x, w, y] = dispatch_args(backend_type, in, wt, out); + if (encode_cudnn_plan(encoder, *plan, {'x', 'w', 'y'}, x, w, y)) { + conv_cache().emplace( + cache_key, std::make_pair(backend_type, std::move(*plan))); + return; + } } - // Setup inputs and outputs. - register_args(encoder, backend_type, in, wt, out, out_); - - // Find a plan for the graph and execute it. - auto plan = find_cudnn_plan_from_op_graph( - encoder.device().cudnn_handle(), backend_type, dtype, *op_graph); - if (!plan) { - throw std::runtime_error("[conv] Unable to find an execution plan."); - } - auto [x, w, y] = dispatch_args(backend_type, in, wt, out); - if (!encode_cudnn_plan(encoder, *plan, {'x', 'w', 'y'}, x, w, y)) { - throw std::runtime_error("[conv] Failed to run execution plan."); - } - conv_cache().emplace( - cache_key, std::make_pair(backend_type, std::move(*plan))); + // Use fallback kernel for settings not supported by cuDNN. + gemm_conv( + encoder, + in, + wt, + out, + kernel_strides_, + padding_lo_, + kernel_dilation_, + input_dilation_, + groups_, + flip_, + s); + conv_cache().emplace(cache_key, std::make_pair(CONV_FALLBACK, std::nullopt)); } } // namespace mlx::core diff --git a/mlx/backend/cuda/conv/conv.h b/mlx/backend/cuda/conv/conv.h new file mode 100644 index 000000000..62dc9343e --- /dev/null +++ b/mlx/backend/cuda/conv/conv.h @@ -0,0 +1,126 @@ +// Copyright © 2025 Apple Inc. + +#pragma once + +#include "mlx/backend/cuda/device.h" +#include "mlx/backend/gpu/copy.h" + +namespace mlx::core { + +template +struct ConvParams { + int N; // Batch size + int C; // In channels + int O; // Out channels + int strides[NDIM]; + int padding[NDIM]; + int kernel_dilation[NDIM]; + int input_dilation[NDIM]; + int groups; + bool flip; + int in_spatial_dims[NDIM]; + int wt_spatial_dims[NDIM]; + int out_spatial_dims[NDIM]; + int64_t in_strides[NDIM + 2]; + + ConvParams( + const array& in, + const array& wt, + const array& out, + const std::vector& strides, + const std::vector& padding, + const std::vector& kernel_dilation, + const std::vector& input_dilation, + int groups, + bool flip) + : N(in.shape(0)), + C(in.shape(-1)), + O(wt.shape(0)), + groups(groups), + flip(flip) { + std::copy_n(strides.begin(), NDIM, this->strides); + std::copy_n(padding.begin(), NDIM, this->padding); + std::copy_n(kernel_dilation.begin(), NDIM, this->kernel_dilation); + std::copy_n(input_dilation.begin(), NDIM, this->input_dilation); + std::copy_n(in.shape().begin() + 1, NDIM, this->in_spatial_dims); + std::copy_n(wt.shape().begin() + 1, NDIM, this->wt_spatial_dims); + std::copy_n(out.shape().begin() + 1, NDIM, this->out_spatial_dims); + std::copy_n(in.strides().begin(), NDIM + 2, this->in_strides); + } +}; + +void gemm_grouped_conv( + cu::CommandEncoder& encoder, + const array& in, + const array& wt, + array& out, + const std::vector& strides, + const std::vector& padding, + const std::vector& kernel_dilation, + const std::vector& input_dilation, + int groups, + bool flip, + Stream s); + +void gemm_conv( + cu::CommandEncoder& encoder, + const array& in, + const array& wt, + array& out, + const std::vector& strides, + const std::vector& padding, + const std::vector& kernel_dilation, + const std::vector& input_dilation, + bool flip, + Stream s); + +inline void gemm_conv( + cu::CommandEncoder& encoder, + array in, + array wt, + array& out, + const std::vector& strides, + const std::vector& padding, + const std::vector& kernel_dilation, + const std::vector& input_dilation, + int groups, + bool flip, + Stream s) { + if (!in.flags().row_contiguous) { + in = contiguous_copy_gpu(in, s); + encoder.add_temporary(in); + } + if (!wt.flags().row_contiguous) { + wt = contiguous_copy_gpu(wt, s); + encoder.add_temporary(wt); + } + + if (groups == 1) { + gemm_conv( + encoder, + in, + wt, + out, + strides, + padding, + kernel_dilation, + input_dilation, + flip, + s); + } else { + gemm_grouped_conv( + encoder, + in, + wt, + out, + strides, + padding, + kernel_dilation, + input_dilation, + groups, + flip, + s); + } +} + +} // namespace mlx::core diff --git a/mlx/backend/cuda/conv/gemm_conv.cu b/mlx/backend/cuda/conv/gemm_conv.cu new file mode 100644 index 000000000..11a78a7ab --- /dev/null +++ b/mlx/backend/cuda/conv/gemm_conv.cu @@ -0,0 +1,217 @@ +// Copyright © 2025 Apple Inc. + +#include "mlx/backend/cuda/conv/conv.h" +#include "mlx/backend/cuda/gemms/cublas_gemm.h" +#include "mlx/backend/cuda/kernel_utils.cuh" +#include "mlx/dtype_utils.h" + +#include + +namespace mlx::core { + +namespace cu { + +namespace cg = cooperative_groups; + +template +__global__ void naive_unfold_nd( + const T* in, + T* out, + int filter_size, + int out_pixels, + const __grid_constant__ ConvParams params) { + auto block = cg::this_thread_block(); + auto tid = block.group_index(); + auto lid = block.thread_index(); + + int index_batch = tid.z / out_pixels; // [0, N) + int index_out_spatial = tid.z % out_pixels; // [0, H_out * W_out) + int index_wt_spatial = + tid.x * block.dim_threads().x + lid.x; // [0, H_wt * W_wt) + + if (index_wt_spatial >= filter_size / params.C) { + return; + } + + in += tid.y; // [0, C) + out += tid.z * filter_size + index_wt_spatial * params.C + tid.y; + + bool valid = index_batch < params.N; + + // Get the coordinates in input. + int index_in[NDIM] = {}; +#pragma unroll + for (int i = NDIM - 1; i >= 0; --i) { + int index_out = index_out_spatial % params.out_spatial_dims[i]; + int index_wt = index_wt_spatial % params.wt_spatial_dims[i]; + + if (params.flip) { + index_wt = params.wt_spatial_dims[i] - index_wt - 1; + } + + int index = index_out * params.strides[i] - params.padding[i] + + index_wt * params.kernel_dilation[i]; + int index_max = + 1 + params.input_dilation[i] * (params.in_spatial_dims[i] - 1); + + valid &= (index >= 0) && (index < index_max) && + (index % params.input_dilation[i] == 0); + + index_in[i] = index / params.input_dilation[i]; + + index_out_spatial /= params.out_spatial_dims[i]; + index_wt_spatial /= params.wt_spatial_dims[i]; + } + + if (valid) { + int in_offset = index_batch * params.in_strides[0]; +#pragma unroll + for (int i = 0; i < NDIM; ++i) { + in_offset += index_in[i] * params.in_strides[i + 1]; + } + *out = in[in_offset]; + } else { + *out = T{0}; + } +} + +} // namespace cu + +template +array unfold_inputs_nd( + cu::CommandEncoder& encoder, + const array& in, + int mat_M, + int mat_K, + int mat_N, + ConvParams& params) { + array unfolded({mat_M, mat_K}, in.dtype(), nullptr, {}); + unfolded.set_data(allocator::malloc(unfolded.nbytes())); + encoder.add_temporary(unfolded); + + int filter_size = params.C; +#pragma unroll + for (int i = 0; i < NDIM; ++i) { + filter_size *= params.wt_spatial_dims[i]; + } + + int out_pixels = 1; +#pragma unroll + for (int i = 0; i < NDIM; ++i) { + out_pixels *= params.out_spatial_dims[i]; + } + + int wt_spatial_size = mat_K / params.C; + dim3 block_dims; + block_dims.x = std::min(std::max(wt_spatial_size, 32), 1024); + dim3 num_blocks; + num_blocks.x = cuda::ceil_div(wt_spatial_size, block_dims.x); + num_blocks.y = params.C; + num_blocks.z = mat_M; + + encoder.set_input_array(in); + encoder.set_output_array(unfolded); + dispatch_float_types(in.dtype(), "unfold", [&](auto type_tag) { + using DataType = cuda_type_t; + encoder.add_kernel_node( + cu::naive_unfold_nd, + num_blocks, + block_dims, + 0, + in.data(), + unfolded.data(), + filter_size, + out_pixels, + params); + }); + + return unfolded; +} + +template +void gemm_conv_nd( + cu::CommandEncoder& encoder, + const array& in, + const array& wt, + array& out, + ConvParams& params, + Stream s) { + // Get gemm shapes. + int mat_M = out.size() / params.O; // N * H_out * W_out + int mat_K = wt.size() / params.O; // C * H_wt * W_wt + int mat_N = params.O; // O + + // Unfold input to (N * H_out * W_out, C * H_wt * W_wt) for gemm. + array in_unfolded = + unfold_inputs_nd(encoder, in, mat_M, mat_K, mat_N, params); + + // Reshape weight to (C * H_wt * W_wt, O) for gemm. + array wt_reshaped({mat_K, mat_N}, wt.dtype(), nullptr, {}); + wt_reshaped.copy_shared_buffer( + wt, + {1, mat_K}, + {false, false, /* col_contiguous */ true}, + wt.data_size()); + + // Single batch. + Shape batch_shape{1}; + Strides a_batch_strides{0}; + Strides b_batch_strides{0}; + + // Run matmul. + CublasGemm gemm( + encoder.device(), + in.dtype(), + false, // a_transposed + mat_M, // a_rows + mat_K, // a_cols + mat_K, // lda + true, // b_transposed + mat_K, // b_rows + mat_N, // b_cols + mat_K, // ldb + batch_shape.back(), + a_batch_strides.back(), + b_batch_strides.back()); + gemm.run( + encoder, + out, + in_unfolded, + wt_reshaped, + batch_shape, + a_batch_strides, + b_batch_strides); +} + +void gemm_conv( + cu::CommandEncoder& encoder, + const array& in, + const array& wt, + array& out, + const std::vector& strides, + const std::vector& padding, + const std::vector& kernel_dilation, + const std::vector& input_dilation, + bool flip, + Stream s) { + int conv_ndim = in.ndim() - 2; + if (conv_ndim < 1 || conv_ndim > 3) { + throw std::runtime_error( + fmt::format("[conv] Unsupported gemm_conv for {}D conv.", conv_ndim)); + } + dispatch_1_2_3(conv_ndim, [&](auto ndim_constant) { + ConvParams params( + in, + wt, + out, + strides, + padding, + kernel_dilation, + input_dilation, + 1, // groups + flip); + gemm_conv_nd(encoder, in, wt, out, params, s); + }); +} + +} // namespace mlx::core diff --git a/mlx/backend/cuda/conv/gemm_grouped_conv.cu b/mlx/backend/cuda/conv/gemm_grouped_conv.cu new file mode 100644 index 000000000..7ceb58166 --- /dev/null +++ b/mlx/backend/cuda/conv/gemm_grouped_conv.cu @@ -0,0 +1,231 @@ +// Copyright © 2025 Apple Inc. + +#include "mlx/backend/cuda/conv/conv.h" +#include "mlx/backend/cuda/gemms/cublas_gemm.h" +#include "mlx/backend/cuda/kernel_utils.cuh" +#include "mlx/dtype_utils.h" + +#include + +namespace mlx::core { + +namespace cu { + +namespace cg = cooperative_groups; + +template +__global__ void naive_grouped_unfold_transpose_nd( + const T* in, + T* out, + int filter_size, + int out_pixels, + const __grid_constant__ ConvParams params) { + auto block = cg::this_thread_block(); + auto tid = block.group_index(); + auto lid = block.thread_index(); + + int index_batch = tid.z / out_pixels; // [0, N) + int index_out_spatial = tid.z % out_pixels; // [0, H_out * W_out) + int index_wt_spatial = + tid.x * block.dim_threads().x + lid.x; // [0, H_wt * W_wt) + + if (index_wt_spatial >= filter_size / params.C) { + return; + } + + in += tid.y; // [0, C) + out += tid.z * filter_size + tid.y * (filter_size / params.C); + + bool valid = index_batch < params.N; + + // Get the coordinates in input. + int index_in[NDIM] = {}; + int wt_stride = 1; +#pragma unroll + for (int i = NDIM - 1; i >= 0; --i) { + int index_out = index_out_spatial % params.out_spatial_dims[i]; + int index_wt = index_wt_spatial % params.wt_spatial_dims[i]; + out += index_wt * wt_stride; + + if (params.flip) { + index_wt = params.wt_spatial_dims[i] - index_wt - 1; + } + + int index = index_out * params.strides[i] - params.padding[i] + + index_wt * params.kernel_dilation[i]; + int index_max = + 1 + params.input_dilation[i] * (params.in_spatial_dims[i] - 1); + + valid &= (index >= 0) && (index < index_max) && + (index % params.input_dilation[i] == 0); + + index_in[i] = index / params.input_dilation[i]; + + index_out_spatial /= params.out_spatial_dims[i]; + index_wt_spatial /= params.wt_spatial_dims[i]; + wt_stride *= params.wt_spatial_dims[i]; + } + + if (valid) { + int in_offset = index_batch * params.in_strides[0]; +#pragma unroll + for (int i = 0; i < NDIM; ++i) { + in_offset += index_in[i] * params.in_strides[i + 1]; + } + *out = in[in_offset]; + } else { + *out = T{0}; + } +} + +} // namespace cu + +template +array grouped_unfold_transpose_inputs_nd( + cu::CommandEncoder& encoder, + const array& in, + int mat_M, + int mat_K, + int mat_N, + ConvParams& params) { + array unfolded({mat_M, mat_K * params.groups}, in.dtype(), nullptr, {}); + unfolded.set_data(allocator::malloc(unfolded.nbytes())); + encoder.add_temporary(unfolded); + + int filter_size = params.C; +#pragma unroll + for (int i = 0; i < NDIM; ++i) { + filter_size *= params.wt_spatial_dims[i]; + } + + int out_pixels = 1; +#pragma unroll + for (int i = 0; i < NDIM; ++i) { + out_pixels *= params.out_spatial_dims[i]; + } + + int wt_spatial_size = (mat_K * params.groups) / params.C; + dim3 block_dims; + block_dims.x = std::min(std::max(wt_spatial_size, 32), 1024); + dim3 num_blocks; + num_blocks.x = cuda::ceil_div(wt_spatial_size, block_dims.x); + num_blocks.y = params.C; + num_blocks.z = mat_M; + + encoder.set_input_array(in); + encoder.set_output_array(unfolded); + dispatch_float_types(in.dtype(), "unfold", [&](auto type_tag) { + using DataType = cuda_type_t; + encoder.add_kernel_node( + cu::naive_grouped_unfold_transpose_nd, + num_blocks, + block_dims, + 0, + in.data(), + unfolded.data(), + filter_size, + out_pixels, + params); + }); + + return unfolded; +} + +template +void gemm_grouped_conv_nd( + cu::CommandEncoder& encoder, + const array& in, + const array& wt, + array& out, + ConvParams& params, + Stream s) { + // Get gemm shapes. + int C_per_group = params.C / params.groups; + int O_per_group = params.O / params.groups; + int mat_M = out.size() / params.O; // N * H_out * W_out + int mat_K = wt.size() / params.O; // C_per_group * H_wt * W_wt + int mat_N = O_per_group; // O_per_group + + // Unfold input to (N * H_out * W_out, C * H_wt * W_wt) for gemm. + array in_unfolded = grouped_unfold_transpose_inputs_nd( + encoder, in, mat_M, mat_K, mat_N, params); + + // Reshape weight to (O, C_per_group, H_wt * W_wt) for gemm. + int wt_spatial_size = (wt.size() / wt.shape(0)) / wt.shape(-1); + array wt_view( + {params.O, C_per_group, wt_spatial_size}, wt.dtype(), nullptr, {}); + wt_view.copy_shared_buffer( + wt, {wt.strides(0), 1, C_per_group}, wt.flags(), wt.size()); + array wt_reshaped = contiguous_copy_gpu(wt_view, s); + + // Batch with size of groups. + Shape batch_shape{params.groups}; + Strides a_batch_strides{mat_K}; + Strides b_batch_strides{mat_N * mat_K}; + + // Run matmul. + CublasGemm gemm( + encoder.device(), + in.dtype(), + false, // a_transposed + mat_M, // a_rows + mat_K, // a_cols + mat_K * params.groups, // lda + true, // b_transposed + mat_K, // b_rows + mat_N, // b_cols + mat_K, // ldb + batch_shape.back(), + a_batch_strides.back(), + b_batch_strides.back()); + gemm.set_out( + out.dtype(), + false, // out_transposed + mat_M, // out_rows + mat_N, // out_cols + mat_N * params.groups, // out_ld + params.groups, // batch_count + mat_N); // batch_stride + gemm.run( + encoder, + out, + in_unfolded, + wt_reshaped, + batch_shape, + a_batch_strides, + b_batch_strides); +} + +void gemm_grouped_conv( + cu::CommandEncoder& encoder, + const array& in, + const array& wt, + array& out, + const std::vector& strides, + const std::vector& padding, + const std::vector& kernel_dilation, + const std::vector& input_dilation, + int groups, + bool flip, + Stream s) { + int conv_ndim = in.ndim() - 2; + if (conv_ndim < 1 || conv_ndim > 3) { + throw std::runtime_error( + fmt::format("[conv] Unsupported gemm_conv for {}D conv.", conv_ndim)); + } + dispatch_1_2_3(conv_ndim, [&](auto ndim_constant) { + ConvParams params( + in, + wt, + out, + strides, + padding, + kernel_dilation, + input_dilation, + groups, + flip); + gemm_grouped_conv_nd(encoder, in, wt, out, params, s); + }); +} + +} // namespace mlx::core diff --git a/mlx/backend/cuda/gemms/cublas_gemm.cpp b/mlx/backend/cuda/gemms/cublas_gemm.cpp index 1aeeefa38..836385dfe 100644 --- a/mlx/backend/cuda/gemms/cublas_gemm.cpp +++ b/mlx/backend/cuda/gemms/cublas_gemm.cpp @@ -202,6 +202,25 @@ CublasGemm::~CublasGemm() { CHECK_CUBLAS_ERROR(cublasLtMatmulDescDestroy(matmul_desc_)); } +void CublasGemm::set_out( + Dtype dtype, + bool transposed, + uint64_t rows, + uint64_t cols, + int64_t ld, + int32_t batch_count, + int64_t batch_stride) { + CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutDestroy(out_desc_)); + out_desc_ = create_matrix_layout( + dtype_to_cublas_type(dtype), + rows, + cols, + transposed, + ld, + batch_count, + batch_stride); +} + void CublasGemm::run( cu::CommandEncoder& encoder, array& out, diff --git a/mlx/backend/cuda/gemms/cublas_gemm.h b/mlx/backend/cuda/gemms/cublas_gemm.h index e093351b6..1b06fb2f7 100644 --- a/mlx/backend/cuda/gemms/cublas_gemm.h +++ b/mlx/backend/cuda/gemms/cublas_gemm.h @@ -44,6 +44,17 @@ class CublasGemm { ~CublasGemm(); + // The output's descriptor is inferred from inputs by default, use this method + // for unusual output. + void set_out( + Dtype dtype, + bool transposed, + uint64_t rows, + uint64_t cols, + int64_t ld, + int32_t batch_count, + int64_t batch_stride); + void run( cu::CommandEncoder& encoder, array& out, diff --git a/python/tests/cuda_skip.py b/python/tests/cuda_skip.py index c635de9ad..78639da21 100644 --- a/python/tests/cuda_skip.py +++ b/python/tests/cuda_skip.py @@ -15,14 +15,6 @@ cuda_skip = { "TestOps.test_hadamard_grad_vmap", # Convolutions NYI "TestConv.test_1d_conv_with_2d", - "TestConv.test_conv_1d_groups_flipped", - "TestConv.test_conv_general_flip_grad", - "TestConv.test_torch_conv_2D", - "TestConv.test_torch_conv_depthwise", - "TestConv.test_torch_conv_general", - "TestConvTranspose.test_torch_conv_transpose_1D_grad", - "TestConvTranspose.test_torch_conv_transpose_2D_grad", - "TestConvTranspose.test_torch_conv_transpose_3D_grad", # FFTs NYI "TestFFT.test_fft", "TestFFT.test_fft_big_powers_of_two",