From 67a5f7b2a83eabdd4d248e3b793dcaf4d59c6d99 Mon Sep 17 00:00:00 2001 From: Cheng Date: Sun, 20 Jul 2025 03:40:15 -0700 Subject: [PATCH] Test the native cuda graph api --- mlx/backend/cuda/conv.cpp | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) diff --git a/mlx/backend/cuda/conv.cpp b/mlx/backend/cuda/conv.cpp index eced11ec1..594b82dde 100644 --- a/mlx/backend/cuda/conv.cpp +++ b/mlx/backend/cuda/conv.cpp @@ -22,6 +22,9 @@ namespace mlx::core { namespace { +// Not all engines support it so can not use this API now. +#define MLX_USE_CUDNN_NATIVE_CUDA_GRAPH_API 0 + struct ConvCacheKey { int device_id; cudnnBackendDescriptorType_t backend_type; @@ -181,6 +184,20 @@ bool execute_plan( .setUids(3, uids) .build(); +#if CUDNN_VERSION >= 90500 && MLX_USE_CUDNN_NATIVE_CUDA_GRAPH_API + cudaGraph_t graph; + cudaGraphCreate(&graph, 0); + std::unique_ptr graph_freer( + &graph, [](cudaGraph_t* p) { cudaGraphDestroy(*p); }); + if (cudnnBackendPopulateCudaGraph( + encoder.device().cudnn_handle(), + plan.get_raw_desc(), + variantPack.get_raw_desc(), + graph) != CUDNN_STATUS_SUCCESS) { + return false; + } + encoder.add_graph_node(graph); +#else auto capture = encoder.capture_context(); if (cudnnBackendExecute( encoder.device().cudnn_handle(), @@ -190,6 +207,7 @@ bool execute_plan( capture.discard = true; return false; } +#endif encoder.add_temporary(workspace); return true;