mirror of
https://github.com/ml-explore/mlx.git
synced 2025-08-22 13:07:51 +08:00
updated load api to take format as an option else infer from extension
This commit is contained in:
parent
c81f5a5b94
commit
dc5abdc4c4
@ -6,7 +6,6 @@
|
|||||||
#include <cstring>
|
#include <cstring>
|
||||||
#include <fstream>
|
#include <fstream>
|
||||||
#include <stdexcept>
|
#include <stdexcept>
|
||||||
#include <string>
|
|
||||||
#include <string_view>
|
#include <string_view>
|
||||||
#include <unordered_map>
|
#include <unordered_map>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
@ -179,14 +178,18 @@ std::unordered_map<std::string, array> mlx_load_safetensor_helper(
|
|||||||
}
|
}
|
||||||
|
|
||||||
throw std::invalid_argument(
|
throw std::invalid_argument(
|
||||||
"[load] Input must be a file-like object, string, or pathlib.Path");
|
"[load_safetensor] Input must be a file-like object, string, or pathlib.Path");
|
||||||
}
|
}
|
||||||
|
|
||||||
DictOrArray mlx_load_helper(py::object file, StreamOrDevice s) {
|
std::unordered_map<std::string, array> mlx_load_npz_helper(
|
||||||
|
py::object file,
|
||||||
|
StreamOrDevice s) {
|
||||||
py::module_ zipfile = py::module_::import("zipfile");
|
py::module_ zipfile = py::module_::import("zipfile");
|
||||||
|
if (!is_zip_file(zipfile, file)) {
|
||||||
// Assume .npz file if it is zipped
|
throw std::invalid_argument(
|
||||||
if (is_zip_file(zipfile, file)) {
|
"[load_npz] Input must be a zip file or a file-like object that can be "
|
||||||
|
"opened with zipfile.ZipFile");
|
||||||
|
}
|
||||||
// Output dictionary filename in zip -> loaded array
|
// Output dictionary filename in zip -> loaded array
|
||||||
std::unordered_map<std::string, array> array_dict;
|
std::unordered_map<std::string, array> array_dict;
|
||||||
|
|
||||||
@ -215,7 +218,10 @@ DictOrArray mlx_load_helper(py::object file, StreamOrDevice s) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
return {array_dict};
|
return {array_dict};
|
||||||
} else if (py::isinstance<py::str>(file)) { // Assume .npy file path string
|
}
|
||||||
|
|
||||||
|
array mlx_load_npy_helper(py::object file, StreamOrDevice s) {
|
||||||
|
if (py::isinstance<py::str>(file)) { // Assume .npy file path string
|
||||||
return {load(py::cast<std::string>(file), s)};
|
return {load(py::cast<std::string>(file), s)};
|
||||||
} else if (is_istream_object(file)) {
|
} else if (is_istream_object(file)) {
|
||||||
// If we don't own the stream and it was passed to us, eval immediately
|
// If we don't own the stream and it was passed to us, eval immediately
|
||||||
@ -226,9 +232,39 @@ DictOrArray mlx_load_helper(py::object file, StreamOrDevice s) {
|
|||||||
}
|
}
|
||||||
return {arr};
|
return {arr};
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
DictOrArray mlx_load_helper(
|
||||||
|
py::object file,
|
||||||
|
std::optional<std::string> format,
|
||||||
|
StreamOrDevice s) {
|
||||||
|
if (!format.has_value()) {
|
||||||
|
std::string fname;
|
||||||
|
if (py::isinstance<py::str>(file)) {
|
||||||
|
fname = py::cast<std::string>(file);
|
||||||
|
} else if (is_istream_object(file)) {
|
||||||
|
fname = file.attr("name").cast<std::string>();
|
||||||
|
} else {
|
||||||
throw std::invalid_argument(
|
throw std::invalid_argument(
|
||||||
"[load] Input must be a file-like object, string, or pathlib.Path");
|
"[load] Input must be a file-like object, string, or pathlib.Path");
|
||||||
|
}
|
||||||
|
size_t ext = fname.find_last_of('.');
|
||||||
|
if (ext == std::string::npos) {
|
||||||
|
throw std::invalid_argument(
|
||||||
|
"[load] Could not infer file format from extension");
|
||||||
|
}
|
||||||
|
format.emplace(fname.substr(ext + 1));
|
||||||
|
}
|
||||||
|
|
||||||
|
if (format.value() == "safetensors") {
|
||||||
|
return mlx_load_safetensor_helper(file, s);
|
||||||
|
} else if (format.value() == "npz") {
|
||||||
|
return mlx_load_npz_helper(file, s);
|
||||||
|
} else if (format.value() == "npy") {
|
||||||
|
return mlx_load_npy_helper(file, s);
|
||||||
|
} else {
|
||||||
|
throw std::invalid_argument("[load] Unknown file format " + format.value());
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
///////////////////////////////////////////////////////////////////////////////
|
///////////////////////////////////////////////////////////////////////////////
|
||||||
|
@ -3,6 +3,8 @@
|
|||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
#include <pybind11/pybind11.h>
|
#include <pybind11/pybind11.h>
|
||||||
|
#include <optional>
|
||||||
|
#include <string>
|
||||||
#include <unordered_map>
|
#include <unordered_map>
|
||||||
#include <variant>
|
#include <variant>
|
||||||
#include "mlx/ops.h"
|
#include "mlx/ops.h"
|
||||||
|
@ -2931,38 +2931,22 @@ void init_ops(py::module_& m) {
|
|||||||
"load",
|
"load",
|
||||||
&mlx_load_helper,
|
&mlx_load_helper,
|
||||||
"file"_a,
|
"file"_a,
|
||||||
|
"format"_a = none,
|
||||||
py::pos_only(),
|
py::pos_only(),
|
||||||
py::kw_only(),
|
py::kw_only(),
|
||||||
"stream"_a = none,
|
"stream"_a = none,
|
||||||
R"pbdoc(
|
R"pbdoc(
|
||||||
load(file: str, /, *, stream: Union[None, Stream, Device] = None) -> Union[array, Dict[str, array]]
|
load(file: str, format: Optional[str] = None, /, *, stream: Union[None, Stream, Device] = None) -> Union[array, Dict[str, array]]
|
||||||
|
|
||||||
Load array(s) from a binary file in ``.npy`` or ``.npz`` format.
|
Load array(s) from a binary file in ``.npy`` or ``.npz`` format.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
file (file, str): File in which the array is saved
|
file (file, str): File in which the array is saved
|
||||||
|
format (str, optional): Format of the file. If ``None``, the format
|
||||||
|
is inferred from the file extension. Supported formats: npy, npz, and safetensors. (default: ``None``)
|
||||||
Returns:
|
Returns:
|
||||||
result (array, dict): The loaded array if ``.npy`` file or a dict mapping name to array if ``.npz`` file
|
result (array, dict): The loaded array if ``.npy`` file or a dict mapping name to array if ``.npz`` file
|
||||||
)pbdoc");
|
)pbdoc");
|
||||||
m.def(
|
|
||||||
"load_safetensor",
|
|
||||||
&mlx_load_safetensor_helper,
|
|
||||||
"file"_a,
|
|
||||||
py::pos_only(),
|
|
||||||
py::kw_only(),
|
|
||||||
"stream"_a = none,
|
|
||||||
R"pbdoc(
|
|
||||||
load_safetensor(file: str, /, *, stream: Union[None, Stream, Device] = None) -> Dict[str, array]
|
|
||||||
|
|
||||||
Load array(s) from a binary file in ``.safetensors`` format.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
file (file, str): File in which the array is saved
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
result dict: The loaded dict mapping name to array from the ``.safetensors`` file
|
|
||||||
)pbdoc");
|
|
||||||
m.def(
|
m.def(
|
||||||
"save_safetensor",
|
"save_safetensor",
|
||||||
&mlx_save_safetensor_helper,
|
&mlx_save_safetensor_helper,
|
||||||
|
Loading…
Reference in New Issue
Block a user