Read arrays from files faster (#1330)

* read faster

* faster write as well

* set default permission for linux

* comment
This commit is contained in:
Awni Hannun 2024-08-14 20:09:56 -07:00 committed by GitHub
parent 99bb7d3a58
commit d0630ffe8c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 57 additions and 22 deletions

View File

@ -2,6 +2,7 @@
#include <dlfcn.h> #include <dlfcn.h>
#include <filesystem> #include <filesystem>
#include <fstream>
#include <list> #include <list>
#include "mlx/backend/common/compiled.h" #include "mlx/backend/common/compiled.h"

View File

@ -33,7 +33,7 @@ void Load::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 0); assert(inputs.size() == 0);
out.set_data(allocator::malloc_or_wait(out.nbytes())); out.set_data(allocator::malloc_or_wait(out.nbytes()));
reader_->seek(offset_, std::ios_base::beg); reader_->seek(offset_);
reader_->read(out.data<char>(), out.nbytes()); reader_->read(out.data<char>(), out.nbytes());
if (swap_endianness_) { if (swap_endianness_) {

View File

@ -2,6 +2,7 @@
#include <cstdint> #include <cstdint>
#include <cstring> #include <cstring>
#include <fstream>
#include <numeric> #include <numeric>
#include "mlx/io/gguf.h" #include "mlx/io/gguf.h"

View File

@ -2,9 +2,11 @@
#pragma once #pragma once
#include <fstream> #include <fcntl.h>
#include <istream> #include <sys/stat.h>
#include <unistd.h>
#include <memory> #include <memory>
#include <sstream>
namespace mlx::core { namespace mlx::core {
@ -20,6 +22,7 @@ class Reader {
std::ios_base::seekdir way = std::ios_base::beg) = 0; std::ios_base::seekdir way = std::ios_base::beg) = 0;
virtual void read(char* data, size_t n) = 0; virtual void read(char* data, size_t n) = 0;
virtual std::string label() const = 0; virtual std::string label() const = 0;
virtual ~Reader() = default;
}; };
class Writer { class Writer {
@ -32,35 +35,50 @@ class Writer {
std::ios_base::seekdir way = std::ios_base::beg) = 0; std::ios_base::seekdir way = std::ios_base::beg) = 0;
virtual void write(const char* data, size_t n) = 0; virtual void write(const char* data, size_t n) = 0;
virtual std::string label() const = 0; virtual std::string label() const = 0;
virtual ~Writer() = default;
}; };
class FileReader : public Reader { class FileReader : public Reader {
public: public:
explicit FileReader(std::ifstream is)
: is_(std::move(is)), label_("stream") {}
explicit FileReader(std::string file_path) explicit FileReader(std::string file_path)
: is_(std::ifstream(file_path, std::ios::binary)), : fd_(open(file_path.c_str(), O_RDONLY)), label_(std::move(file_path)) {}
label_(std::move(file_path)) {}
~FileReader() override {
close(fd_);
}
bool is_open() const override { bool is_open() const override {
return is_.is_open(); return fd_ > 0;
} }
bool good() const override { bool good() const override {
return is_.good(); return is_open();
} }
size_t tell() override { size_t tell() override {
return is_.tellg(); return lseek(fd_, 0, SEEK_CUR);
} }
void seek(int64_t off, std::ios_base::seekdir way = std::ios_base::beg) void seek(int64_t off, std::ios_base::seekdir way = std::ios_base::beg)
override { override {
is_.seekg(off, way); if (way == std::ios_base::beg) {
lseek(fd_, off, 0);
} else {
lseek(fd_, off, SEEK_CUR);
}
} }
void read(char* data, size_t n) override { void read(char* data, size_t n) override {
is_.read(data, n); while (n != 0) {
auto m = ::read(fd_, data, n);
if (m <= 0) {
std::ostringstream msg;
msg << "[read] Unable to read " << n << " bytes from file.";
throw std::runtime_error(msg.str());
}
data += m;
n -= m;
}
} }
std::string label() const override { std::string label() const override {
@ -68,37 +86,52 @@ class FileReader : public Reader {
} }
private: private:
std::ifstream is_; int fd_;
std::string label_; std::string label_;
}; };
class FileWriter : public Writer { class FileWriter : public Writer {
public: public:
explicit FileWriter(std::ofstream os)
: os_(std::move(os)), label_("stream") {}
explicit FileWriter(std::string file_path) explicit FileWriter(std::string file_path)
: os_(std::ofstream(file_path, std::ios::binary)), : fd_(open(file_path.c_str(), O_CREAT | O_WRONLY | O_TRUNC, 0644)),
label_(std::move(file_path)) {} label_(std::move(file_path)) {}
~FileWriter() override {
close(fd_);
}
bool is_open() const override { bool is_open() const override {
return os_.is_open(); return fd_ >= 0;
} }
bool good() const override { bool good() const override {
return os_.good(); return is_open();
} }
size_t tell() override { size_t tell() override {
return os_.tellp(); return lseek(fd_, 0, SEEK_CUR);
} }
void seek(int64_t off, std::ios_base::seekdir way = std::ios_base::beg) void seek(int64_t off, std::ios_base::seekdir way = std::ios_base::beg)
override { override {
os_.seekp(off, way); if (way == std::ios_base::beg) {
lseek(fd_, off, 0);
} else {
lseek(fd_, off, SEEK_CUR);
}
} }
void write(const char* data, size_t n) override { void write(const char* data, size_t n) override {
os_.write(data, n); while (n != 0) {
auto m = ::write(fd_, data, n);
if (m <= 0) {
std::ostringstream msg;
msg << "[write] Unable to write " << n << " bytes to file.";
throw std::runtime_error(msg.str());
}
data += m;
n -= m;
}
} }
std::string label() const override { std::string label() const override {
@ -106,7 +139,7 @@ class FileWriter : public Writer {
} }
private: private:
std::ofstream os_; int fd_;
std::string label_; std::string label_;
}; };