mirror of
				https://github.com/ml-explore/mlx.git
				synced 2025-10-31 16:21:27 +08:00 
			
		
		
		
	allow pathlib.Path to save/load functions (#2541)
This commit is contained in:
		| @@ -23,6 +23,14 @@ using namespace nb::literals; | ||||
| // Helpers | ||||
| /////////////////////////////////////////////////////////////////////////////// | ||||
|  | ||||
| bool is_str_or_path(nb::object obj) { | ||||
|   if (nb::isinstance<nb::str>(obj)) { | ||||
|     return true; | ||||
|   } | ||||
|   nb::object path_type = nb::module_::import_("pathlib").attr("Path"); | ||||
|   return nb::isinstance(obj, path_type); | ||||
| } | ||||
|  | ||||
| bool is_istream_object(const nb::object& file) { | ||||
|   return nb::hasattr(file, "readinto") && nb::hasattr(file, "seek") && | ||||
|       nb::hasattr(file, "tell") && nb::hasattr(file, "closed"); | ||||
| @@ -172,8 +180,9 @@ std::pair< | ||||
|     std::unordered_map<std::string, mx::array>, | ||||
|     std::unordered_map<std::string, std::string>> | ||||
| mlx_load_safetensor_helper(nb::object file, mx::StreamOrDevice s) { | ||||
|   if (nb::isinstance<nb::str>(file)) { // Assume .safetensors file path string | ||||
|     return mx::load_safetensors(nb::cast<std::string>(file), s); | ||||
|   if (is_str_or_path(file)) { // Assume .safetensors file path string | ||||
|     auto file_str = nb::cast<std::string>(nb::str(file)); | ||||
|     return mx::load_safetensors(file_str, s); | ||||
|   } else if (is_istream_object(file)) { | ||||
|     // If we don't own the stream and it was passed to us, eval immediately | ||||
|     auto res = mx::load_safetensors(std::make_shared<PyFileReader>(file), s); | ||||
| @@ -191,8 +200,9 @@ mlx_load_safetensor_helper(nb::object file, mx::StreamOrDevice s) { | ||||
| } | ||||
|  | ||||
| mx::GGUFLoad mlx_load_gguf_helper(nb::object file, mx::StreamOrDevice s) { | ||||
|   if (nb::isinstance<nb::str>(file)) { // Assume .gguf file path string | ||||
|     return mx::load_gguf(nb::cast<std::string>(file), s); | ||||
|   if (is_str_or_path(file)) { // Assume .gguf file path string | ||||
|     auto file_str = nb::cast<std::string>(nb::str(file)); | ||||
|     return mx::load_gguf(file_str, s); | ||||
|   } | ||||
|  | ||||
|   throw std::invalid_argument("[load_gguf] Input must be a string"); | ||||
| @@ -201,7 +211,7 @@ mx::GGUFLoad mlx_load_gguf_helper(nb::object file, mx::StreamOrDevice s) { | ||||
| std::unordered_map<std::string, mx::array> mlx_load_npz_helper( | ||||
|     nb::object file, | ||||
|     mx::StreamOrDevice s) { | ||||
|   bool own_file = nb::isinstance<nb::str>(file); | ||||
|   bool own_file = is_str_or_path(file); | ||||
|  | ||||
|   nb::module_ zipfile = nb::module_::import_("zipfile"); | ||||
|   if (!is_zip_file(zipfile, file)) { | ||||
| @@ -242,8 +252,9 @@ std::unordered_map<std::string, mx::array> mlx_load_npz_helper( | ||||
| } | ||||
|  | ||||
| mx::array mlx_load_npy_helper(nb::object file, mx::StreamOrDevice s) { | ||||
|   if (nb::isinstance<nb::str>(file)) { // Assume .npy file path string | ||||
|     return mx::load(nb::cast<std::string>(file), s); | ||||
|   if (is_str_or_path(file)) { // Assume .npy file path string | ||||
|     auto file_str = nb::cast<std::string>(nb::str(file)); | ||||
|     return mx::load(file_str, s); | ||||
|   } else if (is_istream_object(file)) { | ||||
|     // If we don't own the stream and it was passed to us, eval immediately | ||||
|     auto arr = mx::load(std::make_shared<PyFileReader>(file), s); | ||||
| @@ -264,8 +275,8 @@ LoadOutputTypes mlx_load_helper( | ||||
|     mx::StreamOrDevice s) { | ||||
|   if (!format.has_value()) { | ||||
|     std::string fname; | ||||
|     if (nb::isinstance<nb::str>(file)) { | ||||
|       fname = nb::cast<std::string>(file); | ||||
|     if (is_str_or_path(file)) { | ||||
|       fname = nb::cast<std::string>(nb::str(file)); | ||||
|     } else if (is_istream_object(file)) { | ||||
|       fname = nb::cast<std::string>(file.attr("name")); | ||||
|     } else { | ||||
| @@ -384,8 +395,9 @@ class PyFileWriter : public mx::io::Writer { | ||||
| }; | ||||
|  | ||||
| void mlx_save_helper(nb::object file, mx::array a) { | ||||
|   if (nb::isinstance<nb::str>(file)) { | ||||
|     mx::save(nb::cast<std::string>(file), a); | ||||
|   if (is_str_or_path(file)) { | ||||
|     auto file_str = nb::cast<std::string>(nb::str(file)); | ||||
|     mx::save(file_str, a); | ||||
|     return; | ||||
|   } else if (is_ostream_object(file)) { | ||||
|     auto writer = std::make_shared<PyFileWriter>(file); | ||||
| @@ -409,8 +421,8 @@ void mlx_savez_helper( | ||||
|   // Add .npz to the end of the filename if not already there | ||||
|   nb::object file = file_; | ||||
|  | ||||
|   if (nb::isinstance<nb::str>(file_)) { | ||||
|     std::string fname = nb::cast<std::string>(file_); | ||||
|   if (is_str_or_path(file)) { | ||||
|     std::string fname = nb::cast<std::string>(nb::str(file_)); | ||||
|  | ||||
|     // Add .npz to file name if it is not there | ||||
|     if (fname.length() < 4 || fname.substr(fname.length() - 4, 4) != ".npz") | ||||
| @@ -473,11 +485,11 @@ void mlx_save_safetensor_helper( | ||||
|     metadata_map = std::unordered_map<std::string, std::string>(); | ||||
|   } | ||||
|   auto arrays_map = nb::cast<std::unordered_map<std::string, mx::array>>(d); | ||||
|   if (nb::isinstance<nb::str>(file)) { | ||||
|   if (is_str_or_path(file)) { | ||||
|     { | ||||
|       auto file_str = nb::cast<std::string>(nb::str(file)); | ||||
|       nb::gil_scoped_release nogil; | ||||
|       mx::save_safetensors( | ||||
|           nb::cast<std::string>(file), arrays_map, metadata_map); | ||||
|       mx::save_safetensors(file_str, arrays_map, metadata_map); | ||||
|     } | ||||
|   } else if (is_ostream_object(file)) { | ||||
|     auto writer = std::make_shared<PyFileWriter>(file); | ||||
| @@ -496,19 +508,21 @@ void mlx_save_gguf_helper( | ||||
|     nb::dict a, | ||||
|     std::optional<nb::dict> m) { | ||||
|   auto arrays_map = nb::cast<std::unordered_map<std::string, mx::array>>(a); | ||||
|   if (nb::isinstance<nb::str>(file)) { | ||||
|   if (is_str_or_path(file)) { | ||||
|     if (m) { | ||||
|       auto metadata_map = | ||||
|           nb::cast<std::unordered_map<std::string, mx::GGUFMetaData>>( | ||||
|               m.value()); | ||||
|       { | ||||
|         auto file_str = nb::cast<std::string>(nb::str(file)); | ||||
|         nb::gil_scoped_release nogil; | ||||
|         mx::save_gguf(nb::cast<std::string>(file), arrays_map, metadata_map); | ||||
|         mx::save_gguf(file_str, arrays_map, metadata_map); | ||||
|       } | ||||
|     } else { | ||||
|       { | ||||
|         auto file_str = nb::cast<std::string>(nb::str(file)); | ||||
|         nb::gil_scoped_release nogil; | ||||
|         mx::save_gguf(nb::cast<std::string>(file), arrays_map); | ||||
|         mx::save_gguf(file_str, arrays_map); | ||||
|       } | ||||
|     } | ||||
|   } else { | ||||
|   | ||||
| @@ -3911,12 +3911,13 @@ void init_ops(nb::module_& m) { | ||||
|       &mlx_save_helper, | ||||
|       "file"_a, | ||||
|       "arr"_a, | ||||
|       nb::sig("def save(file: str, arr: array) -> None"), | ||||
|       nb::sig( | ||||
|           "def save(file: Union[file, str, pathlib.Path], arr: array) -> None"), | ||||
|       R"pbdoc( | ||||
|         Save the array to a binary file in ``.npy`` format. | ||||
|  | ||||
|         Args: | ||||
|             file (str): File to which the array is saved | ||||
|             file (str, pathlib.Path, file): File to which the array is saved | ||||
|             arr (array): Array to be saved. | ||||
|       )pbdoc"); | ||||
|   m.def( | ||||
| @@ -3927,6 +3928,8 @@ void init_ops(nb::module_& m) { | ||||
|       "file"_a, | ||||
|       "args"_a, | ||||
|       "kwargs"_a, | ||||
|       nb::sig( | ||||
|           "def savez(file: Union[file, str, pathlib.Path], *args, **kwargs)"), | ||||
|       R"pbdoc( | ||||
|         Save several arrays to a binary file in uncompressed ``.npz`` | ||||
|         format. | ||||
| @@ -3946,7 +3949,7 @@ void init_ops(nb::module_& m) { | ||||
|             mx.savez("model.npz", **dict(flat_params)) | ||||
|  | ||||
|         Args: | ||||
|             file (file, str): Path to file to which the arrays are saved. | ||||
|             file (file, str, pathlib.Path): Path to file to which the arrays are saved. | ||||
|             *args (arrays): Arrays to be saved. | ||||
|             **kwargs (arrays): Arrays to be saved. Each array will be saved | ||||
|               with the associated keyword as the output file name. | ||||
| @@ -3959,12 +3962,13 @@ void init_ops(nb::module_& m) { | ||||
|       nb::arg(), | ||||
|       "args"_a, | ||||
|       "kwargs"_a, | ||||
|       nb::sig("def savez_compressed(file: str, *args, **kwargs)"), | ||||
|       nb::sig( | ||||
|           "def savez_compressed(file: Union[file, str, pathlib.Path], *args, **kwargs)"), | ||||
|       R"pbdoc( | ||||
|         Save several arrays to a binary file in compressed ``.npz`` format. | ||||
|  | ||||
|         Args: | ||||
|             file (file, str): Path to file to which the arrays are saved. | ||||
|             file (file, str, pathlib.Path): Path to file to which the arrays are saved. | ||||
|             *args (arrays): Arrays to be saved. | ||||
|             **kwargs (arrays): Arrays to be saved. Each array will be saved | ||||
|               with the associated keyword as the output file name. | ||||
| @@ -3978,7 +3982,7 @@ void init_ops(nb::module_& m) { | ||||
|       nb::kw_only(), | ||||
|       "stream"_a = nb::none(), | ||||
|       nb::sig( | ||||
|           "def load(file: str, /, format: Optional[str] = None, return_metadata: bool = False, *, stream: Union[None, Stream, Device] = None) -> Union[array, dict[str, array]]"), | ||||
|           "def load(file: Union[file, str, pathlib.Path], /, format: Optional[str] = None, return_metadata: bool = False, *, stream: Union[None, Stream, Device] = None) -> Union[array, dict[str, array]]"), | ||||
|       R"pbdoc( | ||||
|         Load array(s) from a binary file. | ||||
|  | ||||
| @@ -3986,7 +3990,7 @@ void init_ops(nb::module_& m) { | ||||
|         ``.gguf``. | ||||
|  | ||||
|         Args: | ||||
|             file (file, str): File in which the array is saved. | ||||
|             file (file, str, pathlib.Path): File in which the array is saved. | ||||
|             format (str, optional): Format of the file. If ``None``, the | ||||
|               format is inferred from the file extension. Supported formats: | ||||
|               ``npy``, ``npz``, and ``safetensors``. Default: ``None``. | ||||
| @@ -4012,7 +4016,7 @@ void init_ops(nb::module_& m) { | ||||
|       "arrays"_a, | ||||
|       "metadata"_a = nb::none(), | ||||
|       nb::sig( | ||||
|           "def save_safetensors(file: str, arrays: dict[str, array], metadata: Optional[dict[str, str]] = None)"), | ||||
|           "def save_safetensors(file: Union[file, str, pathlib.Path], arrays: dict[str, array], metadata: Optional[dict[str, str]] = None)"), | ||||
|       R"pbdoc( | ||||
|         Save array(s) to a binary file in ``.safetensors`` format. | ||||
|  | ||||
| @@ -4021,7 +4025,7 @@ void init_ops(nb::module_& m) { | ||||
|         information on the format. | ||||
|  | ||||
|         Args: | ||||
|             file (file, str): File in which the array is saved. | ||||
|             file (file, str, pathlib.Path): File in which the array is saved. | ||||
|             arrays (dict(str, array)): The dictionary of names to arrays to | ||||
|               be saved. | ||||
|             metadata (dict(str, str), optional): The dictionary of | ||||
| @@ -4034,7 +4038,7 @@ void init_ops(nb::module_& m) { | ||||
|       "arrays"_a, | ||||
|       "metadata"_a = nb::none(), | ||||
|       nb::sig( | ||||
|           "def save_gguf(file: str, arrays: dict[str, array], metadata: dict[str, Union[array, str, list[str]]])"), | ||||
|           "def save_gguf(file: Union[file, str, pathlib.Path], arrays: dict[str, array], metadata: dict[str, Union[array, str, list[str]]])"), | ||||
|       R"pbdoc( | ||||
|         Save array(s) to a binary file in ``.gguf`` format. | ||||
|  | ||||
| @@ -4043,7 +4047,7 @@ void init_ops(nb::module_& m) { | ||||
|         more information on the format. | ||||
|  | ||||
|         Args: | ||||
|             file (file, str): File in which the array is saved. | ||||
|             file (file, str, pathlib.Path): File in which the array is saved. | ||||
|             arrays (dict(str, array)): The dictionary of names to arrays to | ||||
|               be saved. | ||||
|             metadata (dict(str, Union[array, str, list(str)])): The dictionary | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 Awni Hannun
					Awni Hannun