MLX
 
Loading...
Searching...
No Matches
load.h
Go to the documentation of this file.
1// Copyright © 2023 Apple Inc.
2
3#pragma once
4
5#include <memory>
6#include <sstream>
7
8#include <fcntl.h>
9#ifdef _MSC_VER
10#include <io.h>
11#else
12#include <sys/stat.h>
13#include <unistd.h>
14#endif
15
16#include "mlx/threadpool.h"
17
18// Strictly we need to operate on files in binary mode (to avoid \r getting
19// automatically inserted), but every modern system except for Windows no
20// longer differentiates between binary and text files and for them define
21// the flag as no-op.
22#ifndef O_BINARY
23#define O_BINARY 0
24#endif
25
26namespace mlx::core {
27
28namespace io {
29
31
32class Reader {
33 public:
34 virtual bool is_open() const = 0;
35 virtual bool good() const = 0;
36 virtual size_t tell() = 0; // tellp is non-const in iostream
37 virtual void seek(
38 int64_t off,
39 std::ios_base::seekdir way = std::ios_base::beg) = 0;
40 virtual void read(char* data, size_t n) = 0;
41 virtual void read(char* data, size_t n, size_t offset) = 0;
42 virtual std::string label() const = 0;
43 virtual ~Reader() = default;
44};
45
46class Writer {
47 public:
48 virtual bool is_open() const = 0;
49 virtual bool good() const = 0;
50 virtual size_t tell() = 0;
51 virtual void seek(
52 int64_t off,
53 std::ios_base::seekdir way = std::ios_base::beg) = 0;
54 virtual void write(const char* data, size_t n) = 0;
55 virtual std::string label() const = 0;
56 virtual ~Writer() = default;
57};
58
59class ParallelFileReader : public Reader {
60 public:
61 explicit ParallelFileReader(std::string file_path)
62 : fd_(open(file_path.c_str(), O_RDONLY | O_BINARY)),
63 label_(std::move(file_path)) {}
64
66 close(fd_);
67 }
68
69 bool is_open() const override {
70 return fd_ > 0;
71 }
72
73 bool good() const override {
74 return is_open();
75 }
76
77 size_t tell() override {
78 return lseek(fd_, 0, SEEK_CUR);
79 }
80
81 // Warning: do not use this function from multiple threads as
82 // it advances the file descriptor
83 void seek(int64_t off, std::ios_base::seekdir way = std::ios_base::beg)
84 override {
85 if (way == std::ios_base::beg) {
86 lseek(fd_, off, 0);
87 } else {
88 lseek(fd_, off, SEEK_CUR);
89 }
90 }
91
92 // Warning: do not use this function from multiple threads as
93 // it advances the file descriptor
94 void read(char* data, size_t n) override;
95
96 void read(char* data, size_t n, size_t offset) override;
97
98 std::string label() const override {
99 return "file " + label_;
100 }
101
102 private:
103 static constexpr size_t batch_size_ = 1 << 25;
104 static ThreadPool thread_pool_;
105 int fd_;
106 std::string label_;
107};
108
109class FileWriter : public Writer {
110 public:
111 explicit FileWriter(std::string file_path)
112 : fd_(open(
113 file_path.c_str(),
114 O_CREAT | O_WRONLY | O_TRUNC | O_BINARY,
115 0644)),
116 label_(std::move(file_path)) {}
117
118 FileWriter(const FileWriter&) = delete;
119 FileWriter& operator=(const FileWriter&) = delete;
121 std::swap(fd_, other.fd_);
122 }
123
124 ~FileWriter() override {
125 if (fd_ != 0) {
126 close(fd_);
127 }
128 }
129
130 bool is_open() const override {
131 return fd_ >= 0;
132 }
133
134 bool good() const override {
135 return is_open();
136 }
137
138 size_t tell() override {
139 return lseek(fd_, 0, SEEK_CUR);
140 }
141
142 void seek(int64_t off, std::ios_base::seekdir way = std::ios_base::beg)
143 override {
144 if (way == std::ios_base::beg) {
145 lseek(fd_, off, 0);
146 } else {
147 lseek(fd_, off, SEEK_CUR);
148 }
149 }
150
151 void write(const char* data, size_t n) override {
152 while (n != 0) {
153 auto m = ::write(fd_, data, std::min(n, static_cast<size_t>(INT32_MAX)));
154 if (m <= 0) {
155 std::ostringstream msg;
156 msg << "[write] Unable to write " << n << " bytes to file.";
157 throw std::runtime_error(msg.str());
158 }
159 data += m;
160 n -= m;
161 }
162 }
163
164 std::string label() const override {
165 return "file " + label_;
166 }
167
168 private:
169 int fd_{0};
170 std::string label_;
171};
172
173} // namespace io
174} // namespace mlx::core
Definition threadpool.h:35
FileWriter(FileWriter &&other)
Definition load.h:120
FileWriter(std::string file_path)
Definition load.h:111
std::string label() const override
Definition load.h:164
FileWriter & operator=(const FileWriter &)=delete
void seek(int64_t off, std::ios_base::seekdir way=std::ios_base::beg) override
Definition load.h:142
bool good() const override
Definition load.h:134
size_t tell() override
Definition load.h:138
void write(const char *data, size_t n) override
Definition load.h:151
~FileWriter() override
Definition load.h:124
bool is_open() const override
Definition load.h:130
FileWriter(const FileWriter &)=delete
void read(char *data, size_t n, size_t offset) override
size_t tell() override
Definition load.h:77
std::string label() const override
Definition load.h:98
void seek(int64_t off, std::ios_base::seekdir way=std::ios_base::beg) override
Definition load.h:83
bool is_open() const override
Definition load.h:69
void read(char *data, size_t n) override
ParallelFileReader(std::string file_path)
Definition load.h:61
bool good() const override
Definition load.h:73
~ParallelFileReader() override
Definition load.h:65
Definition load.h:32
virtual bool good() const =0
virtual size_t tell()=0
virtual void read(char *data, size_t n, size_t offset)=0
virtual bool is_open() const =0
virtual ~Reader()=default
virtual std::string label() const =0
virtual void seek(int64_t off, std::ios_base::seekdir way=std::ios_base::beg)=0
virtual void read(char *data, size_t n)=0
Definition load.h:46
virtual bool good() const =0
virtual ~Writer()=default
virtual size_t tell()=0
virtual std::string label() const =0
virtual bool is_open() const =0
virtual void seek(int64_t off, std::ios_base::seekdir way=std::ios_base::beg)=0
virtual void write(const char *data, size_t n)=0
array std(const array &a, bool keepdims, int ddof=0, StreamOrDevice s={})
Computes the standard deviation of the elements of an array.
#define O_BINARY
Definition load.h:23
Definition load.h:28
ThreadPool & thread_pool()
Definition allocator.h:7