Read arrays from files faster (#1330)

* read faster

* faster write as well

* set default permission for linux

* comment
This commit is contained in:
Awni Hannun 2024-08-14 20:09:56 -07:00 committed by GitHub
parent 99bb7d3a58
commit d0630ffe8c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 57 additions and 22 deletions

View File

@ -2,6 +2,7 @@
#include <dlfcn.h>
#include <filesystem>
#include <fstream>
#include <list>
#include "mlx/backend/common/compiled.h"

View File

@ -33,7 +33,7 @@ void Load::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 0);
out.set_data(allocator::malloc_or_wait(out.nbytes()));
reader_->seek(offset_, std::ios_base::beg);
reader_->seek(offset_);
reader_->read(out.data<char>(), out.nbytes());
if (swap_endianness_) {

View File

@ -2,6 +2,7 @@
#include <cstdint>
#include <cstring>
#include <fstream>
#include <numeric>
#include "mlx/io/gguf.h"

View File

@ -2,9 +2,11 @@
#pragma once
#include <fstream>
#include <istream>
#include <fcntl.h>
#include <sys/stat.h>
#include <unistd.h>
#include <memory>
#include <sstream>
namespace mlx::core {
@ -20,6 +22,7 @@ class Reader {
std::ios_base::seekdir way = std::ios_base::beg) = 0;
virtual void read(char* data, size_t n) = 0;
virtual std::string label() const = 0;
virtual ~Reader() = default;
};
class Writer {
@ -32,35 +35,50 @@ class Writer {
std::ios_base::seekdir way = std::ios_base::beg) = 0;
virtual void write(const char* data, size_t n) = 0;
virtual std::string label() const = 0;
virtual ~Writer() = default;
};
class FileReader : public Reader {
public:
explicit FileReader(std::ifstream is)
: is_(std::move(is)), label_("stream") {}
explicit FileReader(std::string file_path)
: is_(std::ifstream(file_path, std::ios::binary)),
label_(std::move(file_path)) {}
: fd_(open(file_path.c_str(), O_RDONLY)), label_(std::move(file_path)) {}
~FileReader() override {
close(fd_);
}
bool is_open() const override {
return is_.is_open();
return fd_ > 0;
}
bool good() const override {
return is_.good();
return is_open();
}
size_t tell() override {
return is_.tellg();
return lseek(fd_, 0, SEEK_CUR);
}
void seek(int64_t off, std::ios_base::seekdir way = std::ios_base::beg)
override {
is_.seekg(off, way);
if (way == std::ios_base::beg) {
lseek(fd_, off, 0);
} else {
lseek(fd_, off, SEEK_CUR);
}
}
void read(char* data, size_t n) override {
is_.read(data, n);
while (n != 0) {
auto m = ::read(fd_, data, n);
if (m <= 0) {
std::ostringstream msg;
msg << "[read] Unable to read " << n << " bytes from file.";
throw std::runtime_error(msg.str());
}
data += m;
n -= m;
}
}
std::string label() const override {
@ -68,37 +86,52 @@ class FileReader : public Reader {
}
private:
std::ifstream is_;
int fd_;
std::string label_;
};
class FileWriter : public Writer {
public:
explicit FileWriter(std::ofstream os)
: os_(std::move(os)), label_("stream") {}
explicit FileWriter(std::string file_path)
: os_(std::ofstream(file_path, std::ios::binary)),
: fd_(open(file_path.c_str(), O_CREAT | O_WRONLY | O_TRUNC, 0644)),
label_(std::move(file_path)) {}
~FileWriter() override {
close(fd_);
}
bool is_open() const override {
return os_.is_open();
return fd_ >= 0;
}
bool good() const override {
return os_.good();
return is_open();
}
size_t tell() override {
return os_.tellp();
return lseek(fd_, 0, SEEK_CUR);
}
void seek(int64_t off, std::ios_base::seekdir way = std::ios_base::beg)
override {
os_.seekp(off, way);
if (way == std::ios_base::beg) {
lseek(fd_, off, 0);
} else {
lseek(fd_, off, SEEK_CUR);
}
}
void write(const char* data, size_t n) override {
os_.write(data, n);
while (n != 0) {
auto m = ::write(fd_, data, n);
if (m <= 0) {
std::ostringstream msg;
msg << "[write] Unable to write " << n << " bytes to file.";
throw std::runtime_error(msg.str());
}
data += m;
n -= m;
}
}
std::string label() const override {
@ -106,7 +139,7 @@ class FileWriter : public Writer {
}
private:
std::ofstream os_;
int fd_;
std::string label_;
};