Even Faster I/O (#1369)

* try multithreading for faster IO

* smaller batch size

* Account for pread returning less than size

* nit

---------

Co-authored-by: Angelos Katharopoulos <a_katharopoulos@apple.com>
This commit is contained in:
Awni Hannun
2024-08-28 11:49:07 -07:00
committed by GitHub
parent 4e22a1dffe
commit fcb65a3897
6 changed files with 157 additions and 27 deletions

View File

@@ -8,6 +8,8 @@
#include <memory>
#include <sstream>
#include "mlx/io/threadpool.h"
namespace mlx::core {
namespace io {
@@ -21,6 +23,7 @@ class Reader {
int64_t off,
std::ios_base::seekdir way = std::ios_base::beg) = 0;
virtual void read(char* data, size_t n) = 0;
virtual void read(char* data, size_t n, size_t offset) = 0;
virtual std::string label() const = 0;
virtual ~Reader() = default;
};
@@ -38,12 +41,14 @@ class Writer {
virtual ~Writer() = default;
};
class FileReader : public Reader {
class ParallelFileReader : public Reader {
public:
explicit FileReader(std::string file_path)
: fd_(open(file_path.c_str(), O_RDONLY)), label_(std::move(file_path)) {}
explicit ParallelFileReader(std::string file_path, int num_threads)
: fd_(open(file_path.c_str(), O_RDONLY)),
label_(std::move(file_path)),
thread_pool_(ThreadPool(num_threads)) {}
~FileReader() override {
~ParallelFileReader() override {
close(fd_);
}
@@ -59,35 +64,26 @@ class FileReader : public Reader {
return lseek(fd_, 0, SEEK_CUR);
}
void seek(int64_t off, std::ios_base::seekdir way = std::ios_base::beg)
override {
if (way == std::ios_base::beg) {
lseek(fd_, off, 0);
} else {
lseek(fd_, off, SEEK_CUR);
}
void seek(int64_t, std::ios_base::seekdir = std::ios_base::beg) override {
throw std::runtime_error("[ParallelFileReader::seek] Not allowed");
}
void read(char* data, size_t n) override {
while (n != 0) {
auto m = ::read(fd_, data, std::min(n, static_cast<size_t>(INT32_MAX)));
if (m <= 0) {
std::ostringstream msg;
msg << "[read] Unable to read " << n << " bytes from file.";
throw std::runtime_error(msg.str());
}
data += m;
n -= m;
}
}
// Warning: do not use this function from multiple threads as
// it advances the file descriptor
void read(char* data, size_t n) override;
void read(char* data, size_t n, size_t offset) override;
std::string label() const override {
return "file " + label_;
}
private:
// 4MB
static constexpr size_t batch_size_ = (1 << 22);
int fd_;
std::string label_;
ThreadPool thread_pool_;
};
class FileWriter : public Writer {