diff --git a/python/mlx/nn/layers/activations.py b/python/mlx/nn/layers/activations.py index dbe5f4adf..cb1460ef3 100644 --- a/python/mlx/nn/layers/activations.py +++ b/python/mlx/nn/layers/activations.py @@ -33,7 +33,7 @@ def silu(x): def gelu(x): - """Applies the Gaussian Error Linear Units function. + r"""Applies the Gaussian Error Linear Units function. .. math:: \\textrm{GELU}(x) = x * \Phi(x) diff --git a/python/src/load.cpp b/python/src/load.cpp index b0327e125..d3c603008 100644 --- a/python/src/load.cpp +++ b/python/src/load.cpp @@ -11,8 +11,6 @@ #include #include -#include - #include "mlx/load.h" #include "mlx/ops.h" #include "mlx/utils.h" @@ -99,26 +97,54 @@ class PyFileReader : public io::Reader { seek_func_(file.attr("seek")), tell_func_(file.attr("tell")) {} + ~PyFileReader() { + py::gil_scoped_acquire gil; + + pyistream_.release().dec_ref(); + readinto_func_.release().dec_ref(); + seek_func_.release().dec_ref(); + tell_func_.release().dec_ref(); + } + bool is_open() const override { - return !pyistream_.attr("closed").cast(); + bool out; + { + py::gil_scoped_acquire gil; + out = !pyistream_.attr("closed").cast(); + } + return out; } bool good() const override { - return !pyistream_.is_none(); + bool out; + { + py::gil_scoped_acquire gil; + out = !pyistream_.is_none(); + } + return out; } size_t tell() const override { - return tell_func_().cast(); + size_t out; + { + py::gil_scoped_acquire gil; + out = tell_func_().cast(); + } + return out; } void seek(int64_t off, std::ios_base::seekdir way = std::ios_base::beg) override { + py::gil_scoped_acquire gil; seek_func_(off, (int)way); } void read(char* data, size_t n) override { + py::gil_scoped_acquire gil; + py::object bytes_read = readinto_func_(py::memoryview::from_buffer(data, {n}, {sizeof(char)})); + if (bytes_read.is_none() || py::cast(bytes_read) < n) { throw std::runtime_error("[load] Failed to read from python stream"); } @@ -163,6 +189,7 @@ DictOrArray mlx_load_helper(py::object file, StreamOrDevice s) { // If we don't own the stream and it was passed to us, eval immediately for (auto& [key, arr] : array_dict) { + py::gil_scoped_release gil; arr.eval(); } @@ -172,7 +199,10 @@ DictOrArray mlx_load_helper(py::object file, StreamOrDevice s) { } else if (is_istream_object(file)) { // If we don't own the stream and it was passed to us, eval immediately auto arr = load(std::make_shared(file), s); - arr.eval(); + { + py::gil_scoped_release gil; + arr.eval(); + } return {arr}; } @@ -192,26 +222,54 @@ class PyFileWriter : public io::Writer { seek_func_(file.attr("seek")), tell_func_(file.attr("tell")) {} + ~PyFileWriter() { + py::gil_scoped_acquire gil; + + pyostream_.release().dec_ref(); + write_func_.release().dec_ref(); + seek_func_.release().dec_ref(); + tell_func_.release().dec_ref(); + } + bool is_open() const override { - return !pyostream_.attr("closed").cast(); + bool out; + { + py::gil_scoped_acquire gil; + out = !pyostream_.attr("closed").cast(); + } + return out; } bool good() const override { - return !pyostream_.is_none(); + bool out; + { + py::gil_scoped_acquire gil; + out = !pyostream_.is_none(); + } + return out; } size_t tell() const override { - return tell_func_().cast(); + size_t out; + { + py::gil_scoped_acquire gil; + out = tell_func_().cast(); + } + return out; } void seek(int64_t off, std::ios_base::seekdir way = std::ios_base::beg) override { + py::gil_scoped_acquire gil; seek_func_(off, (int)way); } void write(const char* data, size_t n) override { + py::gil_scoped_acquire gil; + py::object bytes_written = write_func_(py::memoryview::from_buffer(data, {n}, {sizeof(char)})); + if (bytes_written.is_none() || py::cast(bytes_written) < n) { throw std::runtime_error("[load] Failed to write to python stream"); } @@ -233,7 +291,12 @@ void mlx_save_helper(py::object file, array a, bool retain_graph) { save(py::cast(file), a, retain_graph); return; } else if (is_ostream_object(file)) { - save(std::make_shared(file), a, retain_graph); + auto writer = std::make_shared(file); + { + py::gil_scoped_release gil; + save(writer, a, retain_graph); + } + return; } @@ -285,7 +348,11 @@ void mlx_savez_helper( for (auto [k, a] : arrays_dict) { std::string fname = k + ".npy"; auto py_ostream = zipfile_object.open(fname, 'w'); - save(std::make_shared(py_ostream), a); + auto writer = std::make_shared(py_ostream); + { + py::gil_scoped_release gil; + save(writer, a); + } } return;