mirror of
https://github.com/ml-explore/mlx.git
synced 2025-08-28 00:36:32 +08:00
allow pathlib.Path to save/load functions (#2541)
This commit is contained in:
parent
d2f540f4e0
commit
db14e29a0b
@ -23,6 +23,14 @@ using namespace nb::literals;
|
||||
// Helpers
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
bool is_str_or_path(nb::object obj) {
|
||||
if (nb::isinstance<nb::str>(obj)) {
|
||||
return true;
|
||||
}
|
||||
nb::object path_type = nb::module_::import_("pathlib").attr("Path");
|
||||
return nb::isinstance(obj, path_type);
|
||||
}
|
||||
|
||||
bool is_istream_object(const nb::object& file) {
|
||||
return nb::hasattr(file, "readinto") && nb::hasattr(file, "seek") &&
|
||||
nb::hasattr(file, "tell") && nb::hasattr(file, "closed");
|
||||
@ -172,8 +180,9 @@ std::pair<
|
||||
std::unordered_map<std::string, mx::array>,
|
||||
std::unordered_map<std::string, std::string>>
|
||||
mlx_load_safetensor_helper(nb::object file, mx::StreamOrDevice s) {
|
||||
if (nb::isinstance<nb::str>(file)) { // Assume .safetensors file path string
|
||||
return mx::load_safetensors(nb::cast<std::string>(file), s);
|
||||
if (is_str_or_path(file)) { // Assume .safetensors file path string
|
||||
auto file_str = nb::cast<std::string>(nb::str(file));
|
||||
return mx::load_safetensors(file_str, s);
|
||||
} else if (is_istream_object(file)) {
|
||||
// If we don't own the stream and it was passed to us, eval immediately
|
||||
auto res = mx::load_safetensors(std::make_shared<PyFileReader>(file), s);
|
||||
@ -191,8 +200,9 @@ mlx_load_safetensor_helper(nb::object file, mx::StreamOrDevice s) {
|
||||
}
|
||||
|
||||
mx::GGUFLoad mlx_load_gguf_helper(nb::object file, mx::StreamOrDevice s) {
|
||||
if (nb::isinstance<nb::str>(file)) { // Assume .gguf file path string
|
||||
return mx::load_gguf(nb::cast<std::string>(file), s);
|
||||
if (is_str_or_path(file)) { // Assume .gguf file path string
|
||||
auto file_str = nb::cast<std::string>(nb::str(file));
|
||||
return mx::load_gguf(file_str, s);
|
||||
}
|
||||
|
||||
throw std::invalid_argument("[load_gguf] Input must be a string");
|
||||
@ -201,7 +211,7 @@ mx::GGUFLoad mlx_load_gguf_helper(nb::object file, mx::StreamOrDevice s) {
|
||||
std::unordered_map<std::string, mx::array> mlx_load_npz_helper(
|
||||
nb::object file,
|
||||
mx::StreamOrDevice s) {
|
||||
bool own_file = nb::isinstance<nb::str>(file);
|
||||
bool own_file = is_str_or_path(file);
|
||||
|
||||
nb::module_ zipfile = nb::module_::import_("zipfile");
|
||||
if (!is_zip_file(zipfile, file)) {
|
||||
@ -242,8 +252,9 @@ std::unordered_map<std::string, mx::array> mlx_load_npz_helper(
|
||||
}
|
||||
|
||||
mx::array mlx_load_npy_helper(nb::object file, mx::StreamOrDevice s) {
|
||||
if (nb::isinstance<nb::str>(file)) { // Assume .npy file path string
|
||||
return mx::load(nb::cast<std::string>(file), s);
|
||||
if (is_str_or_path(file)) { // Assume .npy file path string
|
||||
auto file_str = nb::cast<std::string>(nb::str(file));
|
||||
return mx::load(file_str, s);
|
||||
} else if (is_istream_object(file)) {
|
||||
// If we don't own the stream and it was passed to us, eval immediately
|
||||
auto arr = mx::load(std::make_shared<PyFileReader>(file), s);
|
||||
@ -264,8 +275,8 @@ LoadOutputTypes mlx_load_helper(
|
||||
mx::StreamOrDevice s) {
|
||||
if (!format.has_value()) {
|
||||
std::string fname;
|
||||
if (nb::isinstance<nb::str>(file)) {
|
||||
fname = nb::cast<std::string>(file);
|
||||
if (is_str_or_path(file)) {
|
||||
fname = nb::cast<std::string>(nb::str(file));
|
||||
} else if (is_istream_object(file)) {
|
||||
fname = nb::cast<std::string>(file.attr("name"));
|
||||
} else {
|
||||
@ -384,8 +395,9 @@ class PyFileWriter : public mx::io::Writer {
|
||||
};
|
||||
|
||||
void mlx_save_helper(nb::object file, mx::array a) {
|
||||
if (nb::isinstance<nb::str>(file)) {
|
||||
mx::save(nb::cast<std::string>(file), a);
|
||||
if (is_str_or_path(file)) {
|
||||
auto file_str = nb::cast<std::string>(nb::str(file));
|
||||
mx::save(file_str, a);
|
||||
return;
|
||||
} else if (is_ostream_object(file)) {
|
||||
auto writer = std::make_shared<PyFileWriter>(file);
|
||||
@ -409,8 +421,8 @@ void mlx_savez_helper(
|
||||
// Add .npz to the end of the filename if not already there
|
||||
nb::object file = file_;
|
||||
|
||||
if (nb::isinstance<nb::str>(file_)) {
|
||||
std::string fname = nb::cast<std::string>(file_);
|
||||
if (is_str_or_path(file)) {
|
||||
std::string fname = nb::cast<std::string>(nb::str(file_));
|
||||
|
||||
// Add .npz to file name if it is not there
|
||||
if (fname.length() < 4 || fname.substr(fname.length() - 4, 4) != ".npz")
|
||||
@ -473,11 +485,11 @@ void mlx_save_safetensor_helper(
|
||||
metadata_map = std::unordered_map<std::string, std::string>();
|
||||
}
|
||||
auto arrays_map = nb::cast<std::unordered_map<std::string, mx::array>>(d);
|
||||
if (nb::isinstance<nb::str>(file)) {
|
||||
if (is_str_or_path(file)) {
|
||||
{
|
||||
auto file_str = nb::cast<std::string>(nb::str(file));
|
||||
nb::gil_scoped_release nogil;
|
||||
mx::save_safetensors(
|
||||
nb::cast<std::string>(file), arrays_map, metadata_map);
|
||||
mx::save_safetensors(file_str, arrays_map, metadata_map);
|
||||
}
|
||||
} else if (is_ostream_object(file)) {
|
||||
auto writer = std::make_shared<PyFileWriter>(file);
|
||||
@ -496,19 +508,21 @@ void mlx_save_gguf_helper(
|
||||
nb::dict a,
|
||||
std::optional<nb::dict> m) {
|
||||
auto arrays_map = nb::cast<std::unordered_map<std::string, mx::array>>(a);
|
||||
if (nb::isinstance<nb::str>(file)) {
|
||||
if (is_str_or_path(file)) {
|
||||
if (m) {
|
||||
auto metadata_map =
|
||||
nb::cast<std::unordered_map<std::string, mx::GGUFMetaData>>(
|
||||
m.value());
|
||||
{
|
||||
auto file_str = nb::cast<std::string>(nb::str(file));
|
||||
nb::gil_scoped_release nogil;
|
||||
mx::save_gguf(nb::cast<std::string>(file), arrays_map, metadata_map);
|
||||
mx::save_gguf(file_str, arrays_map, metadata_map);
|
||||
}
|
||||
} else {
|
||||
{
|
||||
auto file_str = nb::cast<std::string>(nb::str(file));
|
||||
nb::gil_scoped_release nogil;
|
||||
mx::save_gguf(nb::cast<std::string>(file), arrays_map);
|
||||
mx::save_gguf(file_str, arrays_map);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
|
@ -3911,12 +3911,13 @@ void init_ops(nb::module_& m) {
|
||||
&mlx_save_helper,
|
||||
"file"_a,
|
||||
"arr"_a,
|
||||
nb::sig("def save(file: str, arr: array) -> None"),
|
||||
nb::sig(
|
||||
"def save(file: Union[file, str, pathlib.Path], arr: array) -> None"),
|
||||
R"pbdoc(
|
||||
Save the array to a binary file in ``.npy`` format.
|
||||
|
||||
Args:
|
||||
file (str): File to which the array is saved
|
||||
file (str, pathlib.Path, file): File to which the array is saved
|
||||
arr (array): Array to be saved.
|
||||
)pbdoc");
|
||||
m.def(
|
||||
@ -3927,6 +3928,8 @@ void init_ops(nb::module_& m) {
|
||||
"file"_a,
|
||||
"args"_a,
|
||||
"kwargs"_a,
|
||||
nb::sig(
|
||||
"def savez(file: Union[file, str, pathlib.Path], *args, **kwargs)"),
|
||||
R"pbdoc(
|
||||
Save several arrays to a binary file in uncompressed ``.npz``
|
||||
format.
|
||||
@ -3946,7 +3949,7 @@ void init_ops(nb::module_& m) {
|
||||
mx.savez("model.npz", **dict(flat_params))
|
||||
|
||||
Args:
|
||||
file (file, str): Path to file to which the arrays are saved.
|
||||
file (file, str, pathlib.Path): Path to file to which the arrays are saved.
|
||||
*args (arrays): Arrays to be saved.
|
||||
**kwargs (arrays): Arrays to be saved. Each array will be saved
|
||||
with the associated keyword as the output file name.
|
||||
@ -3959,12 +3962,13 @@ void init_ops(nb::module_& m) {
|
||||
nb::arg(),
|
||||
"args"_a,
|
||||
"kwargs"_a,
|
||||
nb::sig("def savez_compressed(file: str, *args, **kwargs)"),
|
||||
nb::sig(
|
||||
"def savez_compressed(file: Union[file, str, pathlib.Path], *args, **kwargs)"),
|
||||
R"pbdoc(
|
||||
Save several arrays to a binary file in compressed ``.npz`` format.
|
||||
|
||||
Args:
|
||||
file (file, str): Path to file to which the arrays are saved.
|
||||
file (file, str, pathlib.Path): Path to file to which the arrays are saved.
|
||||
*args (arrays): Arrays to be saved.
|
||||
**kwargs (arrays): Arrays to be saved. Each array will be saved
|
||||
with the associated keyword as the output file name.
|
||||
@ -3978,7 +3982,7 @@ void init_ops(nb::module_& m) {
|
||||
nb::kw_only(),
|
||||
"stream"_a = nb::none(),
|
||||
nb::sig(
|
||||
"def load(file: str, /, format: Optional[str] = None, return_metadata: bool = False, *, stream: Union[None, Stream, Device] = None) -> Union[array, dict[str, array]]"),
|
||||
"def load(file: Union[file, str, pathlib.Path], /, format: Optional[str] = None, return_metadata: bool = False, *, stream: Union[None, Stream, Device] = None) -> Union[array, dict[str, array]]"),
|
||||
R"pbdoc(
|
||||
Load array(s) from a binary file.
|
||||
|
||||
@ -3986,7 +3990,7 @@ void init_ops(nb::module_& m) {
|
||||
``.gguf``.
|
||||
|
||||
Args:
|
||||
file (file, str): File in which the array is saved.
|
||||
file (file, str, pathlib.Path): File in which the array is saved.
|
||||
format (str, optional): Format of the file. If ``None``, the
|
||||
format is inferred from the file extension. Supported formats:
|
||||
``npy``, ``npz``, and ``safetensors``. Default: ``None``.
|
||||
@ -4012,7 +4016,7 @@ void init_ops(nb::module_& m) {
|
||||
"arrays"_a,
|
||||
"metadata"_a = nb::none(),
|
||||
nb::sig(
|
||||
"def save_safetensors(file: str, arrays: dict[str, array], metadata: Optional[dict[str, str]] = None)"),
|
||||
"def save_safetensors(file: Union[file, str, pathlib.Path], arrays: dict[str, array], metadata: Optional[dict[str, str]] = None)"),
|
||||
R"pbdoc(
|
||||
Save array(s) to a binary file in ``.safetensors`` format.
|
||||
|
||||
@ -4021,7 +4025,7 @@ void init_ops(nb::module_& m) {
|
||||
information on the format.
|
||||
|
||||
Args:
|
||||
file (file, str): File in which the array is saved.
|
||||
file (file, str, pathlib.Path): File in which the array is saved.
|
||||
arrays (dict(str, array)): The dictionary of names to arrays to
|
||||
be saved.
|
||||
metadata (dict(str, str), optional): The dictionary of
|
||||
@ -4034,7 +4038,7 @@ void init_ops(nb::module_& m) {
|
||||
"arrays"_a,
|
||||
"metadata"_a = nb::none(),
|
||||
nb::sig(
|
||||
"def save_gguf(file: str, arrays: dict[str, array], metadata: dict[str, Union[array, str, list[str]]])"),
|
||||
"def save_gguf(file: Union[file, str, pathlib.Path], arrays: dict[str, array], metadata: dict[str, Union[array, str, list[str]]])"),
|
||||
R"pbdoc(
|
||||
Save array(s) to a binary file in ``.gguf`` format.
|
||||
|
||||
@ -4043,7 +4047,7 @@ void init_ops(nb::module_& m) {
|
||||
more information on the format.
|
||||
|
||||
Args:
|
||||
file (file, str): File in which the array is saved.
|
||||
file (file, str, pathlib.Path): File in which the array is saved.
|
||||
arrays (dict(str, array)): The dictionary of names to arrays to
|
||||
be saved.
|
||||
metadata (dict(str, Union[array, str, list(str)])): The dictionary
|
||||
|
Loading…
Reference in New Issue
Block a user