mirror of
https://github.com/ml-explore/mlx.git
synced 2025-07-20 16:11:14 +08:00
No gil eval (#565)
This commit is contained in:
parent
8993382aaa
commit
37d98ba6ff
@ -39,6 +39,10 @@ py::list to_list(array& a, size_t index, int dim) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
auto to_scalar(array& a) {
|
auto to_scalar(array& a) {
|
||||||
|
{
|
||||||
|
py::gil_scoped_release nogil;
|
||||||
|
a.eval();
|
||||||
|
}
|
||||||
switch (a.dtype()) {
|
switch (a.dtype()) {
|
||||||
case bool_:
|
case bool_:
|
||||||
return py::cast(a.item<bool>());
|
return py::cast(a.item<bool>());
|
||||||
@ -73,7 +77,10 @@ py::object tolist(array& a) {
|
|||||||
if (a.ndim() == 0) {
|
if (a.ndim() == 0) {
|
||||||
return to_scalar(a);
|
return to_scalar(a);
|
||||||
}
|
}
|
||||||
a.eval();
|
{
|
||||||
|
py::gil_scoped_release nogil;
|
||||||
|
a.eval();
|
||||||
|
}
|
||||||
py::object pl;
|
py::object pl;
|
||||||
switch (a.dtype()) {
|
switch (a.dtype()) {
|
||||||
case bool_:
|
case bool_:
|
||||||
@ -644,6 +651,7 @@ void init_array(py::module_& m) {
|
|||||||
.def_buffer([](array& a) {
|
.def_buffer([](array& a) {
|
||||||
// Eval if not already evaled
|
// Eval if not already evaled
|
||||||
if (!a.is_evaled()) {
|
if (!a.is_evaled()) {
|
||||||
|
py::gil_scoped_release nogil;
|
||||||
a.eval();
|
a.eval();
|
||||||
}
|
}
|
||||||
return pybind11::buffer_info(
|
return pybind11::buffer_info(
|
||||||
@ -942,6 +950,7 @@ void init_array(py::module_& m) {
|
|||||||
"__repr__",
|
"__repr__",
|
||||||
[](array& a) {
|
[](array& a) {
|
||||||
if (!a.is_evaled()) {
|
if (!a.is_evaled()) {
|
||||||
|
py::gil_scoped_release nogil;
|
||||||
a.eval();
|
a.eval();
|
||||||
}
|
}
|
||||||
std::ostringstream os;
|
std::ostringstream os;
|
||||||
|
@ -195,6 +195,8 @@ mlx_load_gguf_helper(py::object file, StreamOrDevice s) {
|
|||||||
std::unordered_map<std::string, array> mlx_load_npz_helper(
|
std::unordered_map<std::string, array> mlx_load_npz_helper(
|
||||||
py::object file,
|
py::object file,
|
||||||
StreamOrDevice s) {
|
StreamOrDevice s) {
|
||||||
|
bool own_file = py::isinstance<py::str>(file);
|
||||||
|
|
||||||
py::module_ zipfile = py::module_::import("zipfile");
|
py::module_ zipfile = py::module_::import("zipfile");
|
||||||
if (!is_zip_file(zipfile, file)) {
|
if (!is_zip_file(zipfile, file)) {
|
||||||
throw std::invalid_argument(
|
throw std::invalid_argument(
|
||||||
@ -223,9 +225,11 @@ std::unordered_map<std::string, array> mlx_load_npz_helper(
|
|||||||
}
|
}
|
||||||
|
|
||||||
// If we don't own the stream and it was passed to us, eval immediately
|
// If we don't own the stream and it was passed to us, eval immediately
|
||||||
for (auto& [key, arr] : array_dict) {
|
if (!own_file) {
|
||||||
py::gil_scoped_release gil;
|
py::gil_scoped_release gil;
|
||||||
arr.eval();
|
for (auto& [key, arr] : array_dict) {
|
||||||
|
arr.eval();
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return array_dict;
|
return array_dict;
|
||||||
@ -260,7 +264,7 @@ LoadOutputTypes mlx_load_helper(
|
|||||||
fname = file.attr("name").cast<std::string>();
|
fname = file.attr("name").cast<std::string>();
|
||||||
} else {
|
} else {
|
||||||
throw std::invalid_argument(
|
throw std::invalid_argument(
|
||||||
"[load] Input must be a file-like object, or string");
|
"[load] Input must be a file-like object opened in binary mode, or string");
|
||||||
}
|
}
|
||||||
size_t ext = fname.find_last_of('.');
|
size_t ext = fname.find_last_of('.');
|
||||||
if (ext == std::string::npos) {
|
if (ext == std::string::npos) {
|
||||||
@ -432,7 +436,7 @@ void mlx_savez_helper(
|
|||||||
auto py_ostream = zipfile_object.open(fname, 'w');
|
auto py_ostream = zipfile_object.open(fname, 'w');
|
||||||
auto writer = std::make_shared<PyFileWriter>(py_ostream);
|
auto writer = std::make_shared<PyFileWriter>(py_ostream);
|
||||||
{
|
{
|
||||||
py::gil_scoped_release gil;
|
py::gil_scoped_release nogil;
|
||||||
save(writer, a);
|
save(writer, a);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -443,20 +447,20 @@ void mlx_savez_helper(
|
|||||||
void mlx_save_safetensor_helper(py::object file, py::dict d) {
|
void mlx_save_safetensor_helper(py::object file, py::dict d) {
|
||||||
auto arrays_map = d.cast<std::unordered_map<std::string, array>>();
|
auto arrays_map = d.cast<std::unordered_map<std::string, array>>();
|
||||||
if (py::isinstance<py::str>(file)) {
|
if (py::isinstance<py::str>(file)) {
|
||||||
save_safetensors(py::cast<std::string>(file), arrays_map);
|
{
|
||||||
return;
|
py::gil_scoped_release nogil;
|
||||||
|
save_safetensors(py::cast<std::string>(file), arrays_map);
|
||||||
|
}
|
||||||
} else if (is_ostream_object(file)) {
|
} else if (is_ostream_object(file)) {
|
||||||
auto writer = std::make_shared<PyFileWriter>(file);
|
auto writer = std::make_shared<PyFileWriter>(file);
|
||||||
{
|
{
|
||||||
py::gil_scoped_release gil;
|
py::gil_scoped_release nogil;
|
||||||
save_safetensors(writer, arrays_map);
|
save_safetensors(writer, arrays_map);
|
||||||
}
|
}
|
||||||
|
} else {
|
||||||
return;
|
throw std::invalid_argument(
|
||||||
|
"[save_safetensors] Input must be a file-like object, or string");
|
||||||
}
|
}
|
||||||
|
|
||||||
throw std::invalid_argument(
|
|
||||||
"[save_safetensors] Input must be a file-like object, or string");
|
|
||||||
}
|
}
|
||||||
|
|
||||||
void mlx_save_gguf_helper(
|
void mlx_save_gguf_helper(
|
||||||
@ -468,12 +472,17 @@ void mlx_save_gguf_helper(
|
|||||||
if (m) {
|
if (m) {
|
||||||
auto metadata_map =
|
auto metadata_map =
|
||||||
m.value().cast<std::unordered_map<std::string, MetaData>>();
|
m.value().cast<std::unordered_map<std::string, MetaData>>();
|
||||||
save_gguf(py::cast<std::string>(file), arrays_map, metadata_map);
|
{
|
||||||
|
py::gil_scoped_release nogil;
|
||||||
|
save_gguf(py::cast<std::string>(file), arrays_map, metadata_map);
|
||||||
|
}
|
||||||
} else {
|
} else {
|
||||||
save_gguf(py::cast<std::string>(file), arrays_map);
|
{
|
||||||
|
py::gil_scoped_release nogil;
|
||||||
|
save_gguf(py::cast<std::string>(file), arrays_map);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
return;
|
} else {
|
||||||
|
throw std::invalid_argument("[save_gguf] Input must be a string");
|
||||||
}
|
}
|
||||||
|
|
||||||
throw std::invalid_argument("[save_safetensors] Input must be a string");
|
|
||||||
}
|
}
|
||||||
|
@ -509,7 +509,10 @@ void init_transforms(py::module_& m) {
|
|||||||
"eval",
|
"eval",
|
||||||
[](const py::args& args) {
|
[](const py::args& args) {
|
||||||
std::vector<array> arrays = tree_flatten(args);
|
std::vector<array> arrays = tree_flatten(args);
|
||||||
eval(arrays);
|
{
|
||||||
|
py::gil_scoped_release nogil;
|
||||||
|
eval(arrays);
|
||||||
|
}
|
||||||
},
|
},
|
||||||
R"pbdoc(
|
R"pbdoc(
|
||||||
eval(*args) -> None
|
eval(*args) -> None
|
||||||
|
Loading…
Reference in New Issue
Block a user