NPY loading segfault bug (#34)

* Fixed Gil semantics in loading and saving from python file streams
This commit is contained in:
Jagrit Digani 2023-12-06 12:03:47 -08:00 committed by GitHub
parent 170e4b2d43
commit 2440fe0124
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 79 additions and 12 deletions

View File

@ -33,7 +33,7 @@ def silu(x):
def gelu(x): def gelu(x):
"""Applies the Gaussian Error Linear Units function. r"""Applies the Gaussian Error Linear Units function.
.. math:: .. math::
\\textrm{GELU}(x) = x * \Phi(x) \\textrm{GELU}(x) = x * \Phi(x)

View File

@ -11,8 +11,6 @@
#include <unordered_map> #include <unordered_map>
#include <vector> #include <vector>
#include <iostream>
#include "mlx/load.h" #include "mlx/load.h"
#include "mlx/ops.h" #include "mlx/ops.h"
#include "mlx/utils.h" #include "mlx/utils.h"
@ -99,26 +97,54 @@ class PyFileReader : public io::Reader {
seek_func_(file.attr("seek")), seek_func_(file.attr("seek")),
tell_func_(file.attr("tell")) {} tell_func_(file.attr("tell")) {}
~PyFileReader() {
py::gil_scoped_acquire gil;
pyistream_.release().dec_ref();
readinto_func_.release().dec_ref();
seek_func_.release().dec_ref();
tell_func_.release().dec_ref();
}
bool is_open() const override { bool is_open() const override {
return !pyistream_.attr("closed").cast<bool>(); bool out;
{
py::gil_scoped_acquire gil;
out = !pyistream_.attr("closed").cast<bool>();
}
return out;
} }
bool good() const override { bool good() const override {
return !pyistream_.is_none(); bool out;
{
py::gil_scoped_acquire gil;
out = !pyistream_.is_none();
}
return out;
} }
size_t tell() const override { size_t tell() const override {
return tell_func_().cast<size_t>(); size_t out;
{
py::gil_scoped_acquire gil;
out = tell_func_().cast<size_t>();
}
return out;
} }
void seek(int64_t off, std::ios_base::seekdir way = std::ios_base::beg) void seek(int64_t off, std::ios_base::seekdir way = std::ios_base::beg)
override { override {
py::gil_scoped_acquire gil;
seek_func_(off, (int)way); seek_func_(off, (int)way);
} }
void read(char* data, size_t n) override { void read(char* data, size_t n) override {
py::gil_scoped_acquire gil;
py::object bytes_read = py::object bytes_read =
readinto_func_(py::memoryview::from_buffer(data, {n}, {sizeof(char)})); readinto_func_(py::memoryview::from_buffer(data, {n}, {sizeof(char)}));
if (bytes_read.is_none() || py::cast<size_t>(bytes_read) < n) { if (bytes_read.is_none() || py::cast<size_t>(bytes_read) < n) {
throw std::runtime_error("[load] Failed to read from python stream"); throw std::runtime_error("[load] Failed to read from python stream");
} }
@ -163,6 +189,7 @@ DictOrArray mlx_load_helper(py::object file, StreamOrDevice s) {
// If we don't own the stream and it was passed to us, eval immediately // If we don't own the stream and it was passed to us, eval immediately
for (auto& [key, arr] : array_dict) { for (auto& [key, arr] : array_dict) {
py::gil_scoped_release gil;
arr.eval(); arr.eval();
} }
@ -172,7 +199,10 @@ DictOrArray mlx_load_helper(py::object file, StreamOrDevice s) {
} else if (is_istream_object(file)) { } else if (is_istream_object(file)) {
// If we don't own the stream and it was passed to us, eval immediately // If we don't own the stream and it was passed to us, eval immediately
auto arr = load(std::make_shared<PyFileReader>(file), s); auto arr = load(std::make_shared<PyFileReader>(file), s);
arr.eval(); {
py::gil_scoped_release gil;
arr.eval();
}
return {arr}; return {arr};
} }
@ -192,26 +222,54 @@ class PyFileWriter : public io::Writer {
seek_func_(file.attr("seek")), seek_func_(file.attr("seek")),
tell_func_(file.attr("tell")) {} tell_func_(file.attr("tell")) {}
~PyFileWriter() {
py::gil_scoped_acquire gil;
pyostream_.release().dec_ref();
write_func_.release().dec_ref();
seek_func_.release().dec_ref();
tell_func_.release().dec_ref();
}
bool is_open() const override { bool is_open() const override {
return !pyostream_.attr("closed").cast<bool>(); bool out;
{
py::gil_scoped_acquire gil;
out = !pyostream_.attr("closed").cast<bool>();
}
return out;
} }
bool good() const override { bool good() const override {
return !pyostream_.is_none(); bool out;
{
py::gil_scoped_acquire gil;
out = !pyostream_.is_none();
}
return out;
} }
size_t tell() const override { size_t tell() const override {
return tell_func_().cast<size_t>(); size_t out;
{
py::gil_scoped_acquire gil;
out = tell_func_().cast<size_t>();
}
return out;
} }
void seek(int64_t off, std::ios_base::seekdir way = std::ios_base::beg) void seek(int64_t off, std::ios_base::seekdir way = std::ios_base::beg)
override { override {
py::gil_scoped_acquire gil;
seek_func_(off, (int)way); seek_func_(off, (int)way);
} }
void write(const char* data, size_t n) override { void write(const char* data, size_t n) override {
py::gil_scoped_acquire gil;
py::object bytes_written = py::object bytes_written =
write_func_(py::memoryview::from_buffer(data, {n}, {sizeof(char)})); write_func_(py::memoryview::from_buffer(data, {n}, {sizeof(char)}));
if (bytes_written.is_none() || py::cast<size_t>(bytes_written) < n) { if (bytes_written.is_none() || py::cast<size_t>(bytes_written) < n) {
throw std::runtime_error("[load] Failed to write to python stream"); throw std::runtime_error("[load] Failed to write to python stream");
} }
@ -233,7 +291,12 @@ void mlx_save_helper(py::object file, array a, bool retain_graph) {
save(py::cast<std::string>(file), a, retain_graph); save(py::cast<std::string>(file), a, retain_graph);
return; return;
} else if (is_ostream_object(file)) { } else if (is_ostream_object(file)) {
save(std::make_shared<PyFileWriter>(file), a, retain_graph); auto writer = std::make_shared<PyFileWriter>(file);
{
py::gil_scoped_release gil;
save(writer, a, retain_graph);
}
return; return;
} }
@ -285,7 +348,11 @@ void mlx_savez_helper(
for (auto [k, a] : arrays_dict) { for (auto [k, a] : arrays_dict) {
std::string fname = k + ".npy"; std::string fname = k + ".npy";
auto py_ostream = zipfile_object.open(fname, 'w'); auto py_ostream = zipfile_object.open(fname, 'w');
save(std::make_shared<PyFileWriter>(py_ostream), a); auto writer = std::make_shared<PyFileWriter>(py_ostream);
{
py::gil_scoped_release gil;
save(writer, a);
}
} }
return; return;