allow pathlib.Path to save/load functions

This commit is contained in:
Awni Hannun 2025-08-25 13:52:12 -07:00
parent ac85ddfdb7
commit 066d77244c
2 changed files with 48 additions and 30 deletions

View File

@ -23,6 +23,14 @@ using namespace nb::literals;
// Helpers
///////////////////////////////////////////////////////////////////////////////
bool is_str_or_path(nb::object obj) {
if (nb::isinstance<nb::str>(obj)) {
return true;
}
nb::object path_type = nb::module_::import_("pathlib").attr("Path");
return nb::isinstance(obj, path_type);
}
bool is_istream_object(const nb::object& file) {
return nb::hasattr(file, "readinto") && nb::hasattr(file, "seek") &&
nb::hasattr(file, "tell") && nb::hasattr(file, "closed");
@ -172,8 +180,9 @@ std::pair<
std::unordered_map<std::string, mx::array>,
std::unordered_map<std::string, std::string>>
mlx_load_safetensor_helper(nb::object file, mx::StreamOrDevice s) {
if (nb::isinstance<nb::str>(file)) { // Assume .safetensors file path string
return mx::load_safetensors(nb::cast<std::string>(file), s);
if (is_str_or_path(file)) { // Assume .safetensors file path string
auto file_str = nb::cast<std::string>(nb::str(file));
return mx::load_safetensors(file_str, s);
} else if (is_istream_object(file)) {
// If we don't own the stream and it was passed to us, eval immediately
auto res = mx::load_safetensors(std::make_shared<PyFileReader>(file), s);
@ -191,8 +200,9 @@ mlx_load_safetensor_helper(nb::object file, mx::StreamOrDevice s) {
}
mx::GGUFLoad mlx_load_gguf_helper(nb::object file, mx::StreamOrDevice s) {
if (nb::isinstance<nb::str>(file)) { // Assume .gguf file path string
return mx::load_gguf(nb::cast<std::string>(file), s);
if (is_str_or_path(file)) { // Assume .gguf file path string
auto file_str = nb::cast<std::string>(nb::str(file));
return mx::load_gguf(file_str, s);
}
throw std::invalid_argument("[load_gguf] Input must be a string");
@ -201,7 +211,7 @@ mx::GGUFLoad mlx_load_gguf_helper(nb::object file, mx::StreamOrDevice s) {
std::unordered_map<std::string, mx::array> mlx_load_npz_helper(
nb::object file,
mx::StreamOrDevice s) {
bool own_file = nb::isinstance<nb::str>(file);
bool own_file = is_str_or_path(file);
nb::module_ zipfile = nb::module_::import_("zipfile");
if (!is_zip_file(zipfile, file)) {
@ -242,8 +252,9 @@ std::unordered_map<std::string, mx::array> mlx_load_npz_helper(
}
mx::array mlx_load_npy_helper(nb::object file, mx::StreamOrDevice s) {
if (nb::isinstance<nb::str>(file)) { // Assume .npy file path string
return mx::load(nb::cast<std::string>(file), s);
if (is_str_or_path(file)) { // Assume .npy file path string
auto file_str = nb::cast<std::string>(nb::str(file));
return mx::load(file_str, s);
} else if (is_istream_object(file)) {
// If we don't own the stream and it was passed to us, eval immediately
auto arr = mx::load(std::make_shared<PyFileReader>(file), s);
@ -264,8 +275,8 @@ LoadOutputTypes mlx_load_helper(
mx::StreamOrDevice s) {
if (!format.has_value()) {
std::string fname;
if (nb::isinstance<nb::str>(file)) {
fname = nb::cast<std::string>(file);
if (is_str_or_path(file)) {
fname = nb::cast<std::string>(nb::str(file));
} else if (is_istream_object(file)) {
fname = nb::cast<std::string>(file.attr("name"));
} else {
@ -384,8 +395,9 @@ class PyFileWriter : public mx::io::Writer {
};
void mlx_save_helper(nb::object file, mx::array a) {
if (nb::isinstance<nb::str>(file)) {
mx::save(nb::cast<std::string>(file), a);
if (is_str_or_path(file)) {
auto file_str = nb::cast<std::string>(nb::str(file));
mx::save(file_str, a);
return;
} else if (is_ostream_object(file)) {
auto writer = std::make_shared<PyFileWriter>(file);
@ -409,8 +421,8 @@ void mlx_savez_helper(
// Add .npz to the end of the filename if not already there
nb::object file = file_;
if (nb::isinstance<nb::str>(file_)) {
std::string fname = nb::cast<std::string>(file_);
if (is_str_or_path(file)) {
std::string fname = nb::cast<std::string>(nb::str(file_));
// Add .npz to file name if it is not there
if (fname.length() < 4 || fname.substr(fname.length() - 4, 4) != ".npz")
@ -473,11 +485,11 @@ void mlx_save_safetensor_helper(
metadata_map = std::unordered_map<std::string, std::string>();
}
auto arrays_map = nb::cast<std::unordered_map<std::string, mx::array>>(d);
if (nb::isinstance<nb::str>(file)) {
if (is_str_or_path(file)) {
{
auto file_str = nb::cast<std::string>(nb::str(file));
nb::gil_scoped_release nogil;
mx::save_safetensors(
nb::cast<std::string>(file), arrays_map, metadata_map);
mx::save_safetensors(file_str, arrays_map, metadata_map);
}
} else if (is_ostream_object(file)) {
auto writer = std::make_shared<PyFileWriter>(file);
@ -496,19 +508,21 @@ void mlx_save_gguf_helper(
nb::dict a,
std::optional<nb::dict> m) {
auto arrays_map = nb::cast<std::unordered_map<std::string, mx::array>>(a);
if (nb::isinstance<nb::str>(file)) {
if (is_str_or_path(file)) {
if (m) {
auto metadata_map =
nb::cast<std::unordered_map<std::string, mx::GGUFMetaData>>(
m.value());
{
auto file_str = nb::cast<std::string>(nb::str(file));
nb::gil_scoped_release nogil;
mx::save_gguf(nb::cast<std::string>(file), arrays_map, metadata_map);
mx::save_gguf(file_str, arrays_map, metadata_map);
}
} else {
{
auto file_str = nb::cast<std::string>(nb::str(file));
nb::gil_scoped_release nogil;
mx::save_gguf(nb::cast<std::string>(file), arrays_map);
mx::save_gguf(file_str, arrays_map);
}
}
} else {

View File

@ -3911,12 +3911,13 @@ void init_ops(nb::module_& m) {
&mlx_save_helper,
"file"_a,
"arr"_a,
nb::sig("def save(file: str, arr: array) -> None"),
nb::sig(
"def save(file: Union[file, str, pathlib.Path], arr: array) -> None"),
R"pbdoc(
Save the array to a binary file in ``.npy`` format.
Args:
file (str): File to which the array is saved
file (str, pathlib.Path, file): File to which the array is saved
arr (array): Array to be saved.
)pbdoc");
m.def(
@ -3927,6 +3928,8 @@ void init_ops(nb::module_& m) {
"file"_a,
"args"_a,
"kwargs"_a,
nb::sig(
"def savez(file: Union[file, str, pathlib.Path], *args, **kwargs)"),
R"pbdoc(
Save several arrays to a binary file in uncompressed ``.npz``
format.
@ -3946,7 +3949,7 @@ void init_ops(nb::module_& m) {
mx.savez("model.npz", **dict(flat_params))
Args:
file (file, str): Path to file to which the arrays are saved.
file (file, str, pathlib.Path): Path to file to which the arrays are saved.
*args (arrays): Arrays to be saved.
**kwargs (arrays): Arrays to be saved. Each array will be saved
with the associated keyword as the output file name.
@ -3959,12 +3962,13 @@ void init_ops(nb::module_& m) {
nb::arg(),
"args"_a,
"kwargs"_a,
nb::sig("def savez_compressed(file: str, *args, **kwargs)"),
nb::sig(
"def savez_compressed(file: Union[file, str, pathlib.Path], *args, **kwargs)"),
R"pbdoc(
Save several arrays to a binary file in compressed ``.npz`` format.
Args:
file (file, str): Path to file to which the arrays are saved.
file (file, str, pathlib.Path): Path to file to which the arrays are saved.
*args (arrays): Arrays to be saved.
**kwargs (arrays): Arrays to be saved. Each array will be saved
with the associated keyword as the output file name.
@ -3978,7 +3982,7 @@ void init_ops(nb::module_& m) {
nb::kw_only(),
"stream"_a = nb::none(),
nb::sig(
"def load(file: str, /, format: Optional[str] = None, return_metadata: bool = False, *, stream: Union[None, Stream, Device] = None) -> Union[array, dict[str, array]]"),
"def load(file: Union[file, str, pathlib.Path], /, format: Optional[str] = None, return_metadata: bool = False, *, stream: Union[None, Stream, Device] = None) -> Union[array, dict[str, array]]"),
R"pbdoc(
Load array(s) from a binary file.
@ -3986,7 +3990,7 @@ void init_ops(nb::module_& m) {
``.gguf``.
Args:
file (file, str): File in which the array is saved.
file (file, str, pathlib.Path): File in which the array is saved.
format (str, optional): Format of the file. If ``None``, the
format is inferred from the file extension. Supported formats:
``npy``, ``npz``, and ``safetensors``. Default: ``None``.
@ -4012,7 +4016,7 @@ void init_ops(nb::module_& m) {
"arrays"_a,
"metadata"_a = nb::none(),
nb::sig(
"def save_safetensors(file: str, arrays: dict[str, array], metadata: Optional[dict[str, str]] = None)"),
"def save_safetensors(file: Union[file, str, pathlib.Path], arrays: dict[str, array], metadata: Optional[dict[str, str]] = None)"),
R"pbdoc(
Save array(s) to a binary file in ``.safetensors`` format.
@ -4021,7 +4025,7 @@ void init_ops(nb::module_& m) {
information on the format.
Args:
file (file, str): File in which the array is saved.
file (file, str, pathlib.Path): File in which the array is saved.
arrays (dict(str, array)): The dictionary of names to arrays to
be saved.
metadata (dict(str, str), optional): The dictionary of
@ -4034,7 +4038,7 @@ void init_ops(nb::module_& m) {
"arrays"_a,
"metadata"_a = nb::none(),
nb::sig(
"def save_gguf(file: str, arrays: dict[str, array], metadata: dict[str, Union[array, str, list[str]]])"),
"def save_gguf(file: Union[file, str, pathlib.Path], arrays: dict[str, array], metadata: dict[str, Union[array, str, list[str]]])"),
R"pbdoc(
Save array(s) to a binary file in ``.gguf`` format.
@ -4043,7 +4047,7 @@ void init_ops(nb::module_& m) {
more information on the format.
Args:
file (file, str): File in which the array is saved.
file (file, str, pathlib.Path): File in which the array is saved.
arrays (dict(str, array)): The dictionary of names to arrays to
be saved.
metadata (dict(str, Union[array, str, list(str)])): The dictionary