MLX
 
Loading...
Searching...
No Matches
scheduler.h
Go to the documentation of this file.
1// Copyright © 2023 Apple Inc.
2
3#pragma once
4
5#include <atomic>
6#include <future>
7#include <queue>
8#include <thread>
9#include <unordered_map>
10
13#include "mlx/device.h"
14#include "mlx/stream.h"
15
17
19 std::mutex mtx;
20 std::queue<std::function<void()>> q;
21 std::condition_variable cond;
22 bool stop;
23 std::thread thread;
24
26
28 {
29 std::lock_guard<std::mutex> lk(mtx);
30 stop = true;
31 }
32 cond.notify_one();
33 thread.join();
34 }
35
36 void thread_fn() {
37 while (true) {
38 std::function<void()> task;
39 {
40 std::unique_lock<std::mutex> lk(mtx);
41 cond.wait(lk, [this] { return !this->q.empty() || this->stop; });
42 if (q.empty() && stop) {
43 return;
44 }
45 task = std::move(q.front());
46 q.pop();
47 }
48
49 task();
50 }
51 }
52
53 template <typename F>
54 void enqueue(F&& f) {
55 {
56 std::lock_guard<std::mutex> lk(mtx);
57 if (stop) {
58 throw std::runtime_error(
59 "Cannot enqueue work after stream is stopped.");
60 }
61 q.emplace(std::forward<F>(f));
62 }
63 cond.notify_one();
64 }
65};
66
67class Scheduler {
68 public:
69 Scheduler() : n_active_tasks_(0) {
70 if (metal::is_available()) {
71 default_streams_.insert({Device::gpu, new_stream(Device::gpu)});
72 }
73 default_streams_.insert({Device::cpu, new_stream(Device::cpu)});
74 }
75
76 // Not copyable or moveable
77 Scheduler(const Scheduler&) = delete;
78 Scheduler(Scheduler&&) = delete;
79 Scheduler& operator=(const Scheduler&) = delete;
81
83 streams_.emplace_back(streams_.size(), d);
84 if (d == Device::gpu) {
85 threads_.push_back(nullptr);
86 metal::new_stream(streams_.back());
87 } else {
88 threads_.push_back(new StreamThread{});
89 }
90 return streams_.back();
91 }
92
93 template <typename F>
94 void enqueue(const Stream& stream, F&& f);
95
97 return default_streams_.at(d.type);
98 }
99 Stream get_stream(int index) const {
100 return streams_.at(index);
101 }
102
103 void set_default_stream(const Stream& s) {
104 default_streams_.at(s.device.type) = s;
105 }
106
107 void notify_new_task(const Stream& stream) {
108 {
109 std::lock_guard<std::mutex> lk(mtx);
110 n_active_tasks_++;
111 }
112 completion_cv.notify_all();
113 }
114
115 void notify_task_completion(const Stream& stream) {
116 {
117 std::lock_guard<std::mutex> lk(mtx);
118 n_active_tasks_--;
119 }
120 completion_cv.notify_all();
121 }
122
123 int n_active_tasks() const {
124 return n_active_tasks_;
125 }
126
128 std::unique_lock<std::mutex> lk(mtx);
129 int n_tasks_old = n_active_tasks();
130 if (n_tasks_old > 1) {
131 completion_cv.wait(lk, [this, n_tasks_old] {
132 return this->n_active_tasks() != n_tasks_old;
133 });
134 }
135 }
136
138 for (auto s : streams_) {
139 synchronize(s);
140 }
141 for (auto t : threads_) {
142 if (t != nullptr) {
143 delete t;
144 }
145 }
146 }
147
148 private:
149 int n_active_tasks_;
150 std::vector<StreamThread*> threads_;
151 std::vector<Stream> streams_;
152 std::unordered_map<Device::DeviceType, Stream> default_streams_;
153 std::condition_variable completion_cv;
154 std::mutex mtx;
155};
156
157template <typename F>
158void Scheduler::enqueue(const Stream& stream, F&& f) {
159 threads_[stream.index]->enqueue(std::forward<F>(f));
160}
161
163
164template <typename F>
165void enqueue(const Stream& stream, F&& f) {
166 scheduler().enqueue(stream, std::forward<F>(f));
167}
168
169inline int n_active_tasks() {
170 return scheduler().n_active_tasks();
171}
172
173inline void notify_new_task(const Stream& stream) {
174 scheduler().notify_new_task(stream);
175}
176
177inline void notify_task_completion(const Stream& stream) {
179}
180
181inline void wait_for_one() {
183}
184
185} // namespace mlx::core::scheduler
Definition scheduler.h:67
void wait_for_one()
Definition scheduler.h:127
Scheduler & operator=(Scheduler &&)=delete
void enqueue(const Stream &stream, F &&f)
Definition scheduler.h:158
Stream new_stream(const Device &d)
Definition scheduler.h:82
Stream get_default_stream(const Device &d) const
Definition scheduler.h:96
Scheduler()
Definition scheduler.h:69
int n_active_tasks() const
Definition scheduler.h:123
Scheduler(const Scheduler &)=delete
~Scheduler()
Definition scheduler.h:137
void set_default_stream(const Stream &s)
Definition scheduler.h:103
Stream get_stream(int index) const
Definition scheduler.h:99
Scheduler & operator=(const Scheduler &)=delete
void notify_task_completion(const Stream &stream)
Definition scheduler.h:115
Scheduler(Scheduler &&)=delete
void notify_new_task(const Stream &stream)
Definition scheduler.h:107
void new_stream(Stream stream)
Definition scheduler.h:16
void notify_task_completion(const Stream &stream)
Definition scheduler.h:177
void notify_new_task(const Stream &stream)
Definition scheduler.h:173
void wait_for_one()
Definition scheduler.h:181
int n_active_tasks()
Definition scheduler.h:169
void enqueue(const Stream &stream, F &&f)
Definition scheduler.h:165
Scheduler & scheduler()
void synchronize()
Definition device.h:7
static constexpr DeviceType gpu
Definition device.h:14
static constexpr DeviceType cpu
Definition device.h:13
DeviceType type
Definition device.h:18
Definition stream.h:9
Device device
Definition stream.h:11
int index
Definition stream.h:10
Definition scheduler.h:18
void thread_fn()
Definition scheduler.h:36
StreamThread()
Definition scheduler.h:25
std::thread thread
Definition scheduler.h:23
bool stop
Definition scheduler.h:22
void enqueue(F &&f)
Definition scheduler.h:54
std::condition_variable cond
Definition scheduler.h:21
std::mutex mtx
Definition scheduler.h:19
~StreamThread()
Definition scheduler.h:27
std::queue< std::function< void()> > q
Definition scheduler.h:20