mirror of
https://github.com/ml-explore/mlx.git
synced 2025-08-11 03:36:40 +08:00
Read arrays from files faster (#1330)
* read faster * faster write as well * set default permission for linux * comment
This commit is contained in:
parent
99bb7d3a58
commit
d0630ffe8c
@ -2,6 +2,7 @@
|
|||||||
|
|
||||||
#include <dlfcn.h>
|
#include <dlfcn.h>
|
||||||
#include <filesystem>
|
#include <filesystem>
|
||||||
|
#include <fstream>
|
||||||
#include <list>
|
#include <list>
|
||||||
|
|
||||||
#include "mlx/backend/common/compiled.h"
|
#include "mlx/backend/common/compiled.h"
|
||||||
|
@ -33,7 +33,7 @@ void Load::eval(const std::vector<array>& inputs, array& out) {
|
|||||||
assert(inputs.size() == 0);
|
assert(inputs.size() == 0);
|
||||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||||
|
|
||||||
reader_->seek(offset_, std::ios_base::beg);
|
reader_->seek(offset_);
|
||||||
reader_->read(out.data<char>(), out.nbytes());
|
reader_->read(out.data<char>(), out.nbytes());
|
||||||
|
|
||||||
if (swap_endianness_) {
|
if (swap_endianness_) {
|
||||||
|
@ -2,6 +2,7 @@
|
|||||||
|
|
||||||
#include <cstdint>
|
#include <cstdint>
|
||||||
#include <cstring>
|
#include <cstring>
|
||||||
|
#include <fstream>
|
||||||
#include <numeric>
|
#include <numeric>
|
||||||
|
|
||||||
#include "mlx/io/gguf.h"
|
#include "mlx/io/gguf.h"
|
||||||
|
@ -2,9 +2,11 @@
|
|||||||
|
|
||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
#include <fstream>
|
#include <fcntl.h>
|
||||||
#include <istream>
|
#include <sys/stat.h>
|
||||||
|
#include <unistd.h>
|
||||||
#include <memory>
|
#include <memory>
|
||||||
|
#include <sstream>
|
||||||
|
|
||||||
namespace mlx::core {
|
namespace mlx::core {
|
||||||
|
|
||||||
@ -20,6 +22,7 @@ class Reader {
|
|||||||
std::ios_base::seekdir way = std::ios_base::beg) = 0;
|
std::ios_base::seekdir way = std::ios_base::beg) = 0;
|
||||||
virtual void read(char* data, size_t n) = 0;
|
virtual void read(char* data, size_t n) = 0;
|
||||||
virtual std::string label() const = 0;
|
virtual std::string label() const = 0;
|
||||||
|
virtual ~Reader() = default;
|
||||||
};
|
};
|
||||||
|
|
||||||
class Writer {
|
class Writer {
|
||||||
@ -32,35 +35,50 @@ class Writer {
|
|||||||
std::ios_base::seekdir way = std::ios_base::beg) = 0;
|
std::ios_base::seekdir way = std::ios_base::beg) = 0;
|
||||||
virtual void write(const char* data, size_t n) = 0;
|
virtual void write(const char* data, size_t n) = 0;
|
||||||
virtual std::string label() const = 0;
|
virtual std::string label() const = 0;
|
||||||
|
virtual ~Writer() = default;
|
||||||
};
|
};
|
||||||
|
|
||||||
class FileReader : public Reader {
|
class FileReader : public Reader {
|
||||||
public:
|
public:
|
||||||
explicit FileReader(std::ifstream is)
|
|
||||||
: is_(std::move(is)), label_("stream") {}
|
|
||||||
explicit FileReader(std::string file_path)
|
explicit FileReader(std::string file_path)
|
||||||
: is_(std::ifstream(file_path, std::ios::binary)),
|
: fd_(open(file_path.c_str(), O_RDONLY)), label_(std::move(file_path)) {}
|
||||||
label_(std::move(file_path)) {}
|
|
||||||
|
~FileReader() override {
|
||||||
|
close(fd_);
|
||||||
|
}
|
||||||
|
|
||||||
bool is_open() const override {
|
bool is_open() const override {
|
||||||
return is_.is_open();
|
return fd_ > 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
bool good() const override {
|
bool good() const override {
|
||||||
return is_.good();
|
return is_open();
|
||||||
}
|
}
|
||||||
|
|
||||||
size_t tell() override {
|
size_t tell() override {
|
||||||
return is_.tellg();
|
return lseek(fd_, 0, SEEK_CUR);
|
||||||
}
|
}
|
||||||
|
|
||||||
void seek(int64_t off, std::ios_base::seekdir way = std::ios_base::beg)
|
void seek(int64_t off, std::ios_base::seekdir way = std::ios_base::beg)
|
||||||
override {
|
override {
|
||||||
is_.seekg(off, way);
|
if (way == std::ios_base::beg) {
|
||||||
|
lseek(fd_, off, 0);
|
||||||
|
} else {
|
||||||
|
lseek(fd_, off, SEEK_CUR);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void read(char* data, size_t n) override {
|
void read(char* data, size_t n) override {
|
||||||
is_.read(data, n);
|
while (n != 0) {
|
||||||
|
auto m = ::read(fd_, data, n);
|
||||||
|
if (m <= 0) {
|
||||||
|
std::ostringstream msg;
|
||||||
|
msg << "[read] Unable to read " << n << " bytes from file.";
|
||||||
|
throw std::runtime_error(msg.str());
|
||||||
|
}
|
||||||
|
data += m;
|
||||||
|
n -= m;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
std::string label() const override {
|
std::string label() const override {
|
||||||
@ -68,37 +86,52 @@ class FileReader : public Reader {
|
|||||||
}
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
std::ifstream is_;
|
int fd_;
|
||||||
std::string label_;
|
std::string label_;
|
||||||
};
|
};
|
||||||
|
|
||||||
class FileWriter : public Writer {
|
class FileWriter : public Writer {
|
||||||
public:
|
public:
|
||||||
explicit FileWriter(std::ofstream os)
|
|
||||||
: os_(std::move(os)), label_("stream") {}
|
|
||||||
explicit FileWriter(std::string file_path)
|
explicit FileWriter(std::string file_path)
|
||||||
: os_(std::ofstream(file_path, std::ios::binary)),
|
: fd_(open(file_path.c_str(), O_CREAT | O_WRONLY | O_TRUNC, 0644)),
|
||||||
label_(std::move(file_path)) {}
|
label_(std::move(file_path)) {}
|
||||||
|
|
||||||
|
~FileWriter() override {
|
||||||
|
close(fd_);
|
||||||
|
}
|
||||||
|
|
||||||
bool is_open() const override {
|
bool is_open() const override {
|
||||||
return os_.is_open();
|
return fd_ >= 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
bool good() const override {
|
bool good() const override {
|
||||||
return os_.good();
|
return is_open();
|
||||||
}
|
}
|
||||||
|
|
||||||
size_t tell() override {
|
size_t tell() override {
|
||||||
return os_.tellp();
|
return lseek(fd_, 0, SEEK_CUR);
|
||||||
}
|
}
|
||||||
|
|
||||||
void seek(int64_t off, std::ios_base::seekdir way = std::ios_base::beg)
|
void seek(int64_t off, std::ios_base::seekdir way = std::ios_base::beg)
|
||||||
override {
|
override {
|
||||||
os_.seekp(off, way);
|
if (way == std::ios_base::beg) {
|
||||||
|
lseek(fd_, off, 0);
|
||||||
|
} else {
|
||||||
|
lseek(fd_, off, SEEK_CUR);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void write(const char* data, size_t n) override {
|
void write(const char* data, size_t n) override {
|
||||||
os_.write(data, n);
|
while (n != 0) {
|
||||||
|
auto m = ::write(fd_, data, n);
|
||||||
|
if (m <= 0) {
|
||||||
|
std::ostringstream msg;
|
||||||
|
msg << "[write] Unable to write " << n << " bytes to file.";
|
||||||
|
throw std::runtime_error(msg.str());
|
||||||
|
}
|
||||||
|
data += m;
|
||||||
|
n -= m;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
std::string label() const override {
|
std::string label() const override {
|
||||||
@ -106,7 +139,7 @@ class FileWriter : public Writer {
|
|||||||
}
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
std::ofstream os_;
|
int fd_;
|
||||||
std::string label_;
|
std::string label_;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user