MLX
Loading...
Searching...
No Matches
load.h
Go to the documentation of this file.
1// Copyright © 2023 Apple Inc.
2
3#pragma once
4
5#include <fcntl.h>
6#include <sys/stat.h>
7#include <unistd.h>
8#include <memory>
9#include <sstream>
10
11namespace mlx::core {
12
13namespace io {
14
15class Reader {
16 public:
17 virtual bool is_open() const = 0;
18 virtual bool good() const = 0;
19 virtual size_t tell() = 0; // tellp is non-const in iostream
20 virtual void seek(
21 int64_t off,
22 std::ios_base::seekdir way = std::ios_base::beg) = 0;
23 virtual void read(char* data, size_t n) = 0;
24 virtual std::string label() const = 0;
25 virtual ~Reader() = default;
26};
27
28class Writer {
29 public:
30 virtual bool is_open() const = 0;
31 virtual bool good() const = 0;
32 virtual size_t tell() = 0;
33 virtual void seek(
34 int64_t off,
35 std::ios_base::seekdir way = std::ios_base::beg) = 0;
36 virtual void write(const char* data, size_t n) = 0;
37 virtual std::string label() const = 0;
38 virtual ~Writer() = default;
39};
40
41class FileReader : public Reader {
42 public:
43 explicit FileReader(std::string file_path)
44 : fd_(open(file_path.c_str(), O_RDONLY)), label_(std::move(file_path)) {}
45
46 ~FileReader() override {
47 close(fd_);
48 }
49
50 bool is_open() const override {
51 return fd_ > 0;
52 }
53
54 bool good() const override {
55 return is_open();
56 }
57
58 size_t tell() override {
59 return lseek(fd_, 0, SEEK_CUR);
60 }
61
62 void seek(int64_t off, std::ios_base::seekdir way = std::ios_base::beg)
63 override {
64 if (way == std::ios_base::beg) {
65 lseek(fd_, off, 0);
66 } else {
67 lseek(fd_, off, SEEK_CUR);
68 }
69 }
70
71 void read(char* data, size_t n) override {
72 while (n != 0) {
73 auto m = ::read(fd_, data, std::min(n, static_cast<size_t>(INT32_MAX)));
74 if (m <= 0) {
75 std::ostringstream msg;
76 msg << "[read] Unable to read " << n << " bytes from file.";
77 throw std::runtime_error(msg.str());
78 }
79 data += m;
80 n -= m;
81 }
82 }
83
84 std::string label() const override {
85 return "file " + label_;
86 }
87
88 private:
89 int fd_;
90 std::string label_;
91};
92
93class FileWriter : public Writer {
94 public:
95 explicit FileWriter(std::string file_path)
96 : fd_(open(file_path.c_str(), O_CREAT | O_WRONLY | O_TRUNC, 0644)),
97 label_(std::move(file_path)) {}
98
99 ~FileWriter() override {
100 close(fd_);
101 }
102
103 bool is_open() const override {
104 return fd_ >= 0;
105 }
106
107 bool good() const override {
108 return is_open();
109 }
110
111 size_t tell() override {
112 return lseek(fd_, 0, SEEK_CUR);
113 }
114
115 void seek(int64_t off, std::ios_base::seekdir way = std::ios_base::beg)
116 override {
117 if (way == std::ios_base::beg) {
118 lseek(fd_, off, 0);
119 } else {
120 lseek(fd_, off, SEEK_CUR);
121 }
122 }
123
124 void write(const char* data, size_t n) override {
125 while (n != 0) {
126 auto m = ::write(fd_, data, std::min(n, static_cast<size_t>(INT32_MAX)));
127 if (m <= 0) {
128 std::ostringstream msg;
129 msg << "[write] Unable to write " << n << " bytes to file.";
130 throw std::runtime_error(msg.str());
131 }
132 data += m;
133 n -= m;
134 }
135 }
136
137 std::string label() const override {
138 return "file " + label_;
139 }
140
141 private:
142 int fd_;
143 std::string label_;
144};
145
146} // namespace io
147} // namespace mlx::core
Definition load.h:41
bool good() const override
Definition load.h:54
~FileReader() override
Definition load.h:46
void read(char *data, size_t n) override
Definition load.h:71
std::string label() const override
Definition load.h:84
FileReader(std::string file_path)
Definition load.h:43
bool is_open() const override
Definition load.h:50
size_t tell() override
Definition load.h:58
void seek(int64_t off, std::ios_base::seekdir way=std::ios_base::beg) override
Definition load.h:62
Definition load.h:93
FileWriter(std::string file_path)
Definition load.h:95
std::string label() const override
Definition load.h:137
void seek(int64_t off, std::ios_base::seekdir way=std::ios_base::beg) override
Definition load.h:115
bool good() const override
Definition load.h:107
size_t tell() override
Definition load.h:111
void write(const char *data, size_t n) override
Definition load.h:124
~FileWriter() override
Definition load.h:99
bool is_open() const override
Definition load.h:103
Definition load.h:15
virtual bool good() const =0
virtual size_t tell()=0
virtual bool is_open() const =0
virtual ~Reader()=default
virtual std::string label() const =0
virtual void seek(int64_t off, std::ios_base::seekdir way=std::ios_base::beg)=0
virtual void read(char *data, size_t n)=0
Definition load.h:28
virtual bool good() const =0
virtual ~Writer()=default
virtual size_t tell()=0
virtual std::string label() const =0
virtual bool is_open() const =0
virtual void seek(int64_t off, std::ios_base::seekdir way=std::ios_base::beg)=0
virtual void write(const char *data, size_t n)=0
array std(const array &a, bool keepdims, int ddof=0, StreamOrDevice s={})
Computes the standard deviation of the elements of an array.
Definition allocator.h:7