From d0630ffe8c95832454a54e55f83ddbdacc64daf8 Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Wed, 14 Aug 2024 20:09:56 -0700 Subject: [PATCH] Read arrays from files faster (#1330) * read faster * faster write as well * set default permission for linux * comment --- mlx/backend/common/compiled_cpu.cpp | 1 + mlx/backend/common/load.cpp | 2 +- mlx/io/gguf.cpp | 1 + mlx/io/load.h | 75 +++++++++++++++++++++-------- 4 files changed, 57 insertions(+), 22 deletions(-) diff --git a/mlx/backend/common/compiled_cpu.cpp b/mlx/backend/common/compiled_cpu.cpp index 40acb74b53..13f2233ad9 100644 --- a/mlx/backend/common/compiled_cpu.cpp +++ b/mlx/backend/common/compiled_cpu.cpp @@ -2,6 +2,7 @@ #include #include +#include #include #include "mlx/backend/common/compiled.h" diff --git a/mlx/backend/common/load.cpp b/mlx/backend/common/load.cpp index 91f4cee62d..98278b3dc7 100644 --- a/mlx/backend/common/load.cpp +++ b/mlx/backend/common/load.cpp @@ -33,7 +33,7 @@ void Load::eval(const std::vector& inputs, array& out) { assert(inputs.size() == 0); out.set_data(allocator::malloc_or_wait(out.nbytes())); - reader_->seek(offset_, std::ios_base::beg); + reader_->seek(offset_); reader_->read(out.data(), out.nbytes()); if (swap_endianness_) { diff --git a/mlx/io/gguf.cpp b/mlx/io/gguf.cpp index 934bcee82f..be1f382bae 100644 --- a/mlx/io/gguf.cpp +++ b/mlx/io/gguf.cpp @@ -2,6 +2,7 @@ #include #include +#include #include #include "mlx/io/gguf.h" diff --git a/mlx/io/load.h b/mlx/io/load.h index 637df1b199..565f748d8a 100644 --- a/mlx/io/load.h +++ b/mlx/io/load.h @@ -2,9 +2,11 @@ #pragma once -#include -#include +#include +#include +#include #include +#include namespace mlx::core { @@ -20,6 +22,7 @@ class Reader { std::ios_base::seekdir way = std::ios_base::beg) = 0; virtual void read(char* data, size_t n) = 0; virtual std::string label() const = 0; + virtual ~Reader() = default; }; class Writer { @@ -32,35 +35,50 @@ class Writer { std::ios_base::seekdir way = std::ios_base::beg) = 0; virtual void write(const char* data, size_t n) = 0; virtual std::string label() const = 0; + virtual ~Writer() = default; }; class FileReader : public Reader { public: - explicit FileReader(std::ifstream is) - : is_(std::move(is)), label_("stream") {} explicit FileReader(std::string file_path) - : is_(std::ifstream(file_path, std::ios::binary)), - label_(std::move(file_path)) {} + : fd_(open(file_path.c_str(), O_RDONLY)), label_(std::move(file_path)) {} + + ~FileReader() override { + close(fd_); + } bool is_open() const override { - return is_.is_open(); + return fd_ > 0; } bool good() const override { - return is_.good(); + return is_open(); } size_t tell() override { - return is_.tellg(); + return lseek(fd_, 0, SEEK_CUR); } void seek(int64_t off, std::ios_base::seekdir way = std::ios_base::beg) override { - is_.seekg(off, way); + if (way == std::ios_base::beg) { + lseek(fd_, off, 0); + } else { + lseek(fd_, off, SEEK_CUR); + } } void read(char* data, size_t n) override { - is_.read(data, n); + while (n != 0) { + auto m = ::read(fd_, data, n); + if (m <= 0) { + std::ostringstream msg; + msg << "[read] Unable to read " << n << " bytes from file."; + throw std::runtime_error(msg.str()); + } + data += m; + n -= m; + } } std::string label() const override { @@ -68,37 +86,52 @@ class FileReader : public Reader { } private: - std::ifstream is_; + int fd_; std::string label_; }; class FileWriter : public Writer { public: - explicit FileWriter(std::ofstream os) - : os_(std::move(os)), label_("stream") {} explicit FileWriter(std::string file_path) - : os_(std::ofstream(file_path, std::ios::binary)), + : fd_(open(file_path.c_str(), O_CREAT | O_WRONLY | O_TRUNC, 0644)), label_(std::move(file_path)) {} + ~FileWriter() override { + close(fd_); + } + bool is_open() const override { - return os_.is_open(); + return fd_ >= 0; } bool good() const override { - return os_.good(); + return is_open(); } size_t tell() override { - return os_.tellp(); + return lseek(fd_, 0, SEEK_CUR); } void seek(int64_t off, std::ios_base::seekdir way = std::ios_base::beg) override { - os_.seekp(off, way); + if (way == std::ios_base::beg) { + lseek(fd_, off, 0); + } else { + lseek(fd_, off, SEEK_CUR); + } } void write(const char* data, size_t n) override { - os_.write(data, n); + while (n != 0) { + auto m = ::write(fd_, data, n); + if (m <= 0) { + std::ostringstream msg; + msg << "[write] Unable to write " << n << " bytes to file."; + throw std::runtime_error(msg.str()); + } + data += m; + n -= m; + } } std::string label() const override { @@ -106,7 +139,7 @@ class FileWriter : public Writer { } private: - std::ofstream os_; + int fd_; std::string label_; };