mirror of
				https://github.com/ml-explore/mlx.git
				synced 2025-10-31 07:58:14 +08:00 
			
		
		
		
	| @@ -6,12 +6,11 @@ | ||||
| #include <cstring> | ||||
| #include <fstream> | ||||
| #include <stdexcept> | ||||
| #include <string> | ||||
| #include <string_view> | ||||
| #include <unordered_map> | ||||
| #include <vector> | ||||
|  | ||||
| #include "mlx/load.h" | ||||
| #include "mlx/io/load.h" | ||||
| #include "mlx/ops.h" | ||||
| #include "mlx/utils.h" | ||||
| #include "python/src/load.h" | ||||
| @@ -161,40 +160,68 @@ class PyFileReader : public io::Reader { | ||||
|   py::object tell_func_; | ||||
| }; | ||||
|  | ||||
| DictOrArray mlx_load_helper(py::object file, StreamOrDevice s) { | ||||
|   py::module_ zipfile = py::module_::import("zipfile"); | ||||
|  | ||||
|   // Assume .npz file if it is zipped | ||||
|   if (is_zip_file(zipfile, file)) { | ||||
|     // Output dictionary filename in zip -> loaded array | ||||
|     std::unordered_map<std::string, array> array_dict; | ||||
|  | ||||
|     // Create python ZipFile object | ||||
|     ZipFileWrapper zipfile_object(zipfile, file); | ||||
|     for (const std::string& st : zipfile_object.namelist()) { | ||||
|       // Open zip file as a python file stream | ||||
|       py::object sub_file = zipfile_object.open(st); | ||||
|  | ||||
|       // Create array from python fille stream | ||||
|       auto arr = load(std::make_shared<PyFileReader>(sub_file), s); | ||||
|  | ||||
|       // Remove .npy from file if it is there | ||||
|       auto key = st; | ||||
|       if (st.length() > 4 && st.substr(st.length() - 4, 4) == ".npy") | ||||
|         key = st.substr(0, st.length() - 4); | ||||
|  | ||||
|       // Add array to dict | ||||
|       array_dict.insert({key, arr}); | ||||
|     } | ||||
|  | ||||
| std::unordered_map<std::string, array> mlx_load_safetensor_helper( | ||||
|     py::object file, | ||||
|     StreamOrDevice s) { | ||||
|   if (py::isinstance<py::str>(file)) { // Assume .safetensors file path string | ||||
|     return {load_safetensors(py::cast<std::string>(file), s)}; | ||||
|   } else if (is_istream_object(file)) { | ||||
|     // If we don't own the stream and it was passed to us, eval immediately | ||||
|     for (auto& [key, arr] : array_dict) { | ||||
|     auto arr = load_safetensors(std::make_shared<PyFileReader>(file), s); | ||||
|     { | ||||
|       py::gil_scoped_release gil; | ||||
|       arr.eval(); | ||||
|       for (auto& [key, arr] : arr) { | ||||
|         arr.eval(); | ||||
|       } | ||||
|     } | ||||
|     return {arr}; | ||||
|   } | ||||
|  | ||||
|     return {array_dict}; | ||||
|   } else if (py::isinstance<py::str>(file)) { // Assume .npy file path string | ||||
|   throw std::invalid_argument( | ||||
|       "[load_safetensors] Input must be a file-like object, or string"); | ||||
| } | ||||
|  | ||||
| std::unordered_map<std::string, array> mlx_load_npz_helper( | ||||
|     py::object file, | ||||
|     StreamOrDevice s) { | ||||
|   py::module_ zipfile = py::module_::import("zipfile"); | ||||
|   if (!is_zip_file(zipfile, file)) { | ||||
|     throw std::invalid_argument( | ||||
|         "[load_npz] Input must be a zip file or a file-like object that can be " | ||||
|         "opened with zipfile.ZipFile"); | ||||
|   } | ||||
|   // Output dictionary filename in zip -> loaded array | ||||
|   std::unordered_map<std::string, array> array_dict; | ||||
|  | ||||
|   // Create python ZipFile object | ||||
|   ZipFileWrapper zipfile_object(zipfile, file); | ||||
|   for (const std::string& st : zipfile_object.namelist()) { | ||||
|     // Open zip file as a python file stream | ||||
|     py::object sub_file = zipfile_object.open(st); | ||||
|  | ||||
|     // Create array from python fille stream | ||||
|     auto arr = load(std::make_shared<PyFileReader>(sub_file), s); | ||||
|  | ||||
|     // Remove .npy from file if it is there | ||||
|     auto key = st; | ||||
|     if (st.length() > 4 && st.substr(st.length() - 4, 4) == ".npy") | ||||
|       key = st.substr(0, st.length() - 4); | ||||
|  | ||||
|     // Add array to dict | ||||
|     array_dict.insert({key, arr}); | ||||
|   } | ||||
|  | ||||
|   // If we don't own the stream and it was passed to us, eval immediately | ||||
|   for (auto& [key, arr] : array_dict) { | ||||
|     py::gil_scoped_release gil; | ||||
|     arr.eval(); | ||||
|   } | ||||
|  | ||||
|   return {array_dict}; | ||||
| } | ||||
|  | ||||
| array mlx_load_npy_helper(py::object file, StreamOrDevice s) { | ||||
|   if (py::isinstance<py::str>(file)) { // Assume .npy file path string | ||||
|     return {load(py::cast<std::string>(file), s)}; | ||||
|   } else if (is_istream_object(file)) { | ||||
|     // If we don't own the stream and it was passed to us, eval immediately | ||||
| @@ -205,9 +232,41 @@ DictOrArray mlx_load_helper(py::object file, StreamOrDevice s) { | ||||
|     } | ||||
|     return {arr}; | ||||
|   } | ||||
|  | ||||
|   throw std::invalid_argument( | ||||
|       "[load] Input must be a file-like object, string, or pathlib.Path"); | ||||
|       "[load_npy] Input must be a file-like object, or string"); | ||||
| } | ||||
|  | ||||
| DictOrArray mlx_load_helper( | ||||
|     py::object file, | ||||
|     std::optional<std::string> format, | ||||
|     StreamOrDevice s) { | ||||
|   if (!format.has_value()) { | ||||
|     std::string fname; | ||||
|     if (py::isinstance<py::str>(file)) { | ||||
|       fname = py::cast<std::string>(file); | ||||
|     } else if (is_istream_object(file)) { | ||||
|       fname = file.attr("name").cast<std::string>(); | ||||
|     } else { | ||||
|       throw std::invalid_argument( | ||||
|           "[load] Input must be a file-like object, or string"); | ||||
|     } | ||||
|     size_t ext = fname.find_last_of('.'); | ||||
|     if (ext == std::string::npos) { | ||||
|       throw std::invalid_argument( | ||||
|           "[load] Could not infer file format from extension"); | ||||
|     } | ||||
|     format.emplace(fname.substr(ext + 1)); | ||||
|   } | ||||
|  | ||||
|   if (format.value() == "safetensors") { | ||||
|     return mlx_load_safetensor_helper(file, s); | ||||
|   } else if (format.value() == "npz") { | ||||
|     return mlx_load_npz_helper(file, s); | ||||
|   } else if (format.value() == "npy") { | ||||
|     return mlx_load_npy_helper(file, s); | ||||
|   } else { | ||||
|     throw std::invalid_argument("[load] Unknown file format " + format.value()); | ||||
|   } | ||||
| } | ||||
|  | ||||
| /////////////////////////////////////////////////////////////////////////////// | ||||
| @@ -305,7 +364,7 @@ void mlx_save_helper( | ||||
|   } | ||||
|  | ||||
|   throw std::invalid_argument( | ||||
|       "[save] Input must be a file-like object, string, or pathlib.Path"); | ||||
|       "[save] Input must be a file-like object, or string"); | ||||
| } | ||||
|  | ||||
| void mlx_savez_helper( | ||||
| @@ -361,3 +420,25 @@ void mlx_savez_helper( | ||||
|  | ||||
|   return; | ||||
| } | ||||
|  | ||||
| void mlx_save_safetensor_helper( | ||||
|     py::object file, | ||||
|     py::dict d, | ||||
|     std::optional<bool> retain_graph) { | ||||
|   auto arrays_map = d.cast<std::unordered_map<std::string, array>>(); | ||||
|   if (py::isinstance<py::str>(file)) { | ||||
|     save_safetensors(py::cast<std::string>(file), arrays_map, retain_graph); | ||||
|     return; | ||||
|   } else if (is_ostream_object(file)) { | ||||
|     auto writer = std::make_shared<PyFileWriter>(file); | ||||
|     { | ||||
|       py::gil_scoped_release gil; | ||||
|       save_safetensors(writer, arrays_map, retain_graph); | ||||
|     } | ||||
|  | ||||
|     return; | ||||
|   } | ||||
|  | ||||
|   throw std::invalid_argument( | ||||
|       "[save_safetensors] Input must be a file-like object, or string"); | ||||
| } | ||||
|   | ||||
| @@ -3,6 +3,8 @@ | ||||
| #pragma once | ||||
|  | ||||
| #include <pybind11/pybind11.h> | ||||
| #include <optional> | ||||
| #include <string> | ||||
| #include <unordered_map> | ||||
| #include <variant> | ||||
| #include "mlx/ops.h" | ||||
| @@ -12,7 +14,18 @@ using namespace mlx::core; | ||||
|  | ||||
| using DictOrArray = std::variant<array, std::unordered_map<std::string, array>>; | ||||
|  | ||||
| DictOrArray mlx_load_helper(py::object file, StreamOrDevice s); | ||||
| std::unordered_map<std::string, array> mlx_load_safetensor_helper( | ||||
|     py::object file, | ||||
|     StreamOrDevice s); | ||||
| void mlx_save_safetensor_helper( | ||||
|     py::object file, | ||||
|     py::dict d, | ||||
|     std::optional<bool> retain_graph = std::nullopt); | ||||
|  | ||||
| DictOrArray mlx_load_helper( | ||||
|     py::object file, | ||||
|     std::optional<std::string> format, | ||||
|     StreamOrDevice s); | ||||
| void mlx_save_helper( | ||||
|     py::object file, | ||||
|     array a, | ||||
|   | ||||
| @@ -2867,11 +2867,9 @@ void init_ops(py::module_& m) { | ||||
|         Args: | ||||
|             file (str): File to which the array is saved | ||||
|             arr (array): Array to be saved. | ||||
|             retain_graph (bool, optional): Optional argument to retain graph | ||||
|               during array evaluation before saving. If not provided the graph | ||||
|               is retained if we are during a function transformation. Default: | ||||
|               None | ||||
|  | ||||
|             retain_graph (bool, optional): Whether or not to retain the graph | ||||
|               during array evaluation. If left unspecified the graph is retained | ||||
|               only if saving is done in a function transformation. Default: ``None`` | ||||
|       )pbdoc"); | ||||
|   m.def( | ||||
|       "savez", | ||||
| @@ -2932,18 +2930,45 @@ void init_ops(py::module_& m) { | ||||
|       &mlx_load_helper, | ||||
|       "file"_a, | ||||
|       py::pos_only(), | ||||
|       "format"_a = none, | ||||
|       py::kw_only(), | ||||
|       "stream"_a = none, | ||||
|       R"pbdoc( | ||||
|         load(file: str, /, *, stream: Union[None, Stream, Device] = None) -> Union[array, Dict[str, array]] | ||||
|         load(file: str, /, format: Optional[str] = None, *, stream: Union[None, Stream, Device] = None) -> Union[array, Dict[str, array]] | ||||
|  | ||||
|         Load array(s) from a binary file in ``.npy`` or ``.npz`` format. | ||||
|         Load array(s) from a binary file in ``.npy``, ``.npz``, or ``.safetensors`` format. | ||||
|  | ||||
|         Args: | ||||
|             file (file, str): File in which the array is saved | ||||
|  | ||||
|             file (file, str): File in which the array is saved. | ||||
|             format (str, optional): Format of the file. If ``None``, the format | ||||
|               is inferred from the file extension. Supported formats: ``npy``, | ||||
|               ``npz``, and ``safetensors``. Default: ``None``. | ||||
|         Returns: | ||||
|             result (array, dict): The loaded array if ``.npy`` file or a dict mapping name to array if ``.npz`` file | ||||
|             result (array, dict): | ||||
|                 A single array if loading from a ``.npy`` file or a dict mapping | ||||
|                 names to arrays if loading from a ``.npz`` or ``.safetensors`` file. | ||||
|       )pbdoc"); | ||||
|   m.def( | ||||
|       "save_safetensors", | ||||
|       &mlx_save_safetensor_helper, | ||||
|       "file"_a, | ||||
|       "arrays"_a, | ||||
|       py::pos_only(), | ||||
|       "retain_graph"_a = std::nullopt, | ||||
|       py::kw_only(), | ||||
|       R"pbdoc( | ||||
|         save_safetensors(file: str, arrays: Dict[str, array], /, retain_graph: Optional[bool] = None) | ||||
|  | ||||
|         Save array(s) to a binary file in ``.safetensors`` format. | ||||
|  | ||||
|         For more information on the format see https://huggingface.co/docs/safetensors/index. | ||||
|  | ||||
|         Args: | ||||
|             file (file, str): File in which the array is saved> | ||||
|             arrays (dict(str, array)): The dictionary of names to arrays to be saved. | ||||
|             retain_graph (bool, optional): Whether or not to retain the graph | ||||
|               during array evaluation. If left unspecified the graph is retained | ||||
|               only if saving is done in a function transformation. Default: ``None``. | ||||
|       )pbdoc"); | ||||
|   m.def( | ||||
|       "where", | ||||
|   | ||||
| @@ -64,6 +64,33 @@ class TestLoad(mlx_tests.MLXTestCase): | ||||
|                         load_arr_mlx_npy = np.load(save_file_mlx) | ||||
|                         self.assertTrue(np.array_equal(load_arr_mlx_npy, save_arr_npy)) | ||||
|  | ||||
|     def test_save_and_load_safetensors(self): | ||||
|         if not os.path.isdir(self.test_dir): | ||||
|             os.mkdir(self.test_dir) | ||||
|  | ||||
|         for dt in self.dtypes + ["bfloat16"]: | ||||
|             with self.subTest(dtype=dt): | ||||
|                 for i, shape in enumerate([(1,), (23,), (1024, 1024), (4, 6, 3, 1, 2)]): | ||||
|                     with self.subTest(shape=shape): | ||||
|                         save_file_mlx = os.path.join( | ||||
|                             self.test_dir, f"mlx_{dt}_{i}_fs.safetensors" | ||||
|                         ) | ||||
|                         save_dict = { | ||||
|                             "test": mx.random.normal(shape=shape, dtype=getattr(mx, dt)) | ||||
|                             if dt in ["float32", "float16", "bfloat16"] | ||||
|                             else mx.ones(shape, dtype=getattr(mx, dt)) | ||||
|                         } | ||||
|  | ||||
|                         with open(save_file_mlx, "wb") as f: | ||||
|                             mx.save_safetensors(f, save_dict) | ||||
|                         with open(save_file_mlx, "rb") as f: | ||||
|                             load_dict = mx.load(f) | ||||
|  | ||||
|                         self.assertTrue("test" in load_dict) | ||||
|                         self.assertTrue( | ||||
|                             mx.array_equal(load_dict["test"], save_dict["test"]) | ||||
|                         ) | ||||
|  | ||||
|     def test_save_and_load_fs(self): | ||||
|  | ||||
|         if not os.path.isdir(self.test_dir): | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 Diogo
					Diogo