Test the native cuda graph api

This commit is contained in:
Cheng
2025-07-20 03:40:15 -07:00
parent 85510dae78
commit 67a5f7b2a8

View File

@@ -22,6 +22,9 @@ namespace mlx::core {
namespace {
// Not all engines support it so can not use this API now.
#define MLX_USE_CUDNN_NATIVE_CUDA_GRAPH_API 0
struct ConvCacheKey {
int device_id;
cudnnBackendDescriptorType_t backend_type;
@@ -181,6 +184,20 @@ bool execute_plan(
.setUids(3, uids)
.build();
#if CUDNN_VERSION >= 90500 && MLX_USE_CUDNN_NATIVE_CUDA_GRAPH_API
cudaGraph_t graph;
cudaGraphCreate(&graph, 0);
std::unique_ptr<cudaGraph_t, void (*)(cudaGraph_t*)> graph_freer(
&graph, [](cudaGraph_t* p) { cudaGraphDestroy(*p); });
if (cudnnBackendPopulateCudaGraph(
encoder.device().cudnn_handle(),
plan.get_raw_desc(),
variantPack.get_raw_desc(),
graph) != CUDNN_STATUS_SUCCESS) {
return false;
}
encoder.add_graph_node(graph);
#else
auto capture = encoder.capture_context();
if (cudnnBackendExecute(
encoder.device().cudnn_handle(),
@@ -190,6 +207,7 @@ bool execute_plan(
capture.discard = true;
return false;
}
#endif
encoder.add_temporary(workspace);
return true;