From b3f52c9fbe044f5886e4b04d9b5631b395baf0ab Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Sat, 14 Sep 2024 06:17:36 -0700 Subject: [PATCH] ensure io/comm streams are active before eval (#1412) --- mlx/backend/metal/primitives.cpp | 3 +-- mlx/distributed/mpi/mpi.cpp | 3 +++ mlx/primitives.h | 12 ++++++++++-- 3 files changed, 14 insertions(+), 4 deletions(-) diff --git a/mlx/backend/metal/primitives.cpp b/mlx/backend/metal/primitives.cpp index 8adeb75de..d9607efce 100644 --- a/mlx/backend/metal/primitives.cpp +++ b/mlx/backend/metal/primitives.cpp @@ -199,7 +199,6 @@ void Full::eval_gpu(const std::vector& inputs, array& out) { } void Load::eval_gpu(const std::vector& inputs, array& out) { - static Stream io_stream = new_stream(Device::cpu); out.set_data(allocator::malloc_or_wait(out.nbytes())); auto read_task = [out = out, @@ -213,7 +212,7 @@ void Load::eval_gpu(const std::vector& inputs, array& out) { fut.wait(); out.event().signal(); }; - scheduler::enqueue(io_stream, std::move(signal_task)); + scheduler::enqueue(io_stream(), std::move(signal_task)); auto& d = metal::device(stream().device); d.end_encoding(stream().index); auto command_buffer = d.get_command_buffer(stream().index); diff --git a/mlx/distributed/mpi/mpi.cpp b/mlx/distributed/mpi/mpi.cpp index 4504ebecb..c232c13e7 100644 --- a/mlx/distributed/mpi/mpi.cpp +++ b/mlx/distributed/mpi/mpi.cpp @@ -255,6 +255,9 @@ Group init(bool strict /* = false */) { } } + // Ensure the communication stream is alive before + // the graph is evaluated + detail::communication_stream(); return Group(global_group); } diff --git a/mlx/primitives.h b/mlx/primitives.h index 065666a34..5e5bda7c0 100644 --- a/mlx/primitives.h +++ b/mlx/primitives.h @@ -1146,9 +1146,13 @@ class Load : public UnaryPrimitive { size_t offset, bool swap_endianness = false) : UnaryPrimitive(stream), - reader_(reader), + reader_(std::move(reader)), offset_(offset), - swap_endianness_(swap_endianness) {} + swap_endianness_(swap_endianness) { + if (stream.device == Device::gpu) { + io_stream(); + } + } void eval_cpu(const std::vector& inputs, array& out) override; void eval_gpu(const std::vector& inputs, array& out) override; @@ -1156,6 +1160,10 @@ class Load : public UnaryPrimitive { DEFINE_PRINT(Load) private: + Stream& io_stream() { + static Stream io_stream = new_stream(Device::cpu); + return io_stream; + }; void eval(const std::vector& inputs, array& out); std::shared_ptr reader_; size_t offset_;