mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-24 09:21:16 +08:00
Add locks to FileStream
This commit is contained in:
parent
bae159738f
commit
8242d6d5ef
@ -35,8 +35,12 @@ void Load::eval_io(
|
|||||||
array& out = outputs[0];
|
array& out = outputs[0];
|
||||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||||
|
|
||||||
reader_->seek(offset_, std::ios_base::beg);
|
{
|
||||||
reader_->read(out.data<char>(), out.nbytes());
|
std::lock_guard lock(*reader_);
|
||||||
|
|
||||||
|
reader_->seek(offset_, std::ios_base::beg);
|
||||||
|
reader_->read(out.data<char>(), out.nbytes());
|
||||||
|
}
|
||||||
|
|
||||||
if (swap_endianness_) {
|
if (swap_endianness_) {
|
||||||
switch (out.itemsize()) {
|
switch (out.itemsize()) {
|
||||||
|
@ -20,6 +20,8 @@ class Reader {
|
|||||||
std::ios_base::seekdir way = std::ios_base::beg) = 0;
|
std::ios_base::seekdir way = std::ios_base::beg) = 0;
|
||||||
virtual void read(char* data, size_t n) = 0;
|
virtual void read(char* data, size_t n) = 0;
|
||||||
virtual std::string label() const = 0;
|
virtual std::string label() const = 0;
|
||||||
|
virtual void lock() = 0;
|
||||||
|
virtual void unlock() = 0;
|
||||||
};
|
};
|
||||||
|
|
||||||
class Writer {
|
class Writer {
|
||||||
@ -67,9 +69,18 @@ class FileReader : public Reader {
|
|||||||
return "file " + label_;
|
return "file " + label_;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void lock() override {
|
||||||
|
is_mutex_.lock();
|
||||||
|
}
|
||||||
|
|
||||||
|
void unlock() override {
|
||||||
|
is_mutex_.unlock();
|
||||||
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
std::ifstream is_;
|
std::ifstream is_;
|
||||||
std::string label_;
|
std::string label_;
|
||||||
|
std::mutex is_mutex_;
|
||||||
};
|
};
|
||||||
|
|
||||||
class FileWriter : public Writer {
|
class FileWriter : public Writer {
|
||||||
|
@ -150,11 +150,21 @@ class PyFileReader : public io::Reader {
|
|||||||
return "python file object";
|
return "python file object";
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void lock() override {
|
||||||
|
stream_mutex_.lock();
|
||||||
|
}
|
||||||
|
|
||||||
|
void unlock() override {
|
||||||
|
stream_mutex_.unlock();
|
||||||
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
nb::object pyistream_;
|
nb::object pyistream_;
|
||||||
nb::object readinto_func_;
|
nb::object readinto_func_;
|
||||||
nb::object seek_func_;
|
nb::object seek_func_;
|
||||||
nb::object tell_func_;
|
nb::object tell_func_;
|
||||||
|
|
||||||
|
std::mutex stream_mutex_;
|
||||||
};
|
};
|
||||||
|
|
||||||
std::pair<
|
std::pair<
|
||||||
|
Loading…
Reference in New Issue
Block a user