mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-25 09:51:17 +08:00
Fix load compilation (#298)
This commit is contained in:
parent
1f6ab6a556
commit
79c95b6919
@ -164,7 +164,7 @@ std::unordered_map<std::string, array> mlx_load_safetensor_helper(
|
|||||||
py::object file,
|
py::object file,
|
||||||
StreamOrDevice s) {
|
StreamOrDevice s) {
|
||||||
if (py::isinstance<py::str>(file)) { // Assume .safetensors file path string
|
if (py::isinstance<py::str>(file)) { // Assume .safetensors file path string
|
||||||
return {load_safetensors(py::cast<std::string>(file), s)};
|
return load_safetensors(py::cast<std::string>(file), s);
|
||||||
} else if (is_istream_object(file)) {
|
} else if (is_istream_object(file)) {
|
||||||
// If we don't own the stream and it was passed to us, eval immediately
|
// If we don't own the stream and it was passed to us, eval immediately
|
||||||
auto arr = load_safetensors(std::make_shared<PyFileReader>(file), s);
|
auto arr = load_safetensors(std::make_shared<PyFileReader>(file), s);
|
||||||
@ -174,7 +174,7 @@ std::unordered_map<std::string, array> mlx_load_safetensor_helper(
|
|||||||
arr.eval();
|
arr.eval();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return {arr};
|
return arr;
|
||||||
}
|
}
|
||||||
|
|
||||||
throw std::invalid_argument(
|
throw std::invalid_argument(
|
||||||
@ -217,12 +217,12 @@ std::unordered_map<std::string, array> mlx_load_npz_helper(
|
|||||||
arr.eval();
|
arr.eval();
|
||||||
}
|
}
|
||||||
|
|
||||||
return {array_dict};
|
return array_dict;
|
||||||
}
|
}
|
||||||
|
|
||||||
array mlx_load_npy_helper(py::object file, StreamOrDevice s) {
|
array mlx_load_npy_helper(py::object file, StreamOrDevice s) {
|
||||||
if (py::isinstance<py::str>(file)) { // Assume .npy file path string
|
if (py::isinstance<py::str>(file)) { // Assume .npy file path string
|
||||||
return {load(py::cast<std::string>(file), s)};
|
return load(py::cast<std::string>(file), s);
|
||||||
} else if (is_istream_object(file)) {
|
} else if (is_istream_object(file)) {
|
||||||
// If we don't own the stream and it was passed to us, eval immediately
|
// If we don't own the stream and it was passed to us, eval immediately
|
||||||
auto arr = load(std::make_shared<PyFileReader>(file), s);
|
auto arr = load(std::make_shared<PyFileReader>(file), s);
|
||||||
@ -230,7 +230,7 @@ array mlx_load_npy_helper(py::object file, StreamOrDevice s) {
|
|||||||
py::gil_scoped_release gil;
|
py::gil_scoped_release gil;
|
||||||
arr.eval();
|
arr.eval();
|
||||||
}
|
}
|
||||||
return {arr};
|
return arr;
|
||||||
}
|
}
|
||||||
throw std::invalid_argument(
|
throw std::invalid_argument(
|
||||||
"[load_npy] Input must be a file-like object, or string");
|
"[load_npy] Input must be a file-like object, or string");
|
||||||
|
Loading…
Reference in New Issue
Block a user