From dba2bd1105a0e88ff0c9717d6c0382894afee435 Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Thu, 29 Aug 2024 16:05:40 -0700 Subject: [PATCH] Even Even Faster IO (#1374) * even more faster io * make reader pool static * make python reader thread safe * one more optimization --- mlx/backend/metal/primitives.cpp | 15 +++++++++------ mlx/io/load.cpp | 26 ++++++++++++++++++++------ mlx/io/load.h | 13 ++++++------- mlx/io/safetensors.cpp | 2 +- mlx/io/threadpool.h | 27 ++++++++++++++++++++++----- python/src/load.cpp | 25 +++++++++++++++---------- 6 files changed, 73 insertions(+), 35 deletions(-) diff --git a/mlx/backend/metal/primitives.cpp b/mlx/backend/metal/primitives.cpp index 6580dd7a7..8adeb75de 100644 --- a/mlx/backend/metal/primitives.cpp +++ b/mlx/backend/metal/primitives.cpp @@ -202,15 +202,18 @@ void Load::eval_gpu(const std::vector& inputs, array& out) { static Stream io_stream = new_stream(Device::cpu); out.set_data(allocator::malloc_or_wait(out.nbytes())); - auto task = [out = out, - offset = offset_, - reader = reader_, - swap_endianness = swap_endianness_]() mutable { + auto read_task = [out = out, + offset = offset_, + reader = reader_, + swap_endianness = swap_endianness_]() mutable { load(out, offset, reader, swap_endianness); + }; + auto fut = io::thread_pool().enqueue(std::move(read_task)).share(); + auto signal_task = [out = out, fut = std::move(fut)]() { + fut.wait(); out.event().signal(); }; - - scheduler::enqueue(io_stream, std::move(task)); + scheduler::enqueue(io_stream, std::move(signal_task)); auto& d = metal::device(stream().device); d.end_encoding(stream().index); auto command_buffer = d.get_command_buffer(stream().index); diff --git a/mlx/io/load.cpp b/mlx/io/load.cpp index a02864f4d..f2e6f85bd 100644 --- a/mlx/io/load.cpp +++ b/mlx/io/load.cpp @@ -298,11 +298,18 @@ array load(std::shared_ptr in_stream, StreamOrDevice s) { /** Load array from file in .npy format */ array load(std::string file, StreamOrDevice s) { - return load(std::make_shared(std::move(file), 4), s); + return load(std::make_shared(std::move(file)), s); } namespace io { +ThreadPool& thread_pool() { + static ThreadPool pool_{4}; + return pool_; +} + +ThreadPool ParallelFileReader::thread_pool_{4}; + void ParallelFileReader::read(char* data, size_t n) { while (n != 0) { auto m = ::read(fd_, data, std::min(n, static_cast(INT32_MAX))); @@ -330,11 +337,18 @@ void ParallelFileReader::read(char* data, size_t n, size_t offset) { }; std::vector> futs; while (n != 0) { - size_t m = std::min(batch_size_, n); - futs.emplace_back(thread_pool_.enqueue(readfn, offset, m, data)); - data += m; - n -= m; - offset += m; + if (n < batch_size_) { + if (!readfn(offset, n, data)) { + throw std::runtime_error("[read] Unable to read from file."); + } + break; + } else { + size_t m = batch_size_; + futs.emplace_back(thread_pool_.enqueue(readfn, offset, m, data)); + data += m; + n -= m; + offset += m; + } } for (auto& f : futs) { if (!f.get()) { diff --git a/mlx/io/load.h b/mlx/io/load.h index 402e62f0e..e75057f68 100644 --- a/mlx/io/load.h +++ b/mlx/io/load.h @@ -14,6 +14,8 @@ namespace mlx::core { namespace io { +ThreadPool& thread_pool(); + class Reader { public: virtual bool is_open() const = 0; @@ -43,10 +45,8 @@ class Writer { class ParallelFileReader : public Reader { public: - explicit ParallelFileReader(std::string file_path, int num_threads) - : fd_(open(file_path.c_str(), O_RDONLY)), - label_(std::move(file_path)), - thread_pool_(ThreadPool(num_threads)) {} + explicit ParallelFileReader(std::string file_path) + : fd_(open(file_path.c_str(), O_RDONLY)), label_(std::move(file_path)) {} ~ParallelFileReader() override { close(fd_); @@ -79,11 +79,10 @@ class ParallelFileReader : public Reader { } private: - // 4MB - static constexpr size_t batch_size_ = (1 << 22); + static constexpr size_t batch_size_ = 1 << 25; + static ThreadPool thread_pool_; int fd_; std::string label_; - ThreadPool thread_pool_; }; class FileWriter : public Writer { diff --git a/mlx/io/safetensors.cpp b/mlx/io/safetensors.cpp index 4cca8e391..0a41f0826 100644 --- a/mlx/io/safetensors.cpp +++ b/mlx/io/safetensors.cpp @@ -147,7 +147,7 @@ SafetensorsLoad load_safetensors( } SafetensorsLoad load_safetensors(const std::string& file, StreamOrDevice s) { - return load_safetensors(std::make_shared(file, 4), s); + return load_safetensors(std::make_shared(file), s); } void save_safetensors( diff --git a/mlx/io/threadpool.h b/mlx/io/threadpool.h index 02cc0bb60..cdb3f7faf 100644 --- a/mlx/io/threadpool.h +++ b/mlx/io/threadpool.h @@ -1,3 +1,25 @@ +// This code was modified from https://github.com/progschj/ThreadPool +// The original License is copied below: +// +// Copyright (c) 2012 Jakob Progsch, Václav Zeman +// This software is provided 'as-is', without any express or implied +// warranty. In no event will the authors be held liable for any damages +// arising from the use of this software. +// +// Permission is granted to anyone to use this software for any purpose, +// including commercial applications, and to alter it and redistribute it +// freely, subject to the following restrictions: +// +// 1. The origin of this software must not be misrepresented; you must not +// claim that you wrote the original software. If you use this software +// in a product, an acknowledgment in the product documentation would be +// appreciated but is not required. +// +// 2. Altered source versions must be plainly marked as such, and must not be +// misrepresented as being the original software. +// +// 3. This notice may not be removed or altered from any source +// distribution. #pragma once #include @@ -19,12 +41,8 @@ class ThreadPool { ~ThreadPool(); private: - // need to keep track of threads so we can join them std::vector workers; - // the task queue std::queue> tasks; - - // synchronization std::mutex queue_mutex; std::condition_variable condition; bool stop; @@ -63,7 +81,6 @@ auto ThreadPool::enqueue(F&& f, Args&&... args) { std::unique_lock lock(queue_mutex); - // don't allow enqueueing after stopping the pool if (stop) { throw std::runtime_error( "[ThreadPool::enqueue] Not allowed on stopped ThreadPool"); diff --git a/python/src/load.cpp b/python/src/load.cpp index 86e92dc48..84530bd46 100644 --- a/python/src/load.cpp +++ b/python/src/load.cpp @@ -138,6 +138,21 @@ class PyFileReader : public io::Reader { void read(char* data, size_t n) override { nb::gil_scoped_acquire gil; + _read(data, n); + } + + void read(char* data, size_t n, size_t offset) override { + nb::gil_scoped_acquire gil; + seek_func_(offset, (int)std::ios_base::beg); + _read(data, n); + } + + std::string label() const override { + return "python file object"; + } + + private: + void _read(char* data, size_t n) { auto memview = PyMemoryView_FromMemory(data, n, PyBUF_WRITE); nb::object bytes_read = readinto_func_(nb::handle(memview)); @@ -146,16 +161,6 @@ class PyFileReader : public io::Reader { } } - void read(char* data, size_t n, size_t offset) override { - seek(offset); - read(data, n); - } - - std::string label() const override { - return "python file object"; - } - - private: nb::object pyistream_; nb::object readinto_func_; nb::object seek_func_;