diff --git a/mlx/backend/cuda/device.cpp b/mlx/backend/cuda/device.cpp index 99ccfdb4a..f7c8ecdc0 100644 --- a/mlx/backend/cuda/device.cpp +++ b/mlx/backend/cuda/device.cpp @@ -57,6 +57,14 @@ void Device::make_current() { } } +CommandEncoder& Device::get_command_encoder(Stream s) { + auto it = encoders_.find(s.index); + if (it == encoders_.end()) { + it = encoders_.try_emplace(s.index, *this).first; + } + return it->second; +} + CommandEncoder::CaptureContext::CaptureContext(CommandEncoder& enc) : enc(enc) { CHECK_CUDA_ERROR(cudaGraphCreate(&graph, 0)); CHECK_CUDA_ERROR( @@ -168,15 +176,7 @@ void CommandEncoder::insert_graph_dependencies(std::vector nodes) { } } -CommandEncoder& Device::get_command_encoder(Stream s) { - auto it = encoders_.find(s.index); - if (it == encoders_.end()) { - it = encoders_.try_emplace(s.index, *this).first; - } - return it->second; -} - -CommandEncoder::CommandEncoder(Device& d) : stream_(d) { +CommandEncoder::CommandEncoder(Device& d) : device_(d), stream_(d) { CHECK_CUDA_ERROR(cudaGraphCreate(&graph_, 0)); } @@ -287,6 +287,7 @@ void CommandEncoder::commit() { CHECK_CUDA_ERROR( cudaGraphInstantiate(&graph_exec, graph_, NULL, NULL, 0)); } + device_.make_current(); CHECK_CUDA_ERROR(cudaGraphLaunch(graph_exec, stream_)); // TODO smarter cache policy diff --git a/mlx/backend/cuda/device.h b/mlx/backend/cuda/device.h index 4ebdae55c..8ac840cbb 100644 --- a/mlx/backend/cuda/device.h +++ b/mlx/backend/cuda/device.h @@ -93,6 +93,7 @@ class CommandEncoder { void insert_graph_dependencies(GraphNode node); void insert_graph_dependencies(std::vector nodes); + Device& device_; CudaStream stream_; cudaGraph_t graph_; Worker worker_;