mirror of
				https://github.com/ml-explore/mlx.git
				synced 2025-11-04 18:48:15 +08:00 
			
		
		
		
	Add locks to FileStream
This commit is contained in:
		@@ -35,8 +35,12 @@ void Load::eval_io(
 | 
				
			|||||||
  array& out = outputs[0];
 | 
					  array& out = outputs[0];
 | 
				
			||||||
  out.set_data(allocator::malloc_or_wait(out.nbytes()));
 | 
					  out.set_data(allocator::malloc_or_wait(out.nbytes()));
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  reader_->seek(offset_, std::ios_base::beg);
 | 
					  {
 | 
				
			||||||
  reader_->read(out.data<char>(), out.nbytes());
 | 
					    std::lock_guard lock(*reader_);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    reader_->seek(offset_, std::ios_base::beg);
 | 
				
			||||||
 | 
					    reader_->read(out.data<char>(), out.nbytes());
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  if (swap_endianness_) {
 | 
					  if (swap_endianness_) {
 | 
				
			||||||
    switch (out.itemsize()) {
 | 
					    switch (out.itemsize()) {
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -20,6 +20,8 @@ class Reader {
 | 
				
			|||||||
      std::ios_base::seekdir way = std::ios_base::beg) = 0;
 | 
					      std::ios_base::seekdir way = std::ios_base::beg) = 0;
 | 
				
			||||||
  virtual void read(char* data, size_t n) = 0;
 | 
					  virtual void read(char* data, size_t n) = 0;
 | 
				
			||||||
  virtual std::string label() const = 0;
 | 
					  virtual std::string label() const = 0;
 | 
				
			||||||
 | 
					  virtual void lock() = 0;
 | 
				
			||||||
 | 
					  virtual void unlock() = 0;
 | 
				
			||||||
};
 | 
					};
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class Writer {
 | 
					class Writer {
 | 
				
			||||||
@@ -67,9 +69,18 @@ class FileReader : public Reader {
 | 
				
			|||||||
    return "file " + label_;
 | 
					    return "file " + label_;
 | 
				
			||||||
  }
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  void lock() override {
 | 
				
			||||||
 | 
					    is_mutex_.lock();
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  void unlock() override {
 | 
				
			||||||
 | 
					    is_mutex_.unlock();
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 private:
 | 
					 private:
 | 
				
			||||||
  std::ifstream is_;
 | 
					  std::ifstream is_;
 | 
				
			||||||
  std::string label_;
 | 
					  std::string label_;
 | 
				
			||||||
 | 
					  std::mutex is_mutex_;
 | 
				
			||||||
};
 | 
					};
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class FileWriter : public Writer {
 | 
					class FileWriter : public Writer {
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -150,11 +150,21 @@ class PyFileReader : public io::Reader {
 | 
				
			|||||||
    return "python file object";
 | 
					    return "python file object";
 | 
				
			||||||
  }
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  void lock() override {
 | 
				
			||||||
 | 
					    stream_mutex_.lock();
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  void unlock() override {
 | 
				
			||||||
 | 
					    stream_mutex_.unlock();
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 private:
 | 
					 private:
 | 
				
			||||||
  nb::object pyistream_;
 | 
					  nb::object pyistream_;
 | 
				
			||||||
  nb::object readinto_func_;
 | 
					  nb::object readinto_func_;
 | 
				
			||||||
  nb::object seek_func_;
 | 
					  nb::object seek_func_;
 | 
				
			||||||
  nb::object tell_func_;
 | 
					  nb::object tell_func_;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  std::mutex stream_mutex_;
 | 
				
			||||||
};
 | 
					};
 | 
				
			||||||
 | 
					
 | 
				
			||||||
std::pair<
 | 
					std::pair<
 | 
				
			||||||
 
 | 
				
			|||||||
		Reference in New Issue
	
	Block a user