mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-24 09:21:16 +08:00
Fix multistream GPU deadlock (#1969)
* fix multistream GPU deadlock * comments
This commit is contained in:
parent
95e335db7b
commit
3c164fca8c
@ -9,6 +9,9 @@
|
|||||||
|
|
||||||
namespace mlx::core::cpu {
|
namespace mlx::core::cpu {
|
||||||
|
|
||||||
|
// Number of dispatches per scheduler task
|
||||||
|
constexpr int DISPATCHES_PER_TASK = 10;
|
||||||
|
|
||||||
struct CommandEncoder {
|
struct CommandEncoder {
|
||||||
CommandEncoder(Stream stream) : stream_(stream) {}
|
CommandEncoder(Stream stream) : stream_(stream) {}
|
||||||
|
|
||||||
@ -39,13 +42,24 @@ struct CommandEncoder {
|
|||||||
|
|
||||||
template <class F, class... Args>
|
template <class F, class... Args>
|
||||||
void dispatch(F&& f, Args&&... args) {
|
void dispatch(F&& f, Args&&... args) {
|
||||||
|
num_ops_ = (num_ops_ + 1) % DISPATCHES_PER_TASK;
|
||||||
auto task = std::bind(std::forward<F>(f), std::forward<Args>(args)...);
|
auto task = std::bind(std::forward<F>(f), std::forward<Args>(args)...);
|
||||||
scheduler::enqueue(stream_, std::move(task));
|
if (num_ops_ == 0) {
|
||||||
|
scheduler::notify_new_task(stream_);
|
||||||
|
auto task_wrap = [s = stream_, task = std::move(task)]() mutable {
|
||||||
|
task();
|
||||||
|
scheduler::notify_task_completion(s);
|
||||||
|
};
|
||||||
|
scheduler::enqueue(stream_, std::move(task_wrap));
|
||||||
|
} else {
|
||||||
|
scheduler::enqueue(stream_, std::move(task));
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
Stream stream_;
|
Stream stream_;
|
||||||
std::vector<array> temporaries_;
|
std::vector<array> temporaries_;
|
||||||
|
int num_ops_{0};
|
||||||
};
|
};
|
||||||
|
|
||||||
CommandEncoder& get_command_encoder(Stream stream);
|
CommandEncoder& get_command_encoder(Stream stream);
|
||||||
|
@ -33,12 +33,8 @@ void eval(array& arr) {
|
|||||||
buffers.erase(it);
|
buffers.erase(it);
|
||||||
}
|
}
|
||||||
auto& encoder = cpu::get_command_encoder(s);
|
auto& encoder = cpu::get_command_encoder(s);
|
||||||
scheduler::notify_new_task(s);
|
encoder.dispatch([buffers = std::move(buffers),
|
||||||
encoder.dispatch([s,
|
temps = std::move(encoder.temporaries())]() {});
|
||||||
buffers = std::move(buffers),
|
|
||||||
temps = std::move(encoder.temporaries())]() {
|
|
||||||
scheduler::notify_task_completion(s);
|
|
||||||
});
|
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace mlx::core::cpu
|
} // namespace mlx::core::cpu
|
||||||
|
@ -19,9 +19,6 @@ namespace mlx::core::metal {
|
|||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
// TODO nicer way to set this or possibly expose as an environment variable
|
|
||||||
constexpr int MAX_BUFFERS_PER_QUEUE = 12;
|
|
||||||
|
|
||||||
constexpr const char* default_mtllib_path = METAL_PATH;
|
constexpr const char* default_mtllib_path = METAL_PATH;
|
||||||
|
|
||||||
auto get_metal_version() {
|
auto get_metal_version() {
|
||||||
@ -256,7 +253,7 @@ Device::~Device() {
|
|||||||
|
|
||||||
void Device::new_queue(int index) {
|
void Device::new_queue(int index) {
|
||||||
auto thread_pool = metal::new_scoped_memory_pool();
|
auto thread_pool = metal::new_scoped_memory_pool();
|
||||||
auto q = device_->newCommandQueue(MAX_BUFFERS_PER_QUEUE);
|
auto q = device_->newCommandQueue();
|
||||||
debug_set_stream_queue_label(q, index);
|
debug_set_stream_queue_label(q, index);
|
||||||
if (!q) {
|
if (!q) {
|
||||||
throw std::runtime_error(
|
throw std::runtime_error(
|
||||||
|
@ -75,11 +75,7 @@ void finalize(Stream s) {
|
|||||||
auto& d = metal::device(s.device);
|
auto& d = metal::device(s.device);
|
||||||
auto cb = d.get_command_buffer(s.index);
|
auto cb = d.get_command_buffer(s.index);
|
||||||
d.end_encoding(s.index);
|
d.end_encoding(s.index);
|
||||||
scheduler::notify_new_task(s);
|
cb->addCompletedHandler([s](MTL::CommandBuffer* cbuf) { check_error(cbuf); });
|
||||||
cb->addCompletedHandler([s](MTL::CommandBuffer* cbuf) {
|
|
||||||
scheduler::notify_task_completion(s);
|
|
||||||
check_error(cbuf);
|
|
||||||
});
|
|
||||||
d.commit_command_buffer(s.index);
|
d.commit_command_buffer(s.index);
|
||||||
d.get_command_buffer(s.index);
|
d.get_command_buffer(s.index);
|
||||||
}
|
}
|
||||||
|
@ -21,7 +21,7 @@
|
|||||||
|
|
||||||
namespace mlx::core {
|
namespace mlx::core {
|
||||||
|
|
||||||
static constexpr int MAX_ACTIVE_TASKS = 100;
|
static constexpr int MAX_ACTIVE_TASKS = 10;
|
||||||
|
|
||||||
/* This class is only meant to be used in eval
|
/* This class is only meant to be used in eval
|
||||||
* for synchronizing with the main thread. */
|
* for synchronizing with the main thread. */
|
||||||
|
@ -174,6 +174,17 @@ class TestEval(mlx_tests.MLXTestCase):
|
|||||||
post = mx.metal.get_peak_memory()
|
post = mx.metal.get_peak_memory()
|
||||||
self.assertEqual(pre, post)
|
self.assertEqual(pre, post)
|
||||||
|
|
||||||
|
@unittest.skipIf(not mx.metal.is_available(), "Metal is not available")
|
||||||
|
def test_multistream_deadlock(self):
|
||||||
|
s1 = mx.default_stream(mx.gpu)
|
||||||
|
s2 = mx.new_stream(mx.gpu)
|
||||||
|
|
||||||
|
x = mx.array(1.0)
|
||||||
|
x = mx.abs(x, stream=s1)
|
||||||
|
for _ in range(1000):
|
||||||
|
x = mx.abs(x, stream=s2)
|
||||||
|
mx.eval(x)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
unittest.main()
|
unittest.main()
|
||||||
|
Loading…
Reference in New Issue
Block a user