mirror of
https://github.com/ml-explore/mlx.git
synced 2025-08-10 19:26:42 +08:00
Read arrays from files faster (#1330)
* read faster * faster write as well * set default permission for linux * comment
This commit is contained in:
parent
99bb7d3a58
commit
d0630ffe8c
@ -2,6 +2,7 @@
|
||||
|
||||
#include <dlfcn.h>
|
||||
#include <filesystem>
|
||||
#include <fstream>
|
||||
#include <list>
|
||||
|
||||
#include "mlx/backend/common/compiled.h"
|
||||
|
@ -33,7 +33,7 @@ void Load::eval(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 0);
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
|
||||
reader_->seek(offset_, std::ios_base::beg);
|
||||
reader_->seek(offset_);
|
||||
reader_->read(out.data<char>(), out.nbytes());
|
||||
|
||||
if (swap_endianness_) {
|
||||
|
@ -2,6 +2,7 @@
|
||||
|
||||
#include <cstdint>
|
||||
#include <cstring>
|
||||
#include <fstream>
|
||||
#include <numeric>
|
||||
|
||||
#include "mlx/io/gguf.h"
|
||||
|
@ -2,9 +2,11 @@
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <fstream>
|
||||
#include <istream>
|
||||
#include <fcntl.h>
|
||||
#include <sys/stat.h>
|
||||
#include <unistd.h>
|
||||
#include <memory>
|
||||
#include <sstream>
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
@ -20,6 +22,7 @@ class Reader {
|
||||
std::ios_base::seekdir way = std::ios_base::beg) = 0;
|
||||
virtual void read(char* data, size_t n) = 0;
|
||||
virtual std::string label() const = 0;
|
||||
virtual ~Reader() = default;
|
||||
};
|
||||
|
||||
class Writer {
|
||||
@ -32,35 +35,50 @@ class Writer {
|
||||
std::ios_base::seekdir way = std::ios_base::beg) = 0;
|
||||
virtual void write(const char* data, size_t n) = 0;
|
||||
virtual std::string label() const = 0;
|
||||
virtual ~Writer() = default;
|
||||
};
|
||||
|
||||
class FileReader : public Reader {
|
||||
public:
|
||||
explicit FileReader(std::ifstream is)
|
||||
: is_(std::move(is)), label_("stream") {}
|
||||
explicit FileReader(std::string file_path)
|
||||
: is_(std::ifstream(file_path, std::ios::binary)),
|
||||
label_(std::move(file_path)) {}
|
||||
: fd_(open(file_path.c_str(), O_RDONLY)), label_(std::move(file_path)) {}
|
||||
|
||||
~FileReader() override {
|
||||
close(fd_);
|
||||
}
|
||||
|
||||
bool is_open() const override {
|
||||
return is_.is_open();
|
||||
return fd_ > 0;
|
||||
}
|
||||
|
||||
bool good() const override {
|
||||
return is_.good();
|
||||
return is_open();
|
||||
}
|
||||
|
||||
size_t tell() override {
|
||||
return is_.tellg();
|
||||
return lseek(fd_, 0, SEEK_CUR);
|
||||
}
|
||||
|
||||
void seek(int64_t off, std::ios_base::seekdir way = std::ios_base::beg)
|
||||
override {
|
||||
is_.seekg(off, way);
|
||||
if (way == std::ios_base::beg) {
|
||||
lseek(fd_, off, 0);
|
||||
} else {
|
||||
lseek(fd_, off, SEEK_CUR);
|
||||
}
|
||||
}
|
||||
|
||||
void read(char* data, size_t n) override {
|
||||
is_.read(data, n);
|
||||
while (n != 0) {
|
||||
auto m = ::read(fd_, data, n);
|
||||
if (m <= 0) {
|
||||
std::ostringstream msg;
|
||||
msg << "[read] Unable to read " << n << " bytes from file.";
|
||||
throw std::runtime_error(msg.str());
|
||||
}
|
||||
data += m;
|
||||
n -= m;
|
||||
}
|
||||
}
|
||||
|
||||
std::string label() const override {
|
||||
@ -68,37 +86,52 @@ class FileReader : public Reader {
|
||||
}
|
||||
|
||||
private:
|
||||
std::ifstream is_;
|
||||
int fd_;
|
||||
std::string label_;
|
||||
};
|
||||
|
||||
class FileWriter : public Writer {
|
||||
public:
|
||||
explicit FileWriter(std::ofstream os)
|
||||
: os_(std::move(os)), label_("stream") {}
|
||||
explicit FileWriter(std::string file_path)
|
||||
: os_(std::ofstream(file_path, std::ios::binary)),
|
||||
: fd_(open(file_path.c_str(), O_CREAT | O_WRONLY | O_TRUNC, 0644)),
|
||||
label_(std::move(file_path)) {}
|
||||
|
||||
~FileWriter() override {
|
||||
close(fd_);
|
||||
}
|
||||
|
||||
bool is_open() const override {
|
||||
return os_.is_open();
|
||||
return fd_ >= 0;
|
||||
}
|
||||
|
||||
bool good() const override {
|
||||
return os_.good();
|
||||
return is_open();
|
||||
}
|
||||
|
||||
size_t tell() override {
|
||||
return os_.tellp();
|
||||
return lseek(fd_, 0, SEEK_CUR);
|
||||
}
|
||||
|
||||
void seek(int64_t off, std::ios_base::seekdir way = std::ios_base::beg)
|
||||
override {
|
||||
os_.seekp(off, way);
|
||||
if (way == std::ios_base::beg) {
|
||||
lseek(fd_, off, 0);
|
||||
} else {
|
||||
lseek(fd_, off, SEEK_CUR);
|
||||
}
|
||||
}
|
||||
|
||||
void write(const char* data, size_t n) override {
|
||||
os_.write(data, n);
|
||||
while (n != 0) {
|
||||
auto m = ::write(fd_, data, n);
|
||||
if (m <= 0) {
|
||||
std::ostringstream msg;
|
||||
msg << "[write] Unable to write " << n << " bytes to file.";
|
||||
throw std::runtime_error(msg.str());
|
||||
}
|
||||
data += m;
|
||||
n -= m;
|
||||
}
|
||||
}
|
||||
|
||||
std::string label() const override {
|
||||
@ -106,7 +139,7 @@ class FileWriter : public Writer {
|
||||
}
|
||||
|
||||
private:
|
||||
std::ofstream os_;
|
||||
int fd_;
|
||||
std::string label_;
|
||||
};
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user