mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-24 09:21:16 +08:00
Fix multistream GPU deadlock (#1969)
* fix multistream GPU deadlock * comments
This commit is contained in:
parent
95e335db7b
commit
3c164fca8c
@ -9,6 +9,9 @@
|
||||
|
||||
namespace mlx::core::cpu {
|
||||
|
||||
// Number of dispatches per scheduler task
|
||||
constexpr int DISPATCHES_PER_TASK = 10;
|
||||
|
||||
struct CommandEncoder {
|
||||
CommandEncoder(Stream stream) : stream_(stream) {}
|
||||
|
||||
@ -39,13 +42,24 @@ struct CommandEncoder {
|
||||
|
||||
template <class F, class... Args>
|
||||
void dispatch(F&& f, Args&&... args) {
|
||||
num_ops_ = (num_ops_ + 1) % DISPATCHES_PER_TASK;
|
||||
auto task = std::bind(std::forward<F>(f), std::forward<Args>(args)...);
|
||||
scheduler::enqueue(stream_, std::move(task));
|
||||
if (num_ops_ == 0) {
|
||||
scheduler::notify_new_task(stream_);
|
||||
auto task_wrap = [s = stream_, task = std::move(task)]() mutable {
|
||||
task();
|
||||
scheduler::notify_task_completion(s);
|
||||
};
|
||||
scheduler::enqueue(stream_, std::move(task_wrap));
|
||||
} else {
|
||||
scheduler::enqueue(stream_, std::move(task));
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
Stream stream_;
|
||||
std::vector<array> temporaries_;
|
||||
int num_ops_{0};
|
||||
};
|
||||
|
||||
CommandEncoder& get_command_encoder(Stream stream);
|
||||
|
@ -33,12 +33,8 @@ void eval(array& arr) {
|
||||
buffers.erase(it);
|
||||
}
|
||||
auto& encoder = cpu::get_command_encoder(s);
|
||||
scheduler::notify_new_task(s);
|
||||
encoder.dispatch([s,
|
||||
buffers = std::move(buffers),
|
||||
temps = std::move(encoder.temporaries())]() {
|
||||
scheduler::notify_task_completion(s);
|
||||
});
|
||||
encoder.dispatch([buffers = std::move(buffers),
|
||||
temps = std::move(encoder.temporaries())]() {});
|
||||
}
|
||||
|
||||
} // namespace mlx::core::cpu
|
||||
|
@ -19,9 +19,6 @@ namespace mlx::core::metal {
|
||||
|
||||
namespace {
|
||||
|
||||
// TODO nicer way to set this or possibly expose as an environment variable
|
||||
constexpr int MAX_BUFFERS_PER_QUEUE = 12;
|
||||
|
||||
constexpr const char* default_mtllib_path = METAL_PATH;
|
||||
|
||||
auto get_metal_version() {
|
||||
@ -256,7 +253,7 @@ Device::~Device() {
|
||||
|
||||
void Device::new_queue(int index) {
|
||||
auto thread_pool = metal::new_scoped_memory_pool();
|
||||
auto q = device_->newCommandQueue(MAX_BUFFERS_PER_QUEUE);
|
||||
auto q = device_->newCommandQueue();
|
||||
debug_set_stream_queue_label(q, index);
|
||||
if (!q) {
|
||||
throw std::runtime_error(
|
||||
|
@ -75,11 +75,7 @@ void finalize(Stream s) {
|
||||
auto& d = metal::device(s.device);
|
||||
auto cb = d.get_command_buffer(s.index);
|
||||
d.end_encoding(s.index);
|
||||
scheduler::notify_new_task(s);
|
||||
cb->addCompletedHandler([s](MTL::CommandBuffer* cbuf) {
|
||||
scheduler::notify_task_completion(s);
|
||||
check_error(cbuf);
|
||||
});
|
||||
cb->addCompletedHandler([s](MTL::CommandBuffer* cbuf) { check_error(cbuf); });
|
||||
d.commit_command_buffer(s.index);
|
||||
d.get_command_buffer(s.index);
|
||||
}
|
||||
|
@ -21,7 +21,7 @@
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
static constexpr int MAX_ACTIVE_TASKS = 100;
|
||||
static constexpr int MAX_ACTIVE_TASKS = 10;
|
||||
|
||||
/* This class is only meant to be used in eval
|
||||
* for synchronizing with the main thread. */
|
||||
|
@ -174,6 +174,17 @@ class TestEval(mlx_tests.MLXTestCase):
|
||||
post = mx.metal.get_peak_memory()
|
||||
self.assertEqual(pre, post)
|
||||
|
||||
@unittest.skipIf(not mx.metal.is_available(), "Metal is not available")
|
||||
def test_multistream_deadlock(self):
|
||||
s1 = mx.default_stream(mx.gpu)
|
||||
s2 = mx.new_stream(mx.gpu)
|
||||
|
||||
x = mx.array(1.0)
|
||||
x = mx.abs(x, stream=s1)
|
||||
for _ in range(1000):
|
||||
x = mx.abs(x, stream=s2)
|
||||
mx.eval(x)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
Loading…
Reference in New Issue
Block a user