diff --git a/mlx/backend/cuda/device.cpp b/mlx/backend/cuda/device.cpp index 2e43e2df2..e7224642a 100644 --- a/mlx/backend/cuda/device.cpp +++ b/mlx/backend/cuda/device.cpp @@ -348,6 +348,9 @@ std::pair subgraph_to_key(cudaGraph_t graph) { key += subkey; break; } + case cudaGraphNodeTypeHost: + key += "H"; + break; case cudaGraphNodeTypeMemset: key += "M"; break;