From 226a1d24e0bbaed32c8f3c584955d3248198f002 Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Fri, 10 Oct 2025 16:12:47 -0700 Subject: [PATCH] Debug cuda conv (#2662) * use t4 * use t4 --- mlx/backend/common/matmul.h | 4 ++-- mlx/backend/cuda/conv.cpp | 21 ++++++++++----------- mlx/backend/cuda/cudnn_utils.cpp | 3 +++ 3 files changed, 15 insertions(+), 13 deletions(-) diff --git a/mlx/backend/common/matmul.h b/mlx/backend/common/matmul.h index 2faf256d1..2545c4fde 100644 --- a/mlx/backend/common/matmul.h +++ b/mlx/backend/common/matmul.h @@ -13,7 +13,7 @@ inline std::tuple collapse_batches( const array& a, const array& b) { if (a.ndim() == 2) { - return {{1}, {0}, {0}}; + return {Shape{1}, Strides{0}, Strides{0}}; } Shape A_bshape{a.shape().begin(), a.shape().end() - 2}; @@ -38,7 +38,7 @@ inline std::tuple collapse_batches( inline std::tuple collapse_batches(const array& a, const array& b, const array& c) { if (a.ndim() == 2) { - return {{1}, {0}, {0}, {0}}; + return {Shape{1}, Strides{0}, Strides{0}, Strides{0}}; } Shape A_bshape{a.shape().begin(), a.shape().end() - 2}; diff --git a/mlx/backend/cuda/conv.cpp b/mlx/backend/cuda/conv.cpp index a5bc8e41a..e65de63e0 100644 --- a/mlx/backend/cuda/conv.cpp +++ b/mlx/backend/cuda/conv.cpp @@ -382,20 +382,19 @@ void Convolution::eval_gpu(const std::vector& inputs, array& out_) { } if (op_graph) { - // Setup inputs and outputs. - register_args(encoder, backend_type, in, wt, out, out_); - // Find a plan for the graph and execute it. auto plan = find_cudnn_plan_from_op_graph( encoder.device().cudnn_handle(), backend_type, dtype, *op_graph); - if (!plan) { - throw std::runtime_error("[conv] Unable to find an execution plan."); - } - auto [x, w, y] = dispatch_args(backend_type, in, wt, out); - if (encode_cudnn_plan(encoder, *plan, {'x', 'w', 'y'}, x, w, y)) { - conv_cache().emplace( - cache_key, std::make_pair(backend_type, std::move(*plan))); - return; + if (plan) { + // Setup inputs and outputs. + register_args(encoder, backend_type, in, wt, out, out_); + + auto [x, w, y] = dispatch_args(backend_type, in, wt, out); + if (encode_cudnn_plan(encoder, *plan, {'x', 'w', 'y'}, x, w, y)) { + conv_cache().emplace( + cache_key, std::make_pair(backend_type, std::move(*plan))); + return; + } } } diff --git a/mlx/backend/cuda/cudnn_utils.cpp b/mlx/backend/cuda/cudnn_utils.cpp index 4fc112891..20280f2be 100644 --- a/mlx/backend/cuda/cudnn_utils.cpp +++ b/mlx/backend/cuda/cudnn_utils.cpp @@ -210,6 +210,9 @@ std::optional find_cudnn_plan_from_op_graph( Dtype dtype, cudnn_frontend::OperationGraph& op_graph) { auto engine_configs = get_cudnn_engine_configs(backend_type, dtype, op_graph); + if (engine_configs.empty()) { + return std::nullopt; + } return find_cudnn_plan_from_engine_configs(handle, engine_configs, op_graph); }