mirror of
https://github.com/ml-explore/mlx.git
synced 2025-10-20 17:38:09 +08:00
allow pathlib.Path to save/load functions (#2541)
This commit is contained in:
@@ -3911,12 +3911,13 @@ void init_ops(nb::module_& m) {
|
||||
&mlx_save_helper,
|
||||
"file"_a,
|
||||
"arr"_a,
|
||||
nb::sig("def save(file: str, arr: array) -> None"),
|
||||
nb::sig(
|
||||
"def save(file: Union[file, str, pathlib.Path], arr: array) -> None"),
|
||||
R"pbdoc(
|
||||
Save the array to a binary file in ``.npy`` format.
|
||||
|
||||
Args:
|
||||
file (str): File to which the array is saved
|
||||
file (str, pathlib.Path, file): File to which the array is saved
|
||||
arr (array): Array to be saved.
|
||||
)pbdoc");
|
||||
m.def(
|
||||
@@ -3927,6 +3928,8 @@ void init_ops(nb::module_& m) {
|
||||
"file"_a,
|
||||
"args"_a,
|
||||
"kwargs"_a,
|
||||
nb::sig(
|
||||
"def savez(file: Union[file, str, pathlib.Path], *args, **kwargs)"),
|
||||
R"pbdoc(
|
||||
Save several arrays to a binary file in uncompressed ``.npz``
|
||||
format.
|
||||
@@ -3946,7 +3949,7 @@ void init_ops(nb::module_& m) {
|
||||
mx.savez("model.npz", **dict(flat_params))
|
||||
|
||||
Args:
|
||||
file (file, str): Path to file to which the arrays are saved.
|
||||
file (file, str, pathlib.Path): Path to file to which the arrays are saved.
|
||||
*args (arrays): Arrays to be saved.
|
||||
**kwargs (arrays): Arrays to be saved. Each array will be saved
|
||||
with the associated keyword as the output file name.
|
||||
@@ -3959,12 +3962,13 @@ void init_ops(nb::module_& m) {
|
||||
nb::arg(),
|
||||
"args"_a,
|
||||
"kwargs"_a,
|
||||
nb::sig("def savez_compressed(file: str, *args, **kwargs)"),
|
||||
nb::sig(
|
||||
"def savez_compressed(file: Union[file, str, pathlib.Path], *args, **kwargs)"),
|
||||
R"pbdoc(
|
||||
Save several arrays to a binary file in compressed ``.npz`` format.
|
||||
|
||||
Args:
|
||||
file (file, str): Path to file to which the arrays are saved.
|
||||
file (file, str, pathlib.Path): Path to file to which the arrays are saved.
|
||||
*args (arrays): Arrays to be saved.
|
||||
**kwargs (arrays): Arrays to be saved. Each array will be saved
|
||||
with the associated keyword as the output file name.
|
||||
@@ -3978,7 +3982,7 @@ void init_ops(nb::module_& m) {
|
||||
nb::kw_only(),
|
||||
"stream"_a = nb::none(),
|
||||
nb::sig(
|
||||
"def load(file: str, /, format: Optional[str] = None, return_metadata: bool = False, *, stream: Union[None, Stream, Device] = None) -> Union[array, dict[str, array]]"),
|
||||
"def load(file: Union[file, str, pathlib.Path], /, format: Optional[str] = None, return_metadata: bool = False, *, stream: Union[None, Stream, Device] = None) -> Union[array, dict[str, array]]"),
|
||||
R"pbdoc(
|
||||
Load array(s) from a binary file.
|
||||
|
||||
@@ -3986,7 +3990,7 @@ void init_ops(nb::module_& m) {
|
||||
``.gguf``.
|
||||
|
||||
Args:
|
||||
file (file, str): File in which the array is saved.
|
||||
file (file, str, pathlib.Path): File in which the array is saved.
|
||||
format (str, optional): Format of the file. If ``None``, the
|
||||
format is inferred from the file extension. Supported formats:
|
||||
``npy``, ``npz``, and ``safetensors``. Default: ``None``.
|
||||
@@ -4012,7 +4016,7 @@ void init_ops(nb::module_& m) {
|
||||
"arrays"_a,
|
||||
"metadata"_a = nb::none(),
|
||||
nb::sig(
|
||||
"def save_safetensors(file: str, arrays: dict[str, array], metadata: Optional[dict[str, str]] = None)"),
|
||||
"def save_safetensors(file: Union[file, str, pathlib.Path], arrays: dict[str, array], metadata: Optional[dict[str, str]] = None)"),
|
||||
R"pbdoc(
|
||||
Save array(s) to a binary file in ``.safetensors`` format.
|
||||
|
||||
@@ -4021,7 +4025,7 @@ void init_ops(nb::module_& m) {
|
||||
information on the format.
|
||||
|
||||
Args:
|
||||
file (file, str): File in which the array is saved.
|
||||
file (file, str, pathlib.Path): File in which the array is saved.
|
||||
arrays (dict(str, array)): The dictionary of names to arrays to
|
||||
be saved.
|
||||
metadata (dict(str, str), optional): The dictionary of
|
||||
@@ -4034,7 +4038,7 @@ void init_ops(nb::module_& m) {
|
||||
"arrays"_a,
|
||||
"metadata"_a = nb::none(),
|
||||
nb::sig(
|
||||
"def save_gguf(file: str, arrays: dict[str, array], metadata: dict[str, Union[array, str, list[str]]])"),
|
||||
"def save_gguf(file: Union[file, str, pathlib.Path], arrays: dict[str, array], metadata: dict[str, Union[array, str, list[str]]])"),
|
||||
R"pbdoc(
|
||||
Save array(s) to a binary file in ``.gguf`` format.
|
||||
|
||||
@@ -4043,7 +4047,7 @@ void init_ops(nb::module_& m) {
|
||||
more information on the format.
|
||||
|
||||
Args:
|
||||
file (file, str): File in which the array is saved.
|
||||
file (file, str, pathlib.Path): File in which the array is saved.
|
||||
arrays (dict(str, array)): The dictionary of names to arrays to
|
||||
be saved.
|
||||
metadata (dict(str, Union[array, str, list(str)])): The dictionary
|
||||
|
Reference in New Issue
Block a user