From 8242d6d5ef24a6769feaa09b683f5646b582bf5f Mon Sep 17 00:00:00 2001 From: Angelos Katharopoulos Date: Wed, 8 May 2024 23:19:27 -0700 Subject: [PATCH] Add locks to FileStream --- mlx/backend/io/primitives.cpp | 8 ++++++-- mlx/io/load.h | 11 +++++++++++ python/src/load.cpp | 10 ++++++++++ 3 files changed, 27 insertions(+), 2 deletions(-) diff --git a/mlx/backend/io/primitives.cpp b/mlx/backend/io/primitives.cpp index b8569a7d0..31f12c038 100644 --- a/mlx/backend/io/primitives.cpp +++ b/mlx/backend/io/primitives.cpp @@ -35,8 +35,12 @@ void Load::eval_io( array& out = outputs[0]; out.set_data(allocator::malloc_or_wait(out.nbytes())); - reader_->seek(offset_, std::ios_base::beg); - reader_->read(out.data(), out.nbytes()); + { + std::lock_guard lock(*reader_); + + reader_->seek(offset_, std::ios_base::beg); + reader_->read(out.data(), out.nbytes()); + } if (swap_endianness_) { switch (out.itemsize()) { diff --git a/mlx/io/load.h b/mlx/io/load.h index 637df1b19..6b07f738e 100644 --- a/mlx/io/load.h +++ b/mlx/io/load.h @@ -20,6 +20,8 @@ class Reader { std::ios_base::seekdir way = std::ios_base::beg) = 0; virtual void read(char* data, size_t n) = 0; virtual std::string label() const = 0; + virtual void lock() = 0; + virtual void unlock() = 0; }; class Writer { @@ -67,9 +69,18 @@ class FileReader : public Reader { return "file " + label_; } + void lock() override { + is_mutex_.lock(); + } + + void unlock() override { + is_mutex_.unlock(); + } + private: std::ifstream is_; std::string label_; + std::mutex is_mutex_; }; class FileWriter : public Writer { diff --git a/python/src/load.cpp b/python/src/load.cpp index 93b4df40a..3cfdffd71 100644 --- a/python/src/load.cpp +++ b/python/src/load.cpp @@ -150,11 +150,21 @@ class PyFileReader : public io::Reader { return "python file object"; } + void lock() override { + stream_mutex_.lock(); + } + + void unlock() override { + stream_mutex_.unlock(); + } + private: nb::object pyistream_; nb::object readinto_func_; nb::object seek_func_; nb::object tell_func_; + + std::mutex stream_mutex_; }; std::pair<