From 11354d5bffe01701c23ea0f00886b820f3f2c3d8 Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Fri, 27 Sep 2024 13:32:14 -0700 Subject: [PATCH] Avoid io timeout for large arrays (#1442) --- mlx/backend/metal/primitives.cpp | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/mlx/backend/metal/primitives.cpp b/mlx/backend/metal/primitives.cpp index d9607efce..31f2248d7 100644 --- a/mlx/backend/metal/primitives.cpp +++ b/mlx/backend/metal/primitives.cpp @@ -200,13 +200,19 @@ void Full::eval_gpu(const std::vector& inputs, array& out) { void Load::eval_gpu(const std::vector& inputs, array& out) { out.set_data(allocator::malloc_or_wait(out.nbytes())); - auto read_task = [out = out, offset = offset_, reader = reader_, swap_endianness = swap_endianness_]() mutable { load(out, offset, reader, swap_endianness); }; + + // Limit the size that the command buffer will wait on to avoid timing out + // on the event (<4 seconds). + if (out.nbytes() > (1 << 28)) { + read_task(); + return; + } auto fut = io::thread_pool().enqueue(std::move(read_task)).share(); auto signal_task = [out = out, fut = std::move(fut)]() { fut.wait();