mirror of
https://github.com/ml-explore/mlx.git
synced 2025-07-14 12:31:13 +08:00
[CUDA] Set current device before cudaGraphLaunch (#2351)
This commit is contained in:
parent
8c7bc30ce4
commit
8fb3e7a26c
@ -57,6 +57,14 @@ void Device::make_current() {
|
||||
}
|
||||
}
|
||||
|
||||
CommandEncoder& Device::get_command_encoder(Stream s) {
|
||||
auto it = encoders_.find(s.index);
|
||||
if (it == encoders_.end()) {
|
||||
it = encoders_.try_emplace(s.index, *this).first;
|
||||
}
|
||||
return it->second;
|
||||
}
|
||||
|
||||
CommandEncoder::CaptureContext::CaptureContext(CommandEncoder& enc) : enc(enc) {
|
||||
CHECK_CUDA_ERROR(cudaGraphCreate(&graph, 0));
|
||||
CHECK_CUDA_ERROR(
|
||||
@ -168,15 +176,7 @@ void CommandEncoder::insert_graph_dependencies(std::vector<GraphNode> nodes) {
|
||||
}
|
||||
}
|
||||
|
||||
CommandEncoder& Device::get_command_encoder(Stream s) {
|
||||
auto it = encoders_.find(s.index);
|
||||
if (it == encoders_.end()) {
|
||||
it = encoders_.try_emplace(s.index, *this).first;
|
||||
}
|
||||
return it->second;
|
||||
}
|
||||
|
||||
CommandEncoder::CommandEncoder(Device& d) : stream_(d) {
|
||||
CommandEncoder::CommandEncoder(Device& d) : device_(d), stream_(d) {
|
||||
CHECK_CUDA_ERROR(cudaGraphCreate(&graph_, 0));
|
||||
}
|
||||
|
||||
@ -287,6 +287,7 @@ void CommandEncoder::commit() {
|
||||
CHECK_CUDA_ERROR(
|
||||
cudaGraphInstantiate(&graph_exec, graph_, NULL, NULL, 0));
|
||||
}
|
||||
device_.make_current();
|
||||
CHECK_CUDA_ERROR(cudaGraphLaunch(graph_exec, stream_));
|
||||
|
||||
// TODO smarter cache policy
|
||||
|
@ -93,6 +93,7 @@ class CommandEncoder {
|
||||
void insert_graph_dependencies(GraphNode node);
|
||||
void insert_graph_dependencies(std::vector<GraphNode> nodes);
|
||||
|
||||
Device& device_;
|
||||
CudaStream stream_;
|
||||
cudaGraph_t graph_;
|
||||
Worker worker_;
|
||||
|
Loading…
Reference in New Issue
Block a user