No gil eval (#565)

This commit is contained in:
Angelos Katharopoulos 2024-01-26 22:03:52 -08:00 committed by GitHub
parent 8993382aaa
commit 37d98ba6ff
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 40 additions and 19 deletions

View File

@ -39,6 +39,10 @@ py::list to_list(array& a, size_t index, int dim) {
} }
auto to_scalar(array& a) { auto to_scalar(array& a) {
{
py::gil_scoped_release nogil;
a.eval();
}
switch (a.dtype()) { switch (a.dtype()) {
case bool_: case bool_:
return py::cast(a.item<bool>()); return py::cast(a.item<bool>());
@ -73,7 +77,10 @@ py::object tolist(array& a) {
if (a.ndim() == 0) { if (a.ndim() == 0) {
return to_scalar(a); return to_scalar(a);
} }
a.eval(); {
py::gil_scoped_release nogil;
a.eval();
}
py::object pl; py::object pl;
switch (a.dtype()) { switch (a.dtype()) {
case bool_: case bool_:
@ -644,6 +651,7 @@ void init_array(py::module_& m) {
.def_buffer([](array& a) { .def_buffer([](array& a) {
// Eval if not already evaled // Eval if not already evaled
if (!a.is_evaled()) { if (!a.is_evaled()) {
py::gil_scoped_release nogil;
a.eval(); a.eval();
} }
return pybind11::buffer_info( return pybind11::buffer_info(
@ -942,6 +950,7 @@ void init_array(py::module_& m) {
"__repr__", "__repr__",
[](array& a) { [](array& a) {
if (!a.is_evaled()) { if (!a.is_evaled()) {
py::gil_scoped_release nogil;
a.eval(); a.eval();
} }
std::ostringstream os; std::ostringstream os;

View File

@ -195,6 +195,8 @@ mlx_load_gguf_helper(py::object file, StreamOrDevice s) {
std::unordered_map<std::string, array> mlx_load_npz_helper( std::unordered_map<std::string, array> mlx_load_npz_helper(
py::object file, py::object file,
StreamOrDevice s) { StreamOrDevice s) {
bool own_file = py::isinstance<py::str>(file);
py::module_ zipfile = py::module_::import("zipfile"); py::module_ zipfile = py::module_::import("zipfile");
if (!is_zip_file(zipfile, file)) { if (!is_zip_file(zipfile, file)) {
throw std::invalid_argument( throw std::invalid_argument(
@ -223,9 +225,11 @@ std::unordered_map<std::string, array> mlx_load_npz_helper(
} }
// If we don't own the stream and it was passed to us, eval immediately // If we don't own the stream and it was passed to us, eval immediately
for (auto& [key, arr] : array_dict) { if (!own_file) {
py::gil_scoped_release gil; py::gil_scoped_release gil;
arr.eval(); for (auto& [key, arr] : array_dict) {
arr.eval();
}
} }
return array_dict; return array_dict;
@ -260,7 +264,7 @@ LoadOutputTypes mlx_load_helper(
fname = file.attr("name").cast<std::string>(); fname = file.attr("name").cast<std::string>();
} else { } else {
throw std::invalid_argument( throw std::invalid_argument(
"[load] Input must be a file-like object, or string"); "[load] Input must be a file-like object opened in binary mode, or string");
} }
size_t ext = fname.find_last_of('.'); size_t ext = fname.find_last_of('.');
if (ext == std::string::npos) { if (ext == std::string::npos) {
@ -432,7 +436,7 @@ void mlx_savez_helper(
auto py_ostream = zipfile_object.open(fname, 'w'); auto py_ostream = zipfile_object.open(fname, 'w');
auto writer = std::make_shared<PyFileWriter>(py_ostream); auto writer = std::make_shared<PyFileWriter>(py_ostream);
{ {
py::gil_scoped_release gil; py::gil_scoped_release nogil;
save(writer, a); save(writer, a);
} }
} }
@ -443,20 +447,20 @@ void mlx_savez_helper(
void mlx_save_safetensor_helper(py::object file, py::dict d) { void mlx_save_safetensor_helper(py::object file, py::dict d) {
auto arrays_map = d.cast<std::unordered_map<std::string, array>>(); auto arrays_map = d.cast<std::unordered_map<std::string, array>>();
if (py::isinstance<py::str>(file)) { if (py::isinstance<py::str>(file)) {
save_safetensors(py::cast<std::string>(file), arrays_map); {
return; py::gil_scoped_release nogil;
save_safetensors(py::cast<std::string>(file), arrays_map);
}
} else if (is_ostream_object(file)) { } else if (is_ostream_object(file)) {
auto writer = std::make_shared<PyFileWriter>(file); auto writer = std::make_shared<PyFileWriter>(file);
{ {
py::gil_scoped_release gil; py::gil_scoped_release nogil;
save_safetensors(writer, arrays_map); save_safetensors(writer, arrays_map);
} }
} else {
return; throw std::invalid_argument(
"[save_safetensors] Input must be a file-like object, or string");
} }
throw std::invalid_argument(
"[save_safetensors] Input must be a file-like object, or string");
} }
void mlx_save_gguf_helper( void mlx_save_gguf_helper(
@ -468,12 +472,17 @@ void mlx_save_gguf_helper(
if (m) { if (m) {
auto metadata_map = auto metadata_map =
m.value().cast<std::unordered_map<std::string, MetaData>>(); m.value().cast<std::unordered_map<std::string, MetaData>>();
save_gguf(py::cast<std::string>(file), arrays_map, metadata_map); {
py::gil_scoped_release nogil;
save_gguf(py::cast<std::string>(file), arrays_map, metadata_map);
}
} else { } else {
save_gguf(py::cast<std::string>(file), arrays_map); {
py::gil_scoped_release nogil;
save_gguf(py::cast<std::string>(file), arrays_map);
}
} }
return; } else {
throw std::invalid_argument("[save_gguf] Input must be a string");
} }
throw std::invalid_argument("[save_safetensors] Input must be a string");
} }

View File

@ -509,7 +509,10 @@ void init_transforms(py::module_& m) {
"eval", "eval",
[](const py::args& args) { [](const py::args& args) {
std::vector<array> arrays = tree_flatten(args); std::vector<array> arrays = tree_flatten(args);
eval(arrays); {
py::gil_scoped_release nogil;
eval(arrays);
}
}, },
R"pbdoc( R"pbdoc(
eval(*args) -> None eval(*args) -> None