mirror of
				https://github.com/ml-explore/mlx.git
				synced 2025-11-04 10:38:10 +08:00 
			
		
		
		
	NPY loading segfault bug (#34)
* Fixed Gil semantics in loading and saving from python file streams
This commit is contained in:
		@@ -11,8 +11,6 @@
 | 
			
		||||
#include <unordered_map>
 | 
			
		||||
#include <vector>
 | 
			
		||||
 | 
			
		||||
#include <iostream>
 | 
			
		||||
 | 
			
		||||
#include "mlx/load.h"
 | 
			
		||||
#include "mlx/ops.h"
 | 
			
		||||
#include "mlx/utils.h"
 | 
			
		||||
@@ -99,26 +97,54 @@ class PyFileReader : public io::Reader {
 | 
			
		||||
        seek_func_(file.attr("seek")),
 | 
			
		||||
        tell_func_(file.attr("tell")) {}
 | 
			
		||||
 | 
			
		||||
  ~PyFileReader() {
 | 
			
		||||
    py::gil_scoped_acquire gil;
 | 
			
		||||
 | 
			
		||||
    pyistream_.release().dec_ref();
 | 
			
		||||
    readinto_func_.release().dec_ref();
 | 
			
		||||
    seek_func_.release().dec_ref();
 | 
			
		||||
    tell_func_.release().dec_ref();
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  bool is_open() const override {
 | 
			
		||||
    return !pyistream_.attr("closed").cast<bool>();
 | 
			
		||||
    bool out;
 | 
			
		||||
    {
 | 
			
		||||
      py::gil_scoped_acquire gil;
 | 
			
		||||
      out = !pyistream_.attr("closed").cast<bool>();
 | 
			
		||||
    }
 | 
			
		||||
    return out;
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  bool good() const override {
 | 
			
		||||
    return !pyistream_.is_none();
 | 
			
		||||
    bool out;
 | 
			
		||||
    {
 | 
			
		||||
      py::gil_scoped_acquire gil;
 | 
			
		||||
      out = !pyistream_.is_none();
 | 
			
		||||
    }
 | 
			
		||||
    return out;
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  size_t tell() const override {
 | 
			
		||||
    return tell_func_().cast<size_t>();
 | 
			
		||||
    size_t out;
 | 
			
		||||
    {
 | 
			
		||||
      py::gil_scoped_acquire gil;
 | 
			
		||||
      out = tell_func_().cast<size_t>();
 | 
			
		||||
    }
 | 
			
		||||
    return out;
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  void seek(int64_t off, std::ios_base::seekdir way = std::ios_base::beg)
 | 
			
		||||
      override {
 | 
			
		||||
    py::gil_scoped_acquire gil;
 | 
			
		||||
    seek_func_(off, (int)way);
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  void read(char* data, size_t n) override {
 | 
			
		||||
    py::gil_scoped_acquire gil;
 | 
			
		||||
 | 
			
		||||
    py::object bytes_read =
 | 
			
		||||
        readinto_func_(py::memoryview::from_buffer(data, {n}, {sizeof(char)}));
 | 
			
		||||
 | 
			
		||||
    if (bytes_read.is_none() || py::cast<size_t>(bytes_read) < n) {
 | 
			
		||||
      throw std::runtime_error("[load] Failed to read from python stream");
 | 
			
		||||
    }
 | 
			
		||||
@@ -163,6 +189,7 @@ DictOrArray mlx_load_helper(py::object file, StreamOrDevice s) {
 | 
			
		||||
 | 
			
		||||
    // If we don't own the stream and it was passed to us, eval immediately
 | 
			
		||||
    for (auto& [key, arr] : array_dict) {
 | 
			
		||||
      py::gil_scoped_release gil;
 | 
			
		||||
      arr.eval();
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
@@ -172,7 +199,10 @@ DictOrArray mlx_load_helper(py::object file, StreamOrDevice s) {
 | 
			
		||||
  } else if (is_istream_object(file)) {
 | 
			
		||||
    // If we don't own the stream and it was passed to us, eval immediately
 | 
			
		||||
    auto arr = load(std::make_shared<PyFileReader>(file), s);
 | 
			
		||||
    arr.eval();
 | 
			
		||||
    {
 | 
			
		||||
      py::gil_scoped_release gil;
 | 
			
		||||
      arr.eval();
 | 
			
		||||
    }
 | 
			
		||||
    return {arr};
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
@@ -192,26 +222,54 @@ class PyFileWriter : public io::Writer {
 | 
			
		||||
        seek_func_(file.attr("seek")),
 | 
			
		||||
        tell_func_(file.attr("tell")) {}
 | 
			
		||||
 | 
			
		||||
  ~PyFileWriter() {
 | 
			
		||||
    py::gil_scoped_acquire gil;
 | 
			
		||||
 | 
			
		||||
    pyostream_.release().dec_ref();
 | 
			
		||||
    write_func_.release().dec_ref();
 | 
			
		||||
    seek_func_.release().dec_ref();
 | 
			
		||||
    tell_func_.release().dec_ref();
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  bool is_open() const override {
 | 
			
		||||
    return !pyostream_.attr("closed").cast<bool>();
 | 
			
		||||
    bool out;
 | 
			
		||||
    {
 | 
			
		||||
      py::gil_scoped_acquire gil;
 | 
			
		||||
      out = !pyostream_.attr("closed").cast<bool>();
 | 
			
		||||
    }
 | 
			
		||||
    return out;
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  bool good() const override {
 | 
			
		||||
    return !pyostream_.is_none();
 | 
			
		||||
    bool out;
 | 
			
		||||
    {
 | 
			
		||||
      py::gil_scoped_acquire gil;
 | 
			
		||||
      out = !pyostream_.is_none();
 | 
			
		||||
    }
 | 
			
		||||
    return out;
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  size_t tell() const override {
 | 
			
		||||
    return tell_func_().cast<size_t>();
 | 
			
		||||
    size_t out;
 | 
			
		||||
    {
 | 
			
		||||
      py::gil_scoped_acquire gil;
 | 
			
		||||
      out = tell_func_().cast<size_t>();
 | 
			
		||||
    }
 | 
			
		||||
    return out;
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  void seek(int64_t off, std::ios_base::seekdir way = std::ios_base::beg)
 | 
			
		||||
      override {
 | 
			
		||||
    py::gil_scoped_acquire gil;
 | 
			
		||||
    seek_func_(off, (int)way);
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  void write(const char* data, size_t n) override {
 | 
			
		||||
    py::gil_scoped_acquire gil;
 | 
			
		||||
 | 
			
		||||
    py::object bytes_written =
 | 
			
		||||
        write_func_(py::memoryview::from_buffer(data, {n}, {sizeof(char)}));
 | 
			
		||||
 | 
			
		||||
    if (bytes_written.is_none() || py::cast<size_t>(bytes_written) < n) {
 | 
			
		||||
      throw std::runtime_error("[load] Failed to write to python stream");
 | 
			
		||||
    }
 | 
			
		||||
@@ -233,7 +291,12 @@ void mlx_save_helper(py::object file, array a, bool retain_graph) {
 | 
			
		||||
    save(py::cast<std::string>(file), a, retain_graph);
 | 
			
		||||
    return;
 | 
			
		||||
  } else if (is_ostream_object(file)) {
 | 
			
		||||
    save(std::make_shared<PyFileWriter>(file), a, retain_graph);
 | 
			
		||||
    auto writer = std::make_shared<PyFileWriter>(file);
 | 
			
		||||
    {
 | 
			
		||||
      py::gil_scoped_release gil;
 | 
			
		||||
      save(writer, a, retain_graph);
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    return;
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
@@ -285,7 +348,11 @@ void mlx_savez_helper(
 | 
			
		||||
  for (auto [k, a] : arrays_dict) {
 | 
			
		||||
    std::string fname = k + ".npy";
 | 
			
		||||
    auto py_ostream = zipfile_object.open(fname, 'w');
 | 
			
		||||
    save(std::make_shared<PyFileWriter>(py_ostream), a);
 | 
			
		||||
    auto writer = std::make_shared<PyFileWriter>(py_ostream);
 | 
			
		||||
    {
 | 
			
		||||
      py::gil_scoped_release gil;
 | 
			
		||||
      save(writer, a);
 | 
			
		||||
    }
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  return;
 | 
			
		||||
 
 | 
			
		||||
		Reference in New Issue
	
	Block a user