diff --git a/python/src/load.cpp b/python/src/load.cpp index f859febe3..1246a94b2 100644 --- a/python/src/load.cpp +++ b/python/src/load.cpp @@ -6,7 +6,6 @@ #include #include #include -#include #include #include #include @@ -179,43 +178,50 @@ std::unordered_map mlx_load_safetensor_helper( } throw std::invalid_argument( - "[load] Input must be a file-like object, string, or pathlib.Path"); + "[load_safetensor] Input must be a file-like object, string, or pathlib.Path"); } -DictOrArray mlx_load_helper(py::object file, StreamOrDevice s) { +std::unordered_map mlx_load_npz_helper( + py::object file, + StreamOrDevice s) { py::module_ zipfile = py::module_::import("zipfile"); + if (!is_zip_file(zipfile, file)) { + throw std::invalid_argument( + "[load_npz] Input must be a zip file or a file-like object that can be " + "opened with zipfile.ZipFile"); + } + // Output dictionary filename in zip -> loaded array + std::unordered_map array_dict; - // Assume .npz file if it is zipped - if (is_zip_file(zipfile, file)) { - // Output dictionary filename in zip -> loaded array - std::unordered_map array_dict; + // Create python ZipFile object + ZipFileWrapper zipfile_object(zipfile, file); + for (const std::string& st : zipfile_object.namelist()) { + // Open zip file as a python file stream + py::object sub_file = zipfile_object.open(st); - // Create python ZipFile object - ZipFileWrapper zipfile_object(zipfile, file); - for (const std::string& st : zipfile_object.namelist()) { - // Open zip file as a python file stream - py::object sub_file = zipfile_object.open(st); + // Create array from python fille stream + auto arr = load(std::make_shared(sub_file), s); - // Create array from python fille stream - auto arr = load(std::make_shared(sub_file), s); + // Remove .npy from file if it is there + auto key = st; + if (st.length() > 4 && st.substr(st.length() - 4, 4) == ".npy") + key = st.substr(0, st.length() - 4); - // Remove .npy from file if it is there - auto key = st; - if (st.length() > 4 && st.substr(st.length() - 4, 4) == ".npy") - key = st.substr(0, st.length() - 4); + // Add array to dict + array_dict.insert({key, arr}); + } - // Add array to dict - array_dict.insert({key, arr}); - } + // If we don't own the stream and it was passed to us, eval immediately + for (auto& [key, arr] : array_dict) { + py::gil_scoped_release gil; + arr.eval(); + } - // If we don't own the stream and it was passed to us, eval immediately - for (auto& [key, arr] : array_dict) { - py::gil_scoped_release gil; - arr.eval(); - } + return {array_dict}; +} - return {array_dict}; - } else if (py::isinstance(file)) { // Assume .npy file path string +array mlx_load_npy_helper(py::object file, StreamOrDevice s) { + if (py::isinstance(file)) { // Assume .npy file path string return {load(py::cast(file), s)}; } else if (is_istream_object(file)) { // If we don't own the stream and it was passed to us, eval immediately @@ -226,9 +232,39 @@ DictOrArray mlx_load_helper(py::object file, StreamOrDevice s) { } return {arr}; } +} - throw std::invalid_argument( - "[load] Input must be a file-like object, string, or pathlib.Path"); +DictOrArray mlx_load_helper( + py::object file, + std::optional format, + StreamOrDevice s) { + if (!format.has_value()) { + std::string fname; + if (py::isinstance(file)) { + fname = py::cast(file); + } else if (is_istream_object(file)) { + fname = file.attr("name").cast(); + } else { + throw std::invalid_argument( + "[load] Input must be a file-like object, string, or pathlib.Path"); + } + size_t ext = fname.find_last_of('.'); + if (ext == std::string::npos) { + throw std::invalid_argument( + "[load] Could not infer file format from extension"); + } + format.emplace(fname.substr(ext + 1)); + } + + if (format.value() == "safetensors") { + return mlx_load_safetensor_helper(file, s); + } else if (format.value() == "npz") { + return mlx_load_npz_helper(file, s); + } else if (format.value() == "npy") { + return mlx_load_npy_helper(file, s); + } else { + throw std::invalid_argument("[load] Unknown file format " + format.value()); + } } /////////////////////////////////////////////////////////////////////////////// diff --git a/python/src/load.h b/python/src/load.h index 1ced6e35d..45d64664d 100644 --- a/python/src/load.h +++ b/python/src/load.h @@ -3,6 +3,8 @@ #pragma once #include +#include +#include #include #include #include "mlx/ops.h" diff --git a/python/src/ops.cpp b/python/src/ops.cpp index 2cf8b70dd..6455cfc22 100644 --- a/python/src/ops.cpp +++ b/python/src/ops.cpp @@ -2931,38 +2931,22 @@ void init_ops(py::module_& m) { "load", &mlx_load_helper, "file"_a, + "format"_a = none, py::pos_only(), py::kw_only(), "stream"_a = none, R"pbdoc( - load(file: str, /, *, stream: Union[None, Stream, Device] = None) -> Union[array, Dict[str, array]] + load(file: str, format: Optional[str] = None, /, *, stream: Union[None, Stream, Device] = None) -> Union[array, Dict[str, array]] Load array(s) from a binary file in ``.npy`` or ``.npz`` format. Args: file (file, str): File in which the array is saved - + format (str, optional): Format of the file. If ``None``, the format + is inferred from the file extension. Supported formats: npy, npz, and safetensors. (default: ``None``) Returns: result (array, dict): The loaded array if ``.npy`` file or a dict mapping name to array if ``.npz`` file )pbdoc"); - m.def( - "load_safetensor", - &mlx_load_safetensor_helper, - "file"_a, - py::pos_only(), - py::kw_only(), - "stream"_a = none, - R"pbdoc( - load_safetensor(file: str, /, *, stream: Union[None, Stream, Device] = None) -> Dict[str, array] - - Load array(s) from a binary file in ``.safetensors`` format. - - Args: - file (file, str): File in which the array is saved - - Returns: - result dict: The loaded dict mapping name to array from the ``.safetensors`` file - )pbdoc"); m.def( "save_safetensor", &mlx_save_safetensor_helper,