Safetensor support (#215)

Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
Diogo
2023-12-27 05:06:55 -05:00
committed by GitHub
parent 6b0d30bb85
commit 1f6ab6a556
17 changed files with 476 additions and 52 deletions

View File

@@ -2867,11 +2867,9 @@ void init_ops(py::module_& m) {
Args:
file (str): File to which the array is saved
arr (array): Array to be saved.
retain_graph (bool, optional): Optional argument to retain graph
during array evaluation before saving. If not provided the graph
is retained if we are during a function transformation. Default:
None
retain_graph (bool, optional): Whether or not to retain the graph
during array evaluation. If left unspecified the graph is retained
only if saving is done in a function transformation. Default: ``None``
)pbdoc");
m.def(
"savez",
@@ -2932,18 +2930,45 @@ void init_ops(py::module_& m) {
&mlx_load_helper,
"file"_a,
py::pos_only(),
"format"_a = none,
py::kw_only(),
"stream"_a = none,
R"pbdoc(
load(file: str, /, *, stream: Union[None, Stream, Device] = None) -> Union[array, Dict[str, array]]
load(file: str, /, format: Optional[str] = None, *, stream: Union[None, Stream, Device] = None) -> Union[array, Dict[str, array]]
Load array(s) from a binary file in ``.npy`` or ``.npz`` format.
Load array(s) from a binary file in ``.npy``, ``.npz``, or ``.safetensors`` format.
Args:
file (file, str): File in which the array is saved
file (file, str): File in which the array is saved.
format (str, optional): Format of the file. If ``None``, the format
is inferred from the file extension. Supported formats: ``npy``,
``npz``, and ``safetensors``. Default: ``None``.
Returns:
result (array, dict): The loaded array if ``.npy`` file or a dict mapping name to array if ``.npz`` file
result (array, dict):
A single array if loading from a ``.npy`` file or a dict mapping
names to arrays if loading from a ``.npz`` or ``.safetensors`` file.
)pbdoc");
m.def(
"save_safetensors",
&mlx_save_safetensor_helper,
"file"_a,
"arrays"_a,
py::pos_only(),
"retain_graph"_a = std::nullopt,
py::kw_only(),
R"pbdoc(
save_safetensors(file: str, arrays: Dict[str, array], /, retain_graph: Optional[bool] = None)
Save array(s) to a binary file in ``.safetensors`` format.
For more information on the format see https://huggingface.co/docs/safetensors/index.
Args:
file (file, str): File in which the array is saved>
arrays (dict(str, array)): The dictionary of names to arrays to be saved.
retain_graph (bool, optional): Whether or not to retain the graph
during array evaluation. If left unspecified the graph is retained
only if saving is done in a function transformation. Default: ``None``.
)pbdoc");
m.def(
"where",