From fcb65a38972bb02ba49cbb0363e50188eee3c2f7 Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Wed, 28 Aug 2024 11:49:07 -0700 Subject: [PATCH] Even Faster I/O (#1369) * try multithreading for faster IO * smaller batch size * Account for pread returning less than size * nit --------- Co-authored-by: Angelos Katharopoulos --- mlx/backend/common/load.cpp | 3 +- mlx/io/load.cpp | 46 +++++++++++++++++++- mlx/io/load.h | 42 ++++++++---------- mlx/io/safetensors.cpp | 2 +- mlx/io/threadpool.h | 86 +++++++++++++++++++++++++++++++++++++ python/src/load.cpp | 5 +++ 6 files changed, 157 insertions(+), 27 deletions(-) create mode 100644 mlx/io/threadpool.h diff --git a/mlx/backend/common/load.cpp b/mlx/backend/common/load.cpp index 98278b3dc..c4130233c 100644 --- a/mlx/backend/common/load.cpp +++ b/mlx/backend/common/load.cpp @@ -33,8 +33,7 @@ void Load::eval(const std::vector& inputs, array& out) { assert(inputs.size() == 0); out.set_data(allocator::malloc_or_wait(out.nbytes())); - reader_->seek(offset_); - reader_->read(out.data(), out.nbytes()); + reader_->read(out.data(), out.nbytes(), offset_); if (swap_endianness_) { switch (out.itemsize()) { diff --git a/mlx/io/load.cpp b/mlx/io/load.cpp index d9403e4b2..a02864f4d 100644 --- a/mlx/io/load.cpp +++ b/mlx/io/load.cpp @@ -298,7 +298,51 @@ array load(std::shared_ptr in_stream, StreamOrDevice s) { /** Load array from file in .npy format */ array load(std::string file, StreamOrDevice s) { - return load(std::make_shared(std::move(file)), s); + return load(std::make_shared(std::move(file), 4), s); } +namespace io { + +void ParallelFileReader::read(char* data, size_t n) { + while (n != 0) { + auto m = ::read(fd_, data, std::min(n, static_cast(INT32_MAX))); + if (m <= 0) { + std::ostringstream msg; + msg << "[read] Unable to read " << n << " bytes from file."; + throw std::runtime_error(msg.str()); + } + data += m; + n -= m; + } +} + +void ParallelFileReader::read(char* data, size_t n, size_t offset) { + auto readfn = [fd = fd_](size_t offset, size_t size, char* buffer) -> bool { + while (size != 0) { + auto m = pread(fd, buffer, size, offset); + if (m <= 0) { + return false; + } + buffer += m; + size -= m; + } + return true; + }; + std::vector> futs; + while (n != 0) { + size_t m = std::min(batch_size_, n); + futs.emplace_back(thread_pool_.enqueue(readfn, offset, m, data)); + data += m; + n -= m; + offset += m; + } + for (auto& f : futs) { + if (!f.get()) { + throw std::runtime_error("[read] Unable to read from file."); + } + } +} + +} // namespace io + } // namespace mlx::core diff --git a/mlx/io/load.h b/mlx/io/load.h index ec9026c4f..402e62f0e 100644 --- a/mlx/io/load.h +++ b/mlx/io/load.h @@ -8,6 +8,8 @@ #include #include +#include "mlx/io/threadpool.h" + namespace mlx::core { namespace io { @@ -21,6 +23,7 @@ class Reader { int64_t off, std::ios_base::seekdir way = std::ios_base::beg) = 0; virtual void read(char* data, size_t n) = 0; + virtual void read(char* data, size_t n, size_t offset) = 0; virtual std::string label() const = 0; virtual ~Reader() = default; }; @@ -38,12 +41,14 @@ class Writer { virtual ~Writer() = default; }; -class FileReader : public Reader { +class ParallelFileReader : public Reader { public: - explicit FileReader(std::string file_path) - : fd_(open(file_path.c_str(), O_RDONLY)), label_(std::move(file_path)) {} + explicit ParallelFileReader(std::string file_path, int num_threads) + : fd_(open(file_path.c_str(), O_RDONLY)), + label_(std::move(file_path)), + thread_pool_(ThreadPool(num_threads)) {} - ~FileReader() override { + ~ParallelFileReader() override { close(fd_); } @@ -59,35 +64,26 @@ class FileReader : public Reader { return lseek(fd_, 0, SEEK_CUR); } - void seek(int64_t off, std::ios_base::seekdir way = std::ios_base::beg) - override { - if (way == std::ios_base::beg) { - lseek(fd_, off, 0); - } else { - lseek(fd_, off, SEEK_CUR); - } + void seek(int64_t, std::ios_base::seekdir = std::ios_base::beg) override { + throw std::runtime_error("[ParallelFileReader::seek] Not allowed"); } - void read(char* data, size_t n) override { - while (n != 0) { - auto m = ::read(fd_, data, std::min(n, static_cast(INT32_MAX))); - if (m <= 0) { - std::ostringstream msg; - msg << "[read] Unable to read " << n << " bytes from file."; - throw std::runtime_error(msg.str()); - } - data += m; - n -= m; - } - } + // Warning: do not use this function from multiple threads as + // it advances the file descriptor + void read(char* data, size_t n) override; + + void read(char* data, size_t n, size_t offset) override; std::string label() const override { return "file " + label_; } private: + // 4MB + static constexpr size_t batch_size_ = (1 << 22); int fd_; std::string label_; + ThreadPool thread_pool_; }; class FileWriter : public Writer { diff --git a/mlx/io/safetensors.cpp b/mlx/io/safetensors.cpp index 76d8151b9..4cca8e391 100644 --- a/mlx/io/safetensors.cpp +++ b/mlx/io/safetensors.cpp @@ -147,7 +147,7 @@ SafetensorsLoad load_safetensors( } SafetensorsLoad load_safetensors(const std::string& file, StreamOrDevice s) { - return load_safetensors(std::make_shared(file), s); + return load_safetensors(std::make_shared(file, 4), s); } void save_safetensors( diff --git a/mlx/io/threadpool.h b/mlx/io/threadpool.h new file mode 100644 index 000000000..02cc0bb60 --- /dev/null +++ b/mlx/io/threadpool.h @@ -0,0 +1,86 @@ +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +class ThreadPool { + public: + ThreadPool(size_t); + template + auto enqueue(F&& f, Args&&... args) + -> std::future>; + ~ThreadPool(); + + private: + // need to keep track of threads so we can join them + std::vector workers; + // the task queue + std::queue> tasks; + + // synchronization + std::mutex queue_mutex; + std::condition_variable condition; + bool stop; +}; + +inline ThreadPool::ThreadPool(size_t threads) : stop(false) { + for (size_t i = 0; i < threads; ++i) + workers.emplace_back([this] { + for (;;) { + std::function task; + + { + std::unique_lock lock(this->queue_mutex); + this->condition.wait( + lock, [this] { return this->stop || !this->tasks.empty(); }); + if (this->stop && this->tasks.empty()) + return; + task = std::move(this->tasks.front()); + this->tasks.pop(); + } + + task(); + } + }); +} + +template +auto ThreadPool::enqueue(F&& f, Args&&... args) + -> std::future> { + using return_type = typename std::result_of_t; + + auto task = std::make_shared>( + std::bind(std::forward(f), std::forward(args)...)); + + std::future res = task->get_future(); + { + std::unique_lock lock(queue_mutex); + + // don't allow enqueueing after stopping the pool + if (stop) { + throw std::runtime_error( + "[ThreadPool::enqueue] Not allowed on stopped ThreadPool"); + } + + tasks.emplace([task]() { (*task)(); }); + } + condition.notify_one(); + return res; +} + +inline ThreadPool::~ThreadPool() { + { + std::unique_lock lock(queue_mutex); + stop = true; + } + condition.notify_all(); + for (std::thread& worker : workers) + worker.join(); +} diff --git a/python/src/load.cpp b/python/src/load.cpp index efad3d97d..86e92dc48 100644 --- a/python/src/load.cpp +++ b/python/src/load.cpp @@ -146,6 +146,11 @@ class PyFileReader : public io::Reader { } } + void read(char* data, size_t n, size_t offset) override { + seek(offset); + read(data, n); + } + std::string label() const override { return "python file object"; }