diff --git a/mlx/backend/cuda/device.cpp b/mlx/backend/cuda/device.cpp index bf0946a7b..767170848 100644 --- a/mlx/backend/cuda/device.cpp +++ b/mlx/backend/cuda/device.cpp @@ -299,6 +299,7 @@ void CommandEncoder::add_graph_node(cudaGraph_t child) { graph_exec.instantiate(child); device_.make_current(); CHECK_CUDA_ERROR(cudaGraphLaunch(graph_exec, stream())); + return; } cudaGraphNode_t node; CHECK_CUDA_ERROR(cudaGraphAddChildGraphNode(&node, graph_, NULL, 0, child));