From 8fb3e7a26c35d768aceac6eb20a4ebf13740b8b2 Mon Sep 17 00:00:00 2001 From: Cheng Date: Thu, 10 Jul 2025 23:24:02 +0900 Subject: [PATCH] [CUDA] Set current device before cudaGraphLaunch (#2351) --- mlx/backend/cuda/device.cpp | 19 ++++++++++--------- mlx/backend/cuda/device.h | 1 + 2 files changed, 11 insertions(+), 9 deletions(-) diff --git a/mlx/backend/cuda/device.cpp b/mlx/backend/cuda/device.cpp index 99ccfdb4a..f7c8ecdc0 100644 --- a/mlx/backend/cuda/device.cpp +++ b/mlx/backend/cuda/device.cpp @@ -57,6 +57,14 @@ void Device::make_current() { } } +CommandEncoder& Device::get_command_encoder(Stream s) { + auto it = encoders_.find(s.index); + if (it == encoders_.end()) { + it = encoders_.try_emplace(s.index, *this).first; + } + return it->second; +} + CommandEncoder::CaptureContext::CaptureContext(CommandEncoder& enc) : enc(enc) { CHECK_CUDA_ERROR(cudaGraphCreate(&graph, 0)); CHECK_CUDA_ERROR( @@ -168,15 +176,7 @@ void CommandEncoder::insert_graph_dependencies(std::vector nodes) { } } -CommandEncoder& Device::get_command_encoder(Stream s) { - auto it = encoders_.find(s.index); - if (it == encoders_.end()) { - it = encoders_.try_emplace(s.index, *this).first; - } - return it->second; -} - -CommandEncoder::CommandEncoder(Device& d) : stream_(d) { +CommandEncoder::CommandEncoder(Device& d) : device_(d), stream_(d) { CHECK_CUDA_ERROR(cudaGraphCreate(&graph_, 0)); } @@ -287,6 +287,7 @@ void CommandEncoder::commit() { CHECK_CUDA_ERROR( cudaGraphInstantiate(&graph_exec, graph_, NULL, NULL, 0)); } + device_.make_current(); CHECK_CUDA_ERROR(cudaGraphLaunch(graph_exec, stream_)); // TODO smarter cache policy diff --git a/mlx/backend/cuda/device.h b/mlx/backend/cuda/device.h index 4ebdae55c..8ac840cbb 100644 --- a/mlx/backend/cuda/device.h +++ b/mlx/backend/cuda/device.h @@ -93,6 +93,7 @@ class CommandEncoder { void insert_graph_dependencies(GraphNode node); void insert_graph_dependencies(std::vector nodes); + Device& device_; CudaStream stream_; cudaGraph_t graph_; Worker worker_;