Fix multistream GPU deadlock (#1969)

* fix multistream GPU deadlock

* comments
This commit is contained in:
Awni Hannun 2025-03-20 07:19:47 -07:00 committed by GitHub
parent 95e335db7b
commit 3c164fca8c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 31 additions and 17 deletions

View File

@ -9,6 +9,9 @@
namespace mlx::core::cpu { namespace mlx::core::cpu {
// Number of dispatches per scheduler task
constexpr int DISPATCHES_PER_TASK = 10;
struct CommandEncoder { struct CommandEncoder {
CommandEncoder(Stream stream) : stream_(stream) {} CommandEncoder(Stream stream) : stream_(stream) {}
@ -39,13 +42,24 @@ struct CommandEncoder {
template <class F, class... Args> template <class F, class... Args>
void dispatch(F&& f, Args&&... args) { void dispatch(F&& f, Args&&... args) {
num_ops_ = (num_ops_ + 1) % DISPATCHES_PER_TASK;
auto task = std::bind(std::forward<F>(f), std::forward<Args>(args)...); auto task = std::bind(std::forward<F>(f), std::forward<Args>(args)...);
if (num_ops_ == 0) {
scheduler::notify_new_task(stream_);
auto task_wrap = [s = stream_, task = std::move(task)]() mutable {
task();
scheduler::notify_task_completion(s);
};
scheduler::enqueue(stream_, std::move(task_wrap));
} else {
scheduler::enqueue(stream_, std::move(task)); scheduler::enqueue(stream_, std::move(task));
} }
}
private: private:
Stream stream_; Stream stream_;
std::vector<array> temporaries_; std::vector<array> temporaries_;
int num_ops_{0};
}; };
CommandEncoder& get_command_encoder(Stream stream); CommandEncoder& get_command_encoder(Stream stream);

View File

@ -33,12 +33,8 @@ void eval(array& arr) {
buffers.erase(it); buffers.erase(it);
} }
auto& encoder = cpu::get_command_encoder(s); auto& encoder = cpu::get_command_encoder(s);
scheduler::notify_new_task(s); encoder.dispatch([buffers = std::move(buffers),
encoder.dispatch([s, temps = std::move(encoder.temporaries())]() {});
buffers = std::move(buffers),
temps = std::move(encoder.temporaries())]() {
scheduler::notify_task_completion(s);
});
} }
} // namespace mlx::core::cpu } // namespace mlx::core::cpu

View File

@ -19,9 +19,6 @@ namespace mlx::core::metal {
namespace { namespace {
// TODO nicer way to set this or possibly expose as an environment variable
constexpr int MAX_BUFFERS_PER_QUEUE = 12;
constexpr const char* default_mtllib_path = METAL_PATH; constexpr const char* default_mtllib_path = METAL_PATH;
auto get_metal_version() { auto get_metal_version() {
@ -256,7 +253,7 @@ Device::~Device() {
void Device::new_queue(int index) { void Device::new_queue(int index) {
auto thread_pool = metal::new_scoped_memory_pool(); auto thread_pool = metal::new_scoped_memory_pool();
auto q = device_->newCommandQueue(MAX_BUFFERS_PER_QUEUE); auto q = device_->newCommandQueue();
debug_set_stream_queue_label(q, index); debug_set_stream_queue_label(q, index);
if (!q) { if (!q) {
throw std::runtime_error( throw std::runtime_error(

View File

@ -75,11 +75,7 @@ void finalize(Stream s) {
auto& d = metal::device(s.device); auto& d = metal::device(s.device);
auto cb = d.get_command_buffer(s.index); auto cb = d.get_command_buffer(s.index);
d.end_encoding(s.index); d.end_encoding(s.index);
scheduler::notify_new_task(s); cb->addCompletedHandler([s](MTL::CommandBuffer* cbuf) { check_error(cbuf); });
cb->addCompletedHandler([s](MTL::CommandBuffer* cbuf) {
scheduler::notify_task_completion(s);
check_error(cbuf);
});
d.commit_command_buffer(s.index); d.commit_command_buffer(s.index);
d.get_command_buffer(s.index); d.get_command_buffer(s.index);
} }

View File

@ -21,7 +21,7 @@
namespace mlx::core { namespace mlx::core {
static constexpr int MAX_ACTIVE_TASKS = 100; static constexpr int MAX_ACTIVE_TASKS = 10;
/* This class is only meant to be used in eval /* This class is only meant to be used in eval
* for synchronizing with the main thread. */ * for synchronizing with the main thread. */

View File

@ -174,6 +174,17 @@ class TestEval(mlx_tests.MLXTestCase):
post = mx.metal.get_peak_memory() post = mx.metal.get_peak_memory()
self.assertEqual(pre, post) self.assertEqual(pre, post)
@unittest.skipIf(not mx.metal.is_available(), "Metal is not available")
def test_multistream_deadlock(self):
s1 = mx.default_stream(mx.gpu)
s2 = mx.new_stream(mx.gpu)
x = mx.array(1.0)
x = mx.abs(x, stream=s1)
for _ in range(1000):
x = mx.abs(x, stream=s2)
mx.eval(x)
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()