updated load api to take format as an option else infer from extension

This commit is contained in:
dc-dc-dc 2023-12-20 12:53:53 -05:00
parent c81f5a5b94
commit dc5abdc4c4
3 changed files with 72 additions and 50 deletions

View File

@ -6,7 +6,6 @@
#include <cstring>
#include <fstream>
#include <stdexcept>
#include <string>
#include <string_view>
#include <unordered_map>
#include <vector>
@ -179,43 +178,50 @@ std::unordered_map<std::string, array> mlx_load_safetensor_helper(
}
throw std::invalid_argument(
"[load] Input must be a file-like object, string, or pathlib.Path");
"[load_safetensor] Input must be a file-like object, string, or pathlib.Path");
}
DictOrArray mlx_load_helper(py::object file, StreamOrDevice s) {
std::unordered_map<std::string, array> mlx_load_npz_helper(
py::object file,
StreamOrDevice s) {
py::module_ zipfile = py::module_::import("zipfile");
if (!is_zip_file(zipfile, file)) {
throw std::invalid_argument(
"[load_npz] Input must be a zip file or a file-like object that can be "
"opened with zipfile.ZipFile");
}
// Output dictionary filename in zip -> loaded array
std::unordered_map<std::string, array> array_dict;
// Assume .npz file if it is zipped
if (is_zip_file(zipfile, file)) {
// Output dictionary filename in zip -> loaded array
std::unordered_map<std::string, array> array_dict;
// Create python ZipFile object
ZipFileWrapper zipfile_object(zipfile, file);
for (const std::string& st : zipfile_object.namelist()) {
// Open zip file as a python file stream
py::object sub_file = zipfile_object.open(st);
// Create python ZipFile object
ZipFileWrapper zipfile_object(zipfile, file);
for (const std::string& st : zipfile_object.namelist()) {
// Open zip file as a python file stream
py::object sub_file = zipfile_object.open(st);
// Create array from python fille stream
auto arr = load(std::make_shared<PyFileReader>(sub_file), s);
// Create array from python fille stream
auto arr = load(std::make_shared<PyFileReader>(sub_file), s);
// Remove .npy from file if it is there
auto key = st;
if (st.length() > 4 && st.substr(st.length() - 4, 4) == ".npy")
key = st.substr(0, st.length() - 4);
// Remove .npy from file if it is there
auto key = st;
if (st.length() > 4 && st.substr(st.length() - 4, 4) == ".npy")
key = st.substr(0, st.length() - 4);
// Add array to dict
array_dict.insert({key, arr});
}
// Add array to dict
array_dict.insert({key, arr});
}
// If we don't own the stream and it was passed to us, eval immediately
for (auto& [key, arr] : array_dict) {
py::gil_scoped_release gil;
arr.eval();
}
// If we don't own the stream and it was passed to us, eval immediately
for (auto& [key, arr] : array_dict) {
py::gil_scoped_release gil;
arr.eval();
}
return {array_dict};
}
return {array_dict};
} else if (py::isinstance<py::str>(file)) { // Assume .npy file path string
array mlx_load_npy_helper(py::object file, StreamOrDevice s) {
if (py::isinstance<py::str>(file)) { // Assume .npy file path string
return {load(py::cast<std::string>(file), s)};
} else if (is_istream_object(file)) {
// If we don't own the stream and it was passed to us, eval immediately
@ -226,9 +232,39 @@ DictOrArray mlx_load_helper(py::object file, StreamOrDevice s) {
}
return {arr};
}
}
throw std::invalid_argument(
"[load] Input must be a file-like object, string, or pathlib.Path");
DictOrArray mlx_load_helper(
py::object file,
std::optional<std::string> format,
StreamOrDevice s) {
if (!format.has_value()) {
std::string fname;
if (py::isinstance<py::str>(file)) {
fname = py::cast<std::string>(file);
} else if (is_istream_object(file)) {
fname = file.attr("name").cast<std::string>();
} else {
throw std::invalid_argument(
"[load] Input must be a file-like object, string, or pathlib.Path");
}
size_t ext = fname.find_last_of('.');
if (ext == std::string::npos) {
throw std::invalid_argument(
"[load] Could not infer file format from extension");
}
format.emplace(fname.substr(ext + 1));
}
if (format.value() == "safetensors") {
return mlx_load_safetensor_helper(file, s);
} else if (format.value() == "npz") {
return mlx_load_npz_helper(file, s);
} else if (format.value() == "npy") {
return mlx_load_npy_helper(file, s);
} else {
throw std::invalid_argument("[load] Unknown file format " + format.value());
}
}
///////////////////////////////////////////////////////////////////////////////

View File

@ -3,6 +3,8 @@
#pragma once
#include <pybind11/pybind11.h>
#include <optional>
#include <string>
#include <unordered_map>
#include <variant>
#include "mlx/ops.h"

View File

@ -2931,38 +2931,22 @@ void init_ops(py::module_& m) {
"load",
&mlx_load_helper,
"file"_a,
"format"_a = none,
py::pos_only(),
py::kw_only(),
"stream"_a = none,
R"pbdoc(
load(file: str, /, *, stream: Union[None, Stream, Device] = None) -> Union[array, Dict[str, array]]
load(file: str, format: Optional[str] = None, /, *, stream: Union[None, Stream, Device] = None) -> Union[array, Dict[str, array]]
Load array(s) from a binary file in ``.npy`` or ``.npz`` format.
Args:
file (file, str): File in which the array is saved
format (str, optional): Format of the file. If ``None``, the format
is inferred from the file extension. Supported formats: npy, npz, and safetensors. (default: ``None``)
Returns:
result (array, dict): The loaded array if ``.npy`` file or a dict mapping name to array if ``.npz`` file
)pbdoc");
m.def(
"load_safetensor",
&mlx_load_safetensor_helper,
"file"_a,
py::pos_only(),
py::kw_only(),
"stream"_a = none,
R"pbdoc(
load_safetensor(file: str, /, *, stream: Union[None, Stream, Device] = None) -> Dict[str, array]
Load array(s) from a binary file in ``.safetensors`` format.
Args:
file (file, str): File in which the array is saved
Returns:
result dict: The loaded dict mapping name to array from the ``.safetensors`` file
)pbdoc");
m.def(
"save_safetensor",
&mlx_save_safetensor_helper,