mirror of
https://github.com/ml-explore/mlx.git
synced 2025-07-16 22:11:15 +08:00
[CUDA] Set current device before cudaGraphLaunch (#2351)
This commit is contained in:
parent
8c7bc30ce4
commit
8fb3e7a26c
@ -57,6 +57,14 @@ void Device::make_current() {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
CommandEncoder& Device::get_command_encoder(Stream s) {
|
||||||
|
auto it = encoders_.find(s.index);
|
||||||
|
if (it == encoders_.end()) {
|
||||||
|
it = encoders_.try_emplace(s.index, *this).first;
|
||||||
|
}
|
||||||
|
return it->second;
|
||||||
|
}
|
||||||
|
|
||||||
CommandEncoder::CaptureContext::CaptureContext(CommandEncoder& enc) : enc(enc) {
|
CommandEncoder::CaptureContext::CaptureContext(CommandEncoder& enc) : enc(enc) {
|
||||||
CHECK_CUDA_ERROR(cudaGraphCreate(&graph, 0));
|
CHECK_CUDA_ERROR(cudaGraphCreate(&graph, 0));
|
||||||
CHECK_CUDA_ERROR(
|
CHECK_CUDA_ERROR(
|
||||||
@ -168,15 +176,7 @@ void CommandEncoder::insert_graph_dependencies(std::vector<GraphNode> nodes) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
CommandEncoder& Device::get_command_encoder(Stream s) {
|
CommandEncoder::CommandEncoder(Device& d) : device_(d), stream_(d) {
|
||||||
auto it = encoders_.find(s.index);
|
|
||||||
if (it == encoders_.end()) {
|
|
||||||
it = encoders_.try_emplace(s.index, *this).first;
|
|
||||||
}
|
|
||||||
return it->second;
|
|
||||||
}
|
|
||||||
|
|
||||||
CommandEncoder::CommandEncoder(Device& d) : stream_(d) {
|
|
||||||
CHECK_CUDA_ERROR(cudaGraphCreate(&graph_, 0));
|
CHECK_CUDA_ERROR(cudaGraphCreate(&graph_, 0));
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -287,6 +287,7 @@ void CommandEncoder::commit() {
|
|||||||
CHECK_CUDA_ERROR(
|
CHECK_CUDA_ERROR(
|
||||||
cudaGraphInstantiate(&graph_exec, graph_, NULL, NULL, 0));
|
cudaGraphInstantiate(&graph_exec, graph_, NULL, NULL, 0));
|
||||||
}
|
}
|
||||||
|
device_.make_current();
|
||||||
CHECK_CUDA_ERROR(cudaGraphLaunch(graph_exec, stream_));
|
CHECK_CUDA_ERROR(cudaGraphLaunch(graph_exec, stream_));
|
||||||
|
|
||||||
// TODO smarter cache policy
|
// TODO smarter cache policy
|
||||||
|
@ -93,6 +93,7 @@ class CommandEncoder {
|
|||||||
void insert_graph_dependencies(GraphNode node);
|
void insert_graph_dependencies(GraphNode node);
|
||||||
void insert_graph_dependencies(std::vector<GraphNode> nodes);
|
void insert_graph_dependencies(std::vector<GraphNode> nodes);
|
||||||
|
|
||||||
|
Device& device_;
|
||||||
CudaStream stream_;
|
CudaStream stream_;
|
||||||
cudaGraph_t graph_;
|
cudaGraph_t graph_;
|
||||||
Worker worker_;
|
Worker worker_;
|
||||||
|
Loading…
Reference in New Issue
Block a user