mirror of
https://github.com/ml-explore/mlx.git
synced 2025-08-21 20:46:46 +08:00
NPY loading segfault bug (#34)
* Fixed Gil semantics in loading and saving from python file streams
This commit is contained in:
parent
170e4b2d43
commit
2440fe0124
@ -33,7 +33,7 @@ def silu(x):
|
|||||||
|
|
||||||
|
|
||||||
def gelu(x):
|
def gelu(x):
|
||||||
"""Applies the Gaussian Error Linear Units function.
|
r"""Applies the Gaussian Error Linear Units function.
|
||||||
|
|
||||||
.. math::
|
.. math::
|
||||||
\\textrm{GELU}(x) = x * \Phi(x)
|
\\textrm{GELU}(x) = x * \Phi(x)
|
||||||
|
@ -11,8 +11,6 @@
|
|||||||
#include <unordered_map>
|
#include <unordered_map>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
#include <iostream>
|
|
||||||
|
|
||||||
#include "mlx/load.h"
|
#include "mlx/load.h"
|
||||||
#include "mlx/ops.h"
|
#include "mlx/ops.h"
|
||||||
#include "mlx/utils.h"
|
#include "mlx/utils.h"
|
||||||
@ -99,26 +97,54 @@ class PyFileReader : public io::Reader {
|
|||||||
seek_func_(file.attr("seek")),
|
seek_func_(file.attr("seek")),
|
||||||
tell_func_(file.attr("tell")) {}
|
tell_func_(file.attr("tell")) {}
|
||||||
|
|
||||||
|
~PyFileReader() {
|
||||||
|
py::gil_scoped_acquire gil;
|
||||||
|
|
||||||
|
pyistream_.release().dec_ref();
|
||||||
|
readinto_func_.release().dec_ref();
|
||||||
|
seek_func_.release().dec_ref();
|
||||||
|
tell_func_.release().dec_ref();
|
||||||
|
}
|
||||||
|
|
||||||
bool is_open() const override {
|
bool is_open() const override {
|
||||||
return !pyistream_.attr("closed").cast<bool>();
|
bool out;
|
||||||
|
{
|
||||||
|
py::gil_scoped_acquire gil;
|
||||||
|
out = !pyistream_.attr("closed").cast<bool>();
|
||||||
|
}
|
||||||
|
return out;
|
||||||
}
|
}
|
||||||
|
|
||||||
bool good() const override {
|
bool good() const override {
|
||||||
return !pyistream_.is_none();
|
bool out;
|
||||||
|
{
|
||||||
|
py::gil_scoped_acquire gil;
|
||||||
|
out = !pyistream_.is_none();
|
||||||
|
}
|
||||||
|
return out;
|
||||||
}
|
}
|
||||||
|
|
||||||
size_t tell() const override {
|
size_t tell() const override {
|
||||||
return tell_func_().cast<size_t>();
|
size_t out;
|
||||||
|
{
|
||||||
|
py::gil_scoped_acquire gil;
|
||||||
|
out = tell_func_().cast<size_t>();
|
||||||
|
}
|
||||||
|
return out;
|
||||||
}
|
}
|
||||||
|
|
||||||
void seek(int64_t off, std::ios_base::seekdir way = std::ios_base::beg)
|
void seek(int64_t off, std::ios_base::seekdir way = std::ios_base::beg)
|
||||||
override {
|
override {
|
||||||
|
py::gil_scoped_acquire gil;
|
||||||
seek_func_(off, (int)way);
|
seek_func_(off, (int)way);
|
||||||
}
|
}
|
||||||
|
|
||||||
void read(char* data, size_t n) override {
|
void read(char* data, size_t n) override {
|
||||||
|
py::gil_scoped_acquire gil;
|
||||||
|
|
||||||
py::object bytes_read =
|
py::object bytes_read =
|
||||||
readinto_func_(py::memoryview::from_buffer(data, {n}, {sizeof(char)}));
|
readinto_func_(py::memoryview::from_buffer(data, {n}, {sizeof(char)}));
|
||||||
|
|
||||||
if (bytes_read.is_none() || py::cast<size_t>(bytes_read) < n) {
|
if (bytes_read.is_none() || py::cast<size_t>(bytes_read) < n) {
|
||||||
throw std::runtime_error("[load] Failed to read from python stream");
|
throw std::runtime_error("[load] Failed to read from python stream");
|
||||||
}
|
}
|
||||||
@ -163,6 +189,7 @@ DictOrArray mlx_load_helper(py::object file, StreamOrDevice s) {
|
|||||||
|
|
||||||
// If we don't own the stream and it was passed to us, eval immediately
|
// If we don't own the stream and it was passed to us, eval immediately
|
||||||
for (auto& [key, arr] : array_dict) {
|
for (auto& [key, arr] : array_dict) {
|
||||||
|
py::gil_scoped_release gil;
|
||||||
arr.eval();
|
arr.eval();
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -172,7 +199,10 @@ DictOrArray mlx_load_helper(py::object file, StreamOrDevice s) {
|
|||||||
} else if (is_istream_object(file)) {
|
} else if (is_istream_object(file)) {
|
||||||
// If we don't own the stream and it was passed to us, eval immediately
|
// If we don't own the stream and it was passed to us, eval immediately
|
||||||
auto arr = load(std::make_shared<PyFileReader>(file), s);
|
auto arr = load(std::make_shared<PyFileReader>(file), s);
|
||||||
|
{
|
||||||
|
py::gil_scoped_release gil;
|
||||||
arr.eval();
|
arr.eval();
|
||||||
|
}
|
||||||
return {arr};
|
return {arr};
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -192,26 +222,54 @@ class PyFileWriter : public io::Writer {
|
|||||||
seek_func_(file.attr("seek")),
|
seek_func_(file.attr("seek")),
|
||||||
tell_func_(file.attr("tell")) {}
|
tell_func_(file.attr("tell")) {}
|
||||||
|
|
||||||
|
~PyFileWriter() {
|
||||||
|
py::gil_scoped_acquire gil;
|
||||||
|
|
||||||
|
pyostream_.release().dec_ref();
|
||||||
|
write_func_.release().dec_ref();
|
||||||
|
seek_func_.release().dec_ref();
|
||||||
|
tell_func_.release().dec_ref();
|
||||||
|
}
|
||||||
|
|
||||||
bool is_open() const override {
|
bool is_open() const override {
|
||||||
return !pyostream_.attr("closed").cast<bool>();
|
bool out;
|
||||||
|
{
|
||||||
|
py::gil_scoped_acquire gil;
|
||||||
|
out = !pyostream_.attr("closed").cast<bool>();
|
||||||
|
}
|
||||||
|
return out;
|
||||||
}
|
}
|
||||||
|
|
||||||
bool good() const override {
|
bool good() const override {
|
||||||
return !pyostream_.is_none();
|
bool out;
|
||||||
|
{
|
||||||
|
py::gil_scoped_acquire gil;
|
||||||
|
out = !pyostream_.is_none();
|
||||||
|
}
|
||||||
|
return out;
|
||||||
}
|
}
|
||||||
|
|
||||||
size_t tell() const override {
|
size_t tell() const override {
|
||||||
return tell_func_().cast<size_t>();
|
size_t out;
|
||||||
|
{
|
||||||
|
py::gil_scoped_acquire gil;
|
||||||
|
out = tell_func_().cast<size_t>();
|
||||||
|
}
|
||||||
|
return out;
|
||||||
}
|
}
|
||||||
|
|
||||||
void seek(int64_t off, std::ios_base::seekdir way = std::ios_base::beg)
|
void seek(int64_t off, std::ios_base::seekdir way = std::ios_base::beg)
|
||||||
override {
|
override {
|
||||||
|
py::gil_scoped_acquire gil;
|
||||||
seek_func_(off, (int)way);
|
seek_func_(off, (int)way);
|
||||||
}
|
}
|
||||||
|
|
||||||
void write(const char* data, size_t n) override {
|
void write(const char* data, size_t n) override {
|
||||||
|
py::gil_scoped_acquire gil;
|
||||||
|
|
||||||
py::object bytes_written =
|
py::object bytes_written =
|
||||||
write_func_(py::memoryview::from_buffer(data, {n}, {sizeof(char)}));
|
write_func_(py::memoryview::from_buffer(data, {n}, {sizeof(char)}));
|
||||||
|
|
||||||
if (bytes_written.is_none() || py::cast<size_t>(bytes_written) < n) {
|
if (bytes_written.is_none() || py::cast<size_t>(bytes_written) < n) {
|
||||||
throw std::runtime_error("[load] Failed to write to python stream");
|
throw std::runtime_error("[load] Failed to write to python stream");
|
||||||
}
|
}
|
||||||
@ -233,7 +291,12 @@ void mlx_save_helper(py::object file, array a, bool retain_graph) {
|
|||||||
save(py::cast<std::string>(file), a, retain_graph);
|
save(py::cast<std::string>(file), a, retain_graph);
|
||||||
return;
|
return;
|
||||||
} else if (is_ostream_object(file)) {
|
} else if (is_ostream_object(file)) {
|
||||||
save(std::make_shared<PyFileWriter>(file), a, retain_graph);
|
auto writer = std::make_shared<PyFileWriter>(file);
|
||||||
|
{
|
||||||
|
py::gil_scoped_release gil;
|
||||||
|
save(writer, a, retain_graph);
|
||||||
|
}
|
||||||
|
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -285,7 +348,11 @@ void mlx_savez_helper(
|
|||||||
for (auto [k, a] : arrays_dict) {
|
for (auto [k, a] : arrays_dict) {
|
||||||
std::string fname = k + ".npy";
|
std::string fname = k + ".npy";
|
||||||
auto py_ostream = zipfile_object.open(fname, 'w');
|
auto py_ostream = zipfile_object.open(fname, 'w');
|
||||||
save(std::make_shared<PyFileWriter>(py_ostream), a);
|
auto writer = std::make_shared<PyFileWriter>(py_ostream);
|
||||||
|
{
|
||||||
|
py::gil_scoped_release gil;
|
||||||
|
save(writer, a);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return;
|
return;
|
||||||
|
Loading…
Reference in New Issue
Block a user