mirror of
https://github.com/ml-explore/mlx.git
synced 2025-08-22 04:56:41 +08:00
add python bindings for loading
This commit is contained in:
parent
dcfa2700f6
commit
9a39254959
@ -161,6 +161,27 @@ class PyFileReader : public io::Reader {
|
|||||||
py::object tell_func_;
|
py::object tell_func_;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
std::unordered_map<std::string, array> mlx_load_safetensor_helper(
|
||||||
|
py::object file,
|
||||||
|
StreamOrDevice s) {
|
||||||
|
if (py::isinstance<py::str>(file)) { // Assume .safetensors file path string
|
||||||
|
return {load_safetensor(py::cast<std::string>(file), s)};
|
||||||
|
} else if (is_istream_object(file)) {
|
||||||
|
// If we don't own the stream and it was passed to us, eval immediately
|
||||||
|
auto arr = load_safetensor(std::make_shared<PyFileReader>(file), s);
|
||||||
|
{
|
||||||
|
py::gil_scoped_release gil;
|
||||||
|
for (auto& [key, arr] : arr) {
|
||||||
|
arr.eval();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return {arr};
|
||||||
|
}
|
||||||
|
|
||||||
|
throw std::invalid_argument(
|
||||||
|
"[load] Input must be a file-like object, string, or pathlib.Path");
|
||||||
|
}
|
||||||
|
|
||||||
DictOrArray mlx_load_helper(py::object file, StreamOrDevice s) {
|
DictOrArray mlx_load_helper(py::object file, StreamOrDevice s) {
|
||||||
py::module_ zipfile = py::module_::import("zipfile");
|
py::module_ zipfile = py::module_::import("zipfile");
|
||||||
|
|
||||||
|
@ -12,6 +12,9 @@ using namespace mlx::core;
|
|||||||
|
|
||||||
using DictOrArray = std::variant<array, std::unordered_map<std::string, array>>;
|
using DictOrArray = std::variant<array, std::unordered_map<std::string, array>>;
|
||||||
|
|
||||||
|
std::unordered_map<std::string, array> mlx_load_safetensor_helper(
|
||||||
|
py::object file,
|
||||||
|
StreamOrDevice s);
|
||||||
DictOrArray mlx_load_helper(py::object file, StreamOrDevice s);
|
DictOrArray mlx_load_helper(py::object file, StreamOrDevice s);
|
||||||
void mlx_save_helper(
|
void mlx_save_helper(
|
||||||
py::object file,
|
py::object file,
|
||||||
|
@ -2945,6 +2945,24 @@ void init_ops(py::module_& m) {
|
|||||||
Returns:
|
Returns:
|
||||||
result (array, dict): The loaded array if ``.npy`` file or a dict mapping name to array if ``.npz`` file
|
result (array, dict): The loaded array if ``.npy`` file or a dict mapping name to array if ``.npz`` file
|
||||||
)pbdoc");
|
)pbdoc");
|
||||||
|
m.def(
|
||||||
|
"load_safetensor",
|
||||||
|
&mlx_load_safetensor_helper,
|
||||||
|
"file"_a,
|
||||||
|
py::pos_only(),
|
||||||
|
py::kw_only(),
|
||||||
|
"stream"_a = none,
|
||||||
|
R"pbdoc(
|
||||||
|
load_safetensor(file: str, /, *, stream: Union[None, Stream, Device] = None) -> Dict[str, array]
|
||||||
|
|
||||||
|
Load array(s) from a binary file in ``.safetensors`` format.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
file (file, str): File in which the array is saved
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
result dict: The loaded dict mapping name to array from the ``.safetensors`` file
|
||||||
|
)pbdoc");
|
||||||
m.def(
|
m.def(
|
||||||
"where",
|
"where",
|
||||||
[](const ScalarOrArray& condition,
|
[](const ScalarOrArray& condition,
|
||||||
|
Loading…
Reference in New Issue
Block a user