[CUDA] Set current device before cudaGraphLaunch (#2351)

This commit is contained in:
Cheng 2025-07-10 23:24:02 +09:00 committed by GitHub
parent 8c7bc30ce4
commit 8fb3e7a26c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 11 additions and 9 deletions

View File

@ -57,6 +57,14 @@ void Device::make_current() {
}
}
CommandEncoder& Device::get_command_encoder(Stream s) {
auto it = encoders_.find(s.index);
if (it == encoders_.end()) {
it = encoders_.try_emplace(s.index, *this).first;
}
return it->second;
}
CommandEncoder::CaptureContext::CaptureContext(CommandEncoder& enc) : enc(enc) {
CHECK_CUDA_ERROR(cudaGraphCreate(&graph, 0));
CHECK_CUDA_ERROR(
@ -168,15 +176,7 @@ void CommandEncoder::insert_graph_dependencies(std::vector<GraphNode> nodes) {
}
}
CommandEncoder& Device::get_command_encoder(Stream s) {
auto it = encoders_.find(s.index);
if (it == encoders_.end()) {
it = encoders_.try_emplace(s.index, *this).first;
}
return it->second;
}
CommandEncoder::CommandEncoder(Device& d) : stream_(d) {
CommandEncoder::CommandEncoder(Device& d) : device_(d), stream_(d) {
CHECK_CUDA_ERROR(cudaGraphCreate(&graph_, 0));
}
@ -287,6 +287,7 @@ void CommandEncoder::commit() {
CHECK_CUDA_ERROR(
cudaGraphInstantiate(&graph_exec, graph_, NULL, NULL, 0));
}
device_.make_current();
CHECK_CUDA_ERROR(cudaGraphLaunch(graph_exec, stream_));
// TODO smarter cache policy

View File

@ -93,6 +93,7 @@ class CommandEncoder {
void insert_graph_dependencies(GraphNode node);
void insert_graph_dependencies(std::vector<GraphNode> nodes);
Device& device_;
CudaStream stream_;
cudaGraph_t graph_;
Worker worker_;