diff --git a/mlx/backend/cpu/encoder.h b/mlx/backend/cpu/encoder.h index aae64fb5e..b8e33ca81 100644 --- a/mlx/backend/cpu/encoder.h +++ b/mlx/backend/cpu/encoder.h @@ -9,6 +9,9 @@ namespace mlx::core::cpu { +// Number of dispatches per scheduler task +constexpr int DISPATCHES_PER_TASK = 10; + struct CommandEncoder { CommandEncoder(Stream stream) : stream_(stream) {} @@ -39,13 +42,24 @@ struct CommandEncoder { template void dispatch(F&& f, Args&&... args) { + num_ops_ = (num_ops_ + 1) % DISPATCHES_PER_TASK; auto task = std::bind(std::forward(f), std::forward(args)...); - scheduler::enqueue(stream_, std::move(task)); + if (num_ops_ == 0) { + scheduler::notify_new_task(stream_); + auto task_wrap = [s = stream_, task = std::move(task)]() mutable { + task(); + scheduler::notify_task_completion(s); + }; + scheduler::enqueue(stream_, std::move(task_wrap)); + } else { + scheduler::enqueue(stream_, std::move(task)); + } } private: Stream stream_; std::vector temporaries_; + int num_ops_{0}; }; CommandEncoder& get_command_encoder(Stream stream); diff --git a/mlx/backend/cpu/eval.cpp b/mlx/backend/cpu/eval.cpp index 04811e737..b23c8d561 100644 --- a/mlx/backend/cpu/eval.cpp +++ b/mlx/backend/cpu/eval.cpp @@ -33,12 +33,8 @@ void eval(array& arr) { buffers.erase(it); } auto& encoder = cpu::get_command_encoder(s); - scheduler::notify_new_task(s); - encoder.dispatch([s, - buffers = std::move(buffers), - temps = std::move(encoder.temporaries())]() { - scheduler::notify_task_completion(s); - }); + encoder.dispatch([buffers = std::move(buffers), + temps = std::move(encoder.temporaries())]() {}); } } // namespace mlx::core::cpu diff --git a/mlx/backend/metal/device.cpp b/mlx/backend/metal/device.cpp index e28989a5c..930e570e2 100644 --- a/mlx/backend/metal/device.cpp +++ b/mlx/backend/metal/device.cpp @@ -19,9 +19,6 @@ namespace mlx::core::metal { namespace { -// TODO nicer way to set this or possibly expose as an environment variable -constexpr int MAX_BUFFERS_PER_QUEUE = 12; - constexpr const char* default_mtllib_path = METAL_PATH; auto get_metal_version() { @@ -256,7 +253,7 @@ Device::~Device() { void Device::new_queue(int index) { auto thread_pool = metal::new_scoped_memory_pool(); - auto q = device_->newCommandQueue(MAX_BUFFERS_PER_QUEUE); + auto q = device_->newCommandQueue(); debug_set_stream_queue_label(q, index); if (!q) { throw std::runtime_error( diff --git a/mlx/backend/metal/metal.cpp b/mlx/backend/metal/metal.cpp index ba13b4b59..a9a1bc4f6 100644 --- a/mlx/backend/metal/metal.cpp +++ b/mlx/backend/metal/metal.cpp @@ -75,11 +75,7 @@ void finalize(Stream s) { auto& d = metal::device(s.device); auto cb = d.get_command_buffer(s.index); d.end_encoding(s.index); - scheduler::notify_new_task(s); - cb->addCompletedHandler([s](MTL::CommandBuffer* cbuf) { - scheduler::notify_task_completion(s); - check_error(cbuf); - }); + cb->addCompletedHandler([s](MTL::CommandBuffer* cbuf) { check_error(cbuf); }); d.commit_command_buffer(s.index); d.get_command_buffer(s.index); } diff --git a/mlx/transforms.cpp b/mlx/transforms.cpp index f01082418..958899bec 100644 --- a/mlx/transforms.cpp +++ b/mlx/transforms.cpp @@ -21,7 +21,7 @@ namespace mlx::core { -static constexpr int MAX_ACTIVE_TASKS = 100; +static constexpr int MAX_ACTIVE_TASKS = 10; /* This class is only meant to be used in eval * for synchronizing with the main thread. */ diff --git a/python/tests/test_eval.py b/python/tests/test_eval.py index 37e31f80b..510402b06 100644 --- a/python/tests/test_eval.py +++ b/python/tests/test_eval.py @@ -174,6 +174,17 @@ class TestEval(mlx_tests.MLXTestCase): post = mx.metal.get_peak_memory() self.assertEqual(pre, post) + @unittest.skipIf(not mx.metal.is_available(), "Metal is not available") + def test_multistream_deadlock(self): + s1 = mx.default_stream(mx.gpu) + s2 = mx.new_stream(mx.gpu) + + x = mx.array(1.0) + x = mx.abs(x, stream=s1) + for _ in range(1000): + x = mx.abs(x, stream=s2) + mx.eval(x) + if __name__ == "__main__": unittest.main()