MLX
Loading...
Searching...
No Matches
load.h
Go to the documentation of this file.
1// Copyright © 2023 Apple Inc.
2
3#pragma once
4
5#include <fcntl.h>
6#include <sys/stat.h>
7#include <unistd.h>
8#include <memory>
9#include <sstream>
10
11#include "mlx/io/threadpool.h"
12
13namespace mlx::core {
14
15namespace io {
16
18
19class Reader {
20 public:
21 virtual bool is_open() const = 0;
22 virtual bool good() const = 0;
23 virtual size_t tell() = 0; // tellp is non-const in iostream
24 virtual void seek(
25 int64_t off,
26 std::ios_base::seekdir way = std::ios_base::beg) = 0;
27 virtual void read(char* data, size_t n) = 0;
28 virtual void read(char* data, size_t n, size_t offset) = 0;
29 virtual std::string label() const = 0;
30 virtual ~Reader() = default;
31};
32
33class Writer {
34 public:
35 virtual bool is_open() const = 0;
36 virtual bool good() const = 0;
37 virtual size_t tell() = 0;
38 virtual void seek(
39 int64_t off,
40 std::ios_base::seekdir way = std::ios_base::beg) = 0;
41 virtual void write(const char* data, size_t n) = 0;
42 virtual std::string label() const = 0;
43 virtual ~Writer() = default;
44};
45
46class ParallelFileReader : public Reader {
47 public:
48 explicit ParallelFileReader(std::string file_path)
49 : fd_(open(file_path.c_str(), O_RDONLY)), label_(std::move(file_path)) {}
50
52 close(fd_);
53 }
54
55 bool is_open() const override {
56 return fd_ > 0;
57 }
58
59 bool good() const override {
60 return is_open();
61 }
62
63 size_t tell() override {
64 return lseek(fd_, 0, SEEK_CUR);
65 }
66
67 void seek(int64_t, std::ios_base::seekdir = std::ios_base::beg) override {
68 throw std::runtime_error("[ParallelFileReader::seek] Not allowed");
69 }
70
71 // Warning: do not use this function from multiple threads as
72 // it advances the file descriptor
73 void read(char* data, size_t n) override;
74
75 void read(char* data, size_t n, size_t offset) override;
76
77 std::string label() const override {
78 return "file " + label_;
79 }
80
81 private:
82 static constexpr size_t batch_size_ = 1 << 25;
83 static ThreadPool thread_pool_;
84 int fd_;
85 std::string label_;
86};
87
88class FileWriter : public Writer {
89 public:
90 explicit FileWriter(std::string file_path)
91 : fd_(open(file_path.c_str(), O_CREAT | O_WRONLY | O_TRUNC, 0644)),
92 label_(std::move(file_path)) {}
93
94 ~FileWriter() override {
95 close(fd_);
96 }
97
98 bool is_open() const override {
99 return fd_ >= 0;
100 }
101
102 bool good() const override {
103 return is_open();
104 }
105
106 size_t tell() override {
107 return lseek(fd_, 0, SEEK_CUR);
108 }
109
110 void seek(int64_t off, std::ios_base::seekdir way = std::ios_base::beg)
111 override {
112 if (way == std::ios_base::beg) {
113 lseek(fd_, off, 0);
114 } else {
115 lseek(fd_, off, SEEK_CUR);
116 }
117 }
118
119 void write(const char* data, size_t n) override {
120 while (n != 0) {
121 auto m = ::write(fd_, data, std::min(n, static_cast<size_t>(INT32_MAX)));
122 if (m <= 0) {
123 std::ostringstream msg;
124 msg << "[write] Unable to write " << n << " bytes to file.";
125 throw std::runtime_error(msg.str());
126 }
127 data += m;
128 n -= m;
129 }
130 }
131
132 std::string label() const override {
133 return "file " + label_;
134 }
135
136 private:
137 int fd_;
138 std::string label_;
139};
140
141} // namespace io
142} // namespace mlx::core
Definition threadpool.h:35
Definition load.h:88
FileWriter(std::string file_path)
Definition load.h:90
std::string label() const override
Definition load.h:132
void seek(int64_t off, std::ios_base::seekdir way=std::ios_base::beg) override
Definition load.h:110
bool good() const override
Definition load.h:102
size_t tell() override
Definition load.h:106
void write(const char *data, size_t n) override
Definition load.h:119
~FileWriter() override
Definition load.h:94
bool is_open() const override
Definition load.h:98
void read(char *data, size_t n, size_t offset) override
size_t tell() override
Definition load.h:63
std::string label() const override
Definition load.h:77
bool is_open() const override
Definition load.h:55
void read(char *data, size_t n) override
void seek(int64_t, std::ios_base::seekdir=std::ios_base::beg) override
Definition load.h:67
ParallelFileReader(std::string file_path)
Definition load.h:48
bool good() const override
Definition load.h:59
~ParallelFileReader() override
Definition load.h:51
Definition load.h:19
virtual bool good() const =0
virtual size_t tell()=0
virtual void read(char *data, size_t n, size_t offset)=0
virtual bool is_open() const =0
virtual ~Reader()=default
virtual std::string label() const =0
virtual void seek(int64_t off, std::ios_base::seekdir way=std::ios_base::beg)=0
virtual void read(char *data, size_t n)=0
Definition load.h:33
virtual bool good() const =0
virtual ~Writer()=default
virtual size_t tell()=0
virtual std::string label() const =0
virtual bool is_open() const =0
virtual void seek(int64_t off, std::ios_base::seekdir way=std::ios_base::beg)=0
virtual void write(const char *data, size_t n)=0
array std(const array &a, bool keepdims, int ddof=0, StreamOrDevice s={})
Computes the standard deviation of the elements of an array.
ThreadPool & thread_pool()
Definition allocator.h:7