mirror of
https://github.com/ml-explore/mlx.git
synced 2025-08-21 20:46:46 +08:00
added save bindings and fixed header
This commit is contained in:
parent
9a39254959
commit
c88d3174aa
@ -1066,7 +1066,7 @@ std::unordered_map<std::string, array> load_safetensor(
|
|||||||
StreamOrDevice s = {});
|
StreamOrDevice s = {});
|
||||||
|
|
||||||
void save_safetensor(
|
void save_safetensor(
|
||||||
std::shared_ptr<io::Reader> in_stream,
|
std::shared_ptr<io::Writer> in_stream,
|
||||||
std::unordered_map<std::string, array>);
|
std::unordered_map<std::string, array>);
|
||||||
void save_safetensor(
|
void save_safetensor(
|
||||||
const std::string& file,
|
const std::string& file,
|
||||||
|
@ -382,3 +382,22 @@ void mlx_savez_helper(
|
|||||||
|
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void mlx_save_safetensor_helper(py::object file, py::dict d) {
|
||||||
|
auto arrays_map = d.cast<std::unordered_map<std::string, array>>();
|
||||||
|
if (py::isinstance<py::str>(file)) {
|
||||||
|
save_safetensor(py::cast<std::string>(file), arrays_map);
|
||||||
|
return;
|
||||||
|
} else if (is_ostream_object(file)) {
|
||||||
|
auto writer = std::make_shared<PyFileWriter>(file);
|
||||||
|
{
|
||||||
|
py::gil_scoped_release gil;
|
||||||
|
save_safetensor(writer, arrays_map);
|
||||||
|
}
|
||||||
|
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
throw std::invalid_argument(
|
||||||
|
"[save_safetensor] Input must be a file-like object, string, or pathlib.Path");
|
||||||
|
}
|
@ -15,6 +15,8 @@ using DictOrArray = std::variant<array, std::unordered_map<std::string, array>>;
|
|||||||
std::unordered_map<std::string, array> mlx_load_safetensor_helper(
|
std::unordered_map<std::string, array> mlx_load_safetensor_helper(
|
||||||
py::object file,
|
py::object file,
|
||||||
StreamOrDevice s);
|
StreamOrDevice s);
|
||||||
|
void mlx_save_safetensor_helper(py::object file, py::dict d);
|
||||||
|
|
||||||
DictOrArray mlx_load_helper(py::object file, StreamOrDevice s);
|
DictOrArray mlx_load_helper(py::object file, StreamOrDevice s);
|
||||||
void mlx_save_helper(
|
void mlx_save_helper(
|
||||||
py::object file,
|
py::object file,
|
||||||
|
@ -2963,6 +2963,22 @@ void init_ops(py::module_& m) {
|
|||||||
Returns:
|
Returns:
|
||||||
result dict: The loaded dict mapping name to array from the ``.safetensors`` file
|
result dict: The loaded dict mapping name to array from the ``.safetensors`` file
|
||||||
)pbdoc");
|
)pbdoc");
|
||||||
|
m.def(
|
||||||
|
"save_safetensor",
|
||||||
|
&mlx_save_safetensor_helper,
|
||||||
|
"file"_a,
|
||||||
|
"d"_a,
|
||||||
|
py::pos_only(),
|
||||||
|
py::kw_only(),
|
||||||
|
R"pbdoc(
|
||||||
|
save_safetensor(file: str, d: Dict[str, array], /, *, stream: Union[None, Stream, Device] = None)
|
||||||
|
|
||||||
|
Save array(s) to a binary file in ``.safetensors`` format.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
file (file, str): File in which the array is saved
|
||||||
|
d (Dict[str, array]): The dict mapping name to array to be saved
|
||||||
|
)pbdoc");
|
||||||
m.def(
|
m.def(
|
||||||
"where",
|
"where",
|
||||||
[](const ScalarOrArray& condition,
|
[](const ScalarOrArray& condition,
|
||||||
|
Loading…
Reference in New Issue
Block a user