added save bindings and fixed header

This commit is contained in:
dc-dc-dc 2023-12-18 20:25:51 -05:00
parent 9a39254959
commit c88d3174aa
4 changed files with 38 additions and 1 deletions

View File

@ -1066,7 +1066,7 @@ std::unordered_map<std::string, array> load_safetensor(
StreamOrDevice s = {}); StreamOrDevice s = {});
void save_safetensor( void save_safetensor(
std::shared_ptr<io::Reader> in_stream, std::shared_ptr<io::Writer> in_stream,
std::unordered_map<std::string, array>); std::unordered_map<std::string, array>);
void save_safetensor( void save_safetensor(
const std::string& file, const std::string& file,

View File

@ -382,3 +382,22 @@ void mlx_savez_helper(
return; return;
} }
void mlx_save_safetensor_helper(py::object file, py::dict d) {
auto arrays_map = d.cast<std::unordered_map<std::string, array>>();
if (py::isinstance<py::str>(file)) {
save_safetensor(py::cast<std::string>(file), arrays_map);
return;
} else if (is_ostream_object(file)) {
auto writer = std::make_shared<PyFileWriter>(file);
{
py::gil_scoped_release gil;
save_safetensor(writer, arrays_map);
}
return;
}
throw std::invalid_argument(
"[save_safetensor] Input must be a file-like object, string, or pathlib.Path");
}

View File

@ -15,6 +15,8 @@ using DictOrArray = std::variant<array, std::unordered_map<std::string, array>>;
std::unordered_map<std::string, array> mlx_load_safetensor_helper( std::unordered_map<std::string, array> mlx_load_safetensor_helper(
py::object file, py::object file,
StreamOrDevice s); StreamOrDevice s);
void mlx_save_safetensor_helper(py::object file, py::dict d);
DictOrArray mlx_load_helper(py::object file, StreamOrDevice s); DictOrArray mlx_load_helper(py::object file, StreamOrDevice s);
void mlx_save_helper( void mlx_save_helper(
py::object file, py::object file,

View File

@ -2963,6 +2963,22 @@ void init_ops(py::module_& m) {
Returns: Returns:
result dict: The loaded dict mapping name to array from the ``.safetensors`` file result dict: The loaded dict mapping name to array from the ``.safetensors`` file
)pbdoc"); )pbdoc");
m.def(
"save_safetensor",
&mlx_save_safetensor_helper,
"file"_a,
"d"_a,
py::pos_only(),
py::kw_only(),
R"pbdoc(
save_safetensor(file: str, d: Dict[str, array], /, *, stream: Union[None, Stream, Device] = None)
Save array(s) to a binary file in ``.safetensors`` format.
Args:
file (file, str): File in which the array is saved
d (Dict[str, array]): The dict mapping name to array to be saved
)pbdoc");
m.def( m.def(
"where", "where",
[](const ScalarOrArray& condition, [](const ScalarOrArray& condition,