mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Fix multistream GPU deadlock (#1969)
* fix multistream GPU deadlock * comments
This commit is contained in:
@@ -9,6 +9,9 @@
|
||||
|
||||
namespace mlx::core::cpu {
|
||||
|
||||
// Number of dispatches per scheduler task
|
||||
constexpr int DISPATCHES_PER_TASK = 10;
|
||||
|
||||
struct CommandEncoder {
|
||||
CommandEncoder(Stream stream) : stream_(stream) {}
|
||||
|
||||
@@ -39,13 +42,24 @@ struct CommandEncoder {
|
||||
|
||||
template <class F, class... Args>
|
||||
void dispatch(F&& f, Args&&... args) {
|
||||
num_ops_ = (num_ops_ + 1) % DISPATCHES_PER_TASK;
|
||||
auto task = std::bind(std::forward<F>(f), std::forward<Args>(args)...);
|
||||
scheduler::enqueue(stream_, std::move(task));
|
||||
if (num_ops_ == 0) {
|
||||
scheduler::notify_new_task(stream_);
|
||||
auto task_wrap = [s = stream_, task = std::move(task)]() mutable {
|
||||
task();
|
||||
scheduler::notify_task_completion(s);
|
||||
};
|
||||
scheduler::enqueue(stream_, std::move(task_wrap));
|
||||
} else {
|
||||
scheduler::enqueue(stream_, std::move(task));
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
Stream stream_;
|
||||
std::vector<array> temporaries_;
|
||||
int num_ops_{0};
|
||||
};
|
||||
|
||||
CommandEncoder& get_command_encoder(Stream stream);
|
||||
|
||||
Reference in New Issue
Block a user