allow pathlib.Path to save/load functions (#2541)

This commit is contained in:
Awni Hannun
2025-08-25 14:58:49 -07:00
committed by GitHub
parent d2f540f4e0
commit db14e29a0b
2 changed files with 48 additions and 30 deletions

View File

@@ -3911,12 +3911,13 @@ void init_ops(nb::module_& m) {
&mlx_save_helper,
"file"_a,
"arr"_a,
nb::sig("def save(file: str, arr: array) -> None"),
nb::sig(
"def save(file: Union[file, str, pathlib.Path], arr: array) -> None"),
R"pbdoc(
Save the array to a binary file in ``.npy`` format.
Args:
file (str): File to which the array is saved
file (str, pathlib.Path, file): File to which the array is saved
arr (array): Array to be saved.
)pbdoc");
m.def(
@@ -3927,6 +3928,8 @@ void init_ops(nb::module_& m) {
"file"_a,
"args"_a,
"kwargs"_a,
nb::sig(
"def savez(file: Union[file, str, pathlib.Path], *args, **kwargs)"),
R"pbdoc(
Save several arrays to a binary file in uncompressed ``.npz``
format.
@@ -3946,7 +3949,7 @@ void init_ops(nb::module_& m) {
mx.savez("model.npz", **dict(flat_params))
Args:
file (file, str): Path to file to which the arrays are saved.
file (file, str, pathlib.Path): Path to file to which the arrays are saved.
*args (arrays): Arrays to be saved.
**kwargs (arrays): Arrays to be saved. Each array will be saved
with the associated keyword as the output file name.
@@ -3959,12 +3962,13 @@ void init_ops(nb::module_& m) {
nb::arg(),
"args"_a,
"kwargs"_a,
nb::sig("def savez_compressed(file: str, *args, **kwargs)"),
nb::sig(
"def savez_compressed(file: Union[file, str, pathlib.Path], *args, **kwargs)"),
R"pbdoc(
Save several arrays to a binary file in compressed ``.npz`` format.
Args:
file (file, str): Path to file to which the arrays are saved.
file (file, str, pathlib.Path): Path to file to which the arrays are saved.
*args (arrays): Arrays to be saved.
**kwargs (arrays): Arrays to be saved. Each array will be saved
with the associated keyword as the output file name.
@@ -3978,7 +3982,7 @@ void init_ops(nb::module_& m) {
nb::kw_only(),
"stream"_a = nb::none(),
nb::sig(
"def load(file: str, /, format: Optional[str] = None, return_metadata: bool = False, *, stream: Union[None, Stream, Device] = None) -> Union[array, dict[str, array]]"),
"def load(file: Union[file, str, pathlib.Path], /, format: Optional[str] = None, return_metadata: bool = False, *, stream: Union[None, Stream, Device] = None) -> Union[array, dict[str, array]]"),
R"pbdoc(
Load array(s) from a binary file.
@@ -3986,7 +3990,7 @@ void init_ops(nb::module_& m) {
``.gguf``.
Args:
file (file, str): File in which the array is saved.
file (file, str, pathlib.Path): File in which the array is saved.
format (str, optional): Format of the file. If ``None``, the
format is inferred from the file extension. Supported formats:
``npy``, ``npz``, and ``safetensors``. Default: ``None``.
@@ -4012,7 +4016,7 @@ void init_ops(nb::module_& m) {
"arrays"_a,
"metadata"_a = nb::none(),
nb::sig(
"def save_safetensors(file: str, arrays: dict[str, array], metadata: Optional[dict[str, str]] = None)"),
"def save_safetensors(file: Union[file, str, pathlib.Path], arrays: dict[str, array], metadata: Optional[dict[str, str]] = None)"),
R"pbdoc(
Save array(s) to a binary file in ``.safetensors`` format.
@@ -4021,7 +4025,7 @@ void init_ops(nb::module_& m) {
information on the format.
Args:
file (file, str): File in which the array is saved.
file (file, str, pathlib.Path): File in which the array is saved.
arrays (dict(str, array)): The dictionary of names to arrays to
be saved.
metadata (dict(str, str), optional): The dictionary of
@@ -4034,7 +4038,7 @@ void init_ops(nb::module_& m) {
"arrays"_a,
"metadata"_a = nb::none(),
nb::sig(
"def save_gguf(file: str, arrays: dict[str, array], metadata: dict[str, Union[array, str, list[str]]])"),
"def save_gguf(file: Union[file, str, pathlib.Path], arrays: dict[str, array], metadata: dict[str, Union[array, str, list[str]]])"),
R"pbdoc(
Save array(s) to a binary file in ``.gguf`` format.
@@ -4043,7 +4047,7 @@ void init_ops(nb::module_& m) {
more information on the format.
Args:
file (file, str): File in which the array is saved.
file (file, str, pathlib.Path): File in which the array is saved.
arrays (dict(str, array)): The dictionary of names to arrays to
be saved.
metadata (dict(str, Union[array, str, list(str)])): The dictionary