diff --git a/python/src/load.cpp b/python/src/load.cpp index e992f2077..f98307f19 100644 --- a/python/src/load.cpp +++ b/python/src/load.cpp @@ -23,6 +23,14 @@ using namespace nb::literals; // Helpers /////////////////////////////////////////////////////////////////////////////// +bool is_str_or_path(nb::object obj) { + if (nb::isinstance(obj)) { + return true; + } + nb::object path_type = nb::module_::import_("pathlib").attr("Path"); + return nb::isinstance(obj, path_type); +} + bool is_istream_object(const nb::object& file) { return nb::hasattr(file, "readinto") && nb::hasattr(file, "seek") && nb::hasattr(file, "tell") && nb::hasattr(file, "closed"); @@ -172,8 +180,9 @@ std::pair< std::unordered_map, std::unordered_map> mlx_load_safetensor_helper(nb::object file, mx::StreamOrDevice s) { - if (nb::isinstance(file)) { // Assume .safetensors file path string - return mx::load_safetensors(nb::cast(file), s); + if (is_str_or_path(file)) { // Assume .safetensors file path string + auto file_str = nb::cast(nb::str(file)); + return mx::load_safetensors(file_str, s); } else if (is_istream_object(file)) { // If we don't own the stream and it was passed to us, eval immediately auto res = mx::load_safetensors(std::make_shared(file), s); @@ -191,8 +200,9 @@ mlx_load_safetensor_helper(nb::object file, mx::StreamOrDevice s) { } mx::GGUFLoad mlx_load_gguf_helper(nb::object file, mx::StreamOrDevice s) { - if (nb::isinstance(file)) { // Assume .gguf file path string - return mx::load_gguf(nb::cast(file), s); + if (is_str_or_path(file)) { // Assume .gguf file path string + auto file_str = nb::cast(nb::str(file)); + return mx::load_gguf(file_str, s); } throw std::invalid_argument("[load_gguf] Input must be a string"); @@ -201,7 +211,7 @@ mx::GGUFLoad mlx_load_gguf_helper(nb::object file, mx::StreamOrDevice s) { std::unordered_map mlx_load_npz_helper( nb::object file, mx::StreamOrDevice s) { - bool own_file = nb::isinstance(file); + bool own_file = is_str_or_path(file); nb::module_ zipfile = nb::module_::import_("zipfile"); if (!is_zip_file(zipfile, file)) { @@ -242,8 +252,9 @@ std::unordered_map mlx_load_npz_helper( } mx::array mlx_load_npy_helper(nb::object file, mx::StreamOrDevice s) { - if (nb::isinstance(file)) { // Assume .npy file path string - return mx::load(nb::cast(file), s); + if (is_str_or_path(file)) { // Assume .npy file path string + auto file_str = nb::cast(nb::str(file)); + return mx::load(file_str, s); } else if (is_istream_object(file)) { // If we don't own the stream and it was passed to us, eval immediately auto arr = mx::load(std::make_shared(file), s); @@ -264,8 +275,8 @@ LoadOutputTypes mlx_load_helper( mx::StreamOrDevice s) { if (!format.has_value()) { std::string fname; - if (nb::isinstance(file)) { - fname = nb::cast(file); + if (is_str_or_path(file)) { + fname = nb::cast(nb::str(file)); } else if (is_istream_object(file)) { fname = nb::cast(file.attr("name")); } else { @@ -384,8 +395,9 @@ class PyFileWriter : public mx::io::Writer { }; void mlx_save_helper(nb::object file, mx::array a) { - if (nb::isinstance(file)) { - mx::save(nb::cast(file), a); + if (is_str_or_path(file)) { + auto file_str = nb::cast(nb::str(file)); + mx::save(file_str, a); return; } else if (is_ostream_object(file)) { auto writer = std::make_shared(file); @@ -409,8 +421,8 @@ void mlx_savez_helper( // Add .npz to the end of the filename if not already there nb::object file = file_; - if (nb::isinstance(file_)) { - std::string fname = nb::cast(file_); + if (is_str_or_path(file)) { + std::string fname = nb::cast(nb::str(file_)); // Add .npz to file name if it is not there if (fname.length() < 4 || fname.substr(fname.length() - 4, 4) != ".npz") @@ -473,11 +485,11 @@ void mlx_save_safetensor_helper( metadata_map = std::unordered_map(); } auto arrays_map = nb::cast>(d); - if (nb::isinstance(file)) { + if (is_str_or_path(file)) { { + auto file_str = nb::cast(nb::str(file)); nb::gil_scoped_release nogil; - mx::save_safetensors( - nb::cast(file), arrays_map, metadata_map); + mx::save_safetensors(file_str, arrays_map, metadata_map); } } else if (is_ostream_object(file)) { auto writer = std::make_shared(file); @@ -496,19 +508,21 @@ void mlx_save_gguf_helper( nb::dict a, std::optional m) { auto arrays_map = nb::cast>(a); - if (nb::isinstance(file)) { + if (is_str_or_path(file)) { if (m) { auto metadata_map = nb::cast>( m.value()); { + auto file_str = nb::cast(nb::str(file)); nb::gil_scoped_release nogil; - mx::save_gguf(nb::cast(file), arrays_map, metadata_map); + mx::save_gguf(file_str, arrays_map, metadata_map); } } else { { + auto file_str = nb::cast(nb::str(file)); nb::gil_scoped_release nogil; - mx::save_gguf(nb::cast(file), arrays_map); + mx::save_gguf(file_str, arrays_map); } } } else { diff --git a/python/src/ops.cpp b/python/src/ops.cpp index af64d9dfc..f2a27e282 100644 --- a/python/src/ops.cpp +++ b/python/src/ops.cpp @@ -3911,12 +3911,13 @@ void init_ops(nb::module_& m) { &mlx_save_helper, "file"_a, "arr"_a, - nb::sig("def save(file: str, arr: array) -> None"), + nb::sig( + "def save(file: Union[file, str, pathlib.Path], arr: array) -> None"), R"pbdoc( Save the array to a binary file in ``.npy`` format. Args: - file (str): File to which the array is saved + file (str, pathlib.Path, file): File to which the array is saved arr (array): Array to be saved. )pbdoc"); m.def( @@ -3927,6 +3928,8 @@ void init_ops(nb::module_& m) { "file"_a, "args"_a, "kwargs"_a, + nb::sig( + "def savez(file: Union[file, str, pathlib.Path], *args, **kwargs)"), R"pbdoc( Save several arrays to a binary file in uncompressed ``.npz`` format. @@ -3946,7 +3949,7 @@ void init_ops(nb::module_& m) { mx.savez("model.npz", **dict(flat_params)) Args: - file (file, str): Path to file to which the arrays are saved. + file (file, str, pathlib.Path): Path to file to which the arrays are saved. *args (arrays): Arrays to be saved. **kwargs (arrays): Arrays to be saved. Each array will be saved with the associated keyword as the output file name. @@ -3959,12 +3962,13 @@ void init_ops(nb::module_& m) { nb::arg(), "args"_a, "kwargs"_a, - nb::sig("def savez_compressed(file: str, *args, **kwargs)"), + nb::sig( + "def savez_compressed(file: Union[file, str, pathlib.Path], *args, **kwargs)"), R"pbdoc( Save several arrays to a binary file in compressed ``.npz`` format. Args: - file (file, str): Path to file to which the arrays are saved. + file (file, str, pathlib.Path): Path to file to which the arrays are saved. *args (arrays): Arrays to be saved. **kwargs (arrays): Arrays to be saved. Each array will be saved with the associated keyword as the output file name. @@ -3978,7 +3982,7 @@ void init_ops(nb::module_& m) { nb::kw_only(), "stream"_a = nb::none(), nb::sig( - "def load(file: str, /, format: Optional[str] = None, return_metadata: bool = False, *, stream: Union[None, Stream, Device] = None) -> Union[array, dict[str, array]]"), + "def load(file: Union[file, str, pathlib.Path], /, format: Optional[str] = None, return_metadata: bool = False, *, stream: Union[None, Stream, Device] = None) -> Union[array, dict[str, array]]"), R"pbdoc( Load array(s) from a binary file. @@ -3986,7 +3990,7 @@ void init_ops(nb::module_& m) { ``.gguf``. Args: - file (file, str): File in which the array is saved. + file (file, str, pathlib.Path): File in which the array is saved. format (str, optional): Format of the file. If ``None``, the format is inferred from the file extension. Supported formats: ``npy``, ``npz``, and ``safetensors``. Default: ``None``. @@ -4012,7 +4016,7 @@ void init_ops(nb::module_& m) { "arrays"_a, "metadata"_a = nb::none(), nb::sig( - "def save_safetensors(file: str, arrays: dict[str, array], metadata: Optional[dict[str, str]] = None)"), + "def save_safetensors(file: Union[file, str, pathlib.Path], arrays: dict[str, array], metadata: Optional[dict[str, str]] = None)"), R"pbdoc( Save array(s) to a binary file in ``.safetensors`` format. @@ -4021,7 +4025,7 @@ void init_ops(nb::module_& m) { information on the format. Args: - file (file, str): File in which the array is saved. + file (file, str, pathlib.Path): File in which the array is saved. arrays (dict(str, array)): The dictionary of names to arrays to be saved. metadata (dict(str, str), optional): The dictionary of @@ -4034,7 +4038,7 @@ void init_ops(nb::module_& m) { "arrays"_a, "metadata"_a = nb::none(), nb::sig( - "def save_gguf(file: str, arrays: dict[str, array], metadata: dict[str, Union[array, str, list[str]]])"), + "def save_gguf(file: Union[file, str, pathlib.Path], arrays: dict[str, array], metadata: dict[str, Union[array, str, list[str]]])"), R"pbdoc( Save array(s) to a binary file in ``.gguf`` format. @@ -4043,7 +4047,7 @@ void init_ops(nb::module_& m) { more information on the format. Args: - file (file, str): File in which the array is saved. + file (file, str, pathlib.Path): File in which the array is saved. arrays (dict(str, array)): The dictionary of names to arrays to be saved. metadata (dict(str, Union[array, str, list(str)])): The dictionary