Fix load compilation (#298)

This commit is contained in:
Angelos Katharopoulos 2023-12-27 06:20:45 -08:00 committed by GitHub
parent 1f6ab6a556
commit 79c95b6919
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -164,7 +164,7 @@ std::unordered_map<std::string, array> mlx_load_safetensor_helper(
py::object file,
StreamOrDevice s) {
if (py::isinstance<py::str>(file)) { // Assume .safetensors file path string
return {load_safetensors(py::cast<std::string>(file), s)};
return load_safetensors(py::cast<std::string>(file), s);
} else if (is_istream_object(file)) {
// If we don't own the stream and it was passed to us, eval immediately
auto arr = load_safetensors(std::make_shared<PyFileReader>(file), s);
@ -174,7 +174,7 @@ std::unordered_map<std::string, array> mlx_load_safetensor_helper(
arr.eval();
}
}
return {arr};
return arr;
}
throw std::invalid_argument(
@ -217,12 +217,12 @@ std::unordered_map<std::string, array> mlx_load_npz_helper(
arr.eval();
}
return {array_dict};
return array_dict;
}
array mlx_load_npy_helper(py::object file, StreamOrDevice s) {
if (py::isinstance<py::str>(file)) { // Assume .npy file path string
return {load(py::cast<std::string>(file), s)};
return load(py::cast<std::string>(file), s);
} else if (is_istream_object(file)) {
// If we don't own the stream and it was passed to us, eval immediately
auto arr = load(std::make_shared<PyFileReader>(file), s);
@ -230,7 +230,7 @@ array mlx_load_npy_helper(py::object file, StreamOrDevice s) {
py::gil_scoped_release gil;
arr.eval();
}
return {arr};
return arr;
}
throw std::invalid_argument(
"[load_npy] Input must be a file-like object, or string");