MLX
 
Loading...
Searching...
No Matches
scheduler.h
Go to the documentation of this file.
1// Copyright © 2023 Apple Inc.
2
3#pragma once
4
5#include <atomic>
6#include <future>
7#include <queue>
8#include <thread>
9#include <unordered_map>
10
13#include "mlx/device.h"
14#include "mlx/stream.h"
15
17
19 std::mutex mtx;
20 std::queue<std::function<void()>> q;
21 std::condition_variable cond;
22 bool stop;
24 std::thread thread;
25
30
33 {
34 std::lock_guard<std::mutex> lk(mtx);
35 stop = true;
36 }
37 cond.notify_one();
38 thread.join();
39 }
40
41 void thread_fn() {
42 while (true) {
43 std::function<void()> task;
44 {
45 std::unique_lock<std::mutex> lk(mtx);
46 cond.wait(lk, [this] { return !this->q.empty() || this->stop; });
47 if (q.empty() && stop) {
48 return;
49 }
50 task = std::move(q.front());
51 q.pop();
52 }
53
54 task();
55 }
56 }
57
58 template <typename F>
59 void enqueue(F&& f) {
60 {
61 std::lock_guard<std::mutex> lk(mtx);
62 if (stop) {
63 throw std::runtime_error(
64 "Cannot enqueue work after stream is stopped.");
65 }
66 q.emplace(std::forward<F>(f));
67 }
68 cond.notify_one();
69 }
70};
71
72class Scheduler {
73 public:
74 Scheduler() : n_active_tasks_(0) {
75 if (metal::is_available()) {
76 default_streams_.insert({Device::gpu, new_stream(Device::gpu)});
77 }
78 default_streams_.insert({Device::cpu, new_stream(Device::cpu)});
79 }
80
81 // Not copyable or moveable
82 Scheduler(const Scheduler&) = delete;
83 Scheduler(Scheduler&&) = delete;
84 Scheduler& operator=(const Scheduler&) = delete;
86
88 auto stream = Stream(streams_.size(), d);
89 streams_.push_back(new StreamThread{stream});
90 return stream;
91 }
92
93 template <typename F>
94 void enqueue(const Stream& stream, F&& f);
95
97 return default_streams_.at(d.type);
98 }
99 Stream get_stream(int index) const {
100 return streams_.at(index)->stream;
101 }
102
103 void set_default_stream(const Stream& s) {
104 default_streams_.at(s.device.type) = s;
105 }
106
107 void notify_new_task(const Stream& stream) {
108 {
109 std::lock_guard<std::mutex> lk(mtx);
110 n_active_tasks_++;
111 }
112 completion_cv.notify_all();
113 }
114
115 void notify_task_completion(const Stream& stream) {
116 {
117 std::lock_guard<std::mutex> lk(mtx);
118 n_active_tasks_--;
119 }
120 completion_cv.notify_all();
121 }
122
123 int n_active_tasks() const {
124 return n_active_tasks_;
125 }
126
128 std::unique_lock<std::mutex> lk(mtx);
129 int n_tasks_old = n_active_tasks();
130 if (n_tasks_old > 1) {
131 completion_cv.wait(lk, [this, n_tasks_old] {
132 return this->n_active_tasks() != n_tasks_old;
133 });
134 }
135 }
136
138 for (auto s : streams_) {
139 delete s;
140 }
141 }
142
143 private:
144 int n_active_tasks_;
145 std::vector<StreamThread*> streams_;
146 std::unordered_map<Device::DeviceType, Stream> default_streams_;
147 std::condition_variable completion_cv;
148 std::mutex mtx;
149};
150
151template <typename F>
152void Scheduler::enqueue(const Stream& stream, F&& f) {
153 streams_[stream.index]->enqueue(std::forward<F>(f));
154}
155
157
158template <typename F>
159void enqueue(const Stream& stream, F&& f) {
160 scheduler().enqueue(stream, std::forward<F>(f));
161}
162
163inline int n_active_tasks() {
164 return scheduler().n_active_tasks();
165}
166
167inline void notify_new_task(const Stream& stream) {
168 scheduler().notify_new_task(stream);
169}
170
171inline void notify_task_completion(const Stream& stream) {
173}
174
175inline void wait_for_one() {
177}
178
179} // namespace mlx::core::scheduler
Definition scheduler.h:72
void wait_for_one()
Definition scheduler.h:127
Scheduler & operator=(Scheduler &&)=delete
void enqueue(const Stream &stream, F &&f)
Definition scheduler.h:152
Stream new_stream(const Device &d)
Definition scheduler.h:87
Stream get_default_stream(const Device &d) const
Definition scheduler.h:96
Scheduler()
Definition scheduler.h:74
int n_active_tasks() const
Definition scheduler.h:123
Scheduler(const Scheduler &)=delete
~Scheduler()
Definition scheduler.h:137
void set_default_stream(const Stream &s)
Definition scheduler.h:103
Stream get_stream(int index) const
Definition scheduler.h:99
Scheduler & operator=(const Scheduler &)=delete
void notify_task_completion(const Stream &stream)
Definition scheduler.h:115
Scheduler(Scheduler &&)=delete
void notify_new_task(const Stream &stream)
Definition scheduler.h:107
void new_stream(Stream stream)
Definition scheduler.h:16
void notify_task_completion(const Stream &stream)
Definition scheduler.h:171
void notify_new_task(const Stream &stream)
Definition scheduler.h:167
void wait_for_one()
Definition scheduler.h:175
int n_active_tasks()
Definition scheduler.h:163
void enqueue(const Stream &stream, F &&f)
Definition scheduler.h:159
Scheduler & scheduler()
void synchronize()
Definition device.h:7
static constexpr DeviceType gpu
Definition device.h:14
static constexpr DeviceType cpu
Definition device.h:13
DeviceType type
Definition device.h:18
Definition stream.h:9
Device device
Definition stream.h:11
int index
Definition stream.h:10
Definition scheduler.h:18
void thread_fn()
Definition scheduler.h:41
std::thread thread
Definition scheduler.h:24
bool stop
Definition scheduler.h:22
void enqueue(F &&f)
Definition scheduler.h:59
std::condition_variable cond
Definition scheduler.h:21
std::mutex mtx
Definition scheduler.h:19
~StreamThread()
Definition scheduler.h:31
Stream stream
Definition scheduler.h:23
StreamThread(Stream stream)
Definition scheduler.h:26
std::queue< std::function< void()> > q
Definition scheduler.h:20