allow pathlib.Path to save/load functions

This commit is contained in:
Awni Hannun 2025-08-25 13:52:12 -07:00
parent ac85ddfdb7
commit 066d77244c
2 changed files with 48 additions and 30 deletions

View File

@ -23,6 +23,14 @@ using namespace nb::literals;
// Helpers // Helpers
/////////////////////////////////////////////////////////////////////////////// ///////////////////////////////////////////////////////////////////////////////
bool is_str_or_path(nb::object obj) {
if (nb::isinstance<nb::str>(obj)) {
return true;
}
nb::object path_type = nb::module_::import_("pathlib").attr("Path");
return nb::isinstance(obj, path_type);
}
bool is_istream_object(const nb::object& file) { bool is_istream_object(const nb::object& file) {
return nb::hasattr(file, "readinto") && nb::hasattr(file, "seek") && return nb::hasattr(file, "readinto") && nb::hasattr(file, "seek") &&
nb::hasattr(file, "tell") && nb::hasattr(file, "closed"); nb::hasattr(file, "tell") && nb::hasattr(file, "closed");
@ -172,8 +180,9 @@ std::pair<
std::unordered_map<std::string, mx::array>, std::unordered_map<std::string, mx::array>,
std::unordered_map<std::string, std::string>> std::unordered_map<std::string, std::string>>
mlx_load_safetensor_helper(nb::object file, mx::StreamOrDevice s) { mlx_load_safetensor_helper(nb::object file, mx::StreamOrDevice s) {
if (nb::isinstance<nb::str>(file)) { // Assume .safetensors file path string if (is_str_or_path(file)) { // Assume .safetensors file path string
return mx::load_safetensors(nb::cast<std::string>(file), s); auto file_str = nb::cast<std::string>(nb::str(file));
return mx::load_safetensors(file_str, s);
} else if (is_istream_object(file)) { } else if (is_istream_object(file)) {
// If we don't own the stream and it was passed to us, eval immediately // If we don't own the stream and it was passed to us, eval immediately
auto res = mx::load_safetensors(std::make_shared<PyFileReader>(file), s); auto res = mx::load_safetensors(std::make_shared<PyFileReader>(file), s);
@ -191,8 +200,9 @@ mlx_load_safetensor_helper(nb::object file, mx::StreamOrDevice s) {
} }
mx::GGUFLoad mlx_load_gguf_helper(nb::object file, mx::StreamOrDevice s) { mx::GGUFLoad mlx_load_gguf_helper(nb::object file, mx::StreamOrDevice s) {
if (nb::isinstance<nb::str>(file)) { // Assume .gguf file path string if (is_str_or_path(file)) { // Assume .gguf file path string
return mx::load_gguf(nb::cast<std::string>(file), s); auto file_str = nb::cast<std::string>(nb::str(file));
return mx::load_gguf(file_str, s);
} }
throw std::invalid_argument("[load_gguf] Input must be a string"); throw std::invalid_argument("[load_gguf] Input must be a string");
@ -201,7 +211,7 @@ mx::GGUFLoad mlx_load_gguf_helper(nb::object file, mx::StreamOrDevice s) {
std::unordered_map<std::string, mx::array> mlx_load_npz_helper( std::unordered_map<std::string, mx::array> mlx_load_npz_helper(
nb::object file, nb::object file,
mx::StreamOrDevice s) { mx::StreamOrDevice s) {
bool own_file = nb::isinstance<nb::str>(file); bool own_file = is_str_or_path(file);
nb::module_ zipfile = nb::module_::import_("zipfile"); nb::module_ zipfile = nb::module_::import_("zipfile");
if (!is_zip_file(zipfile, file)) { if (!is_zip_file(zipfile, file)) {
@ -242,8 +252,9 @@ std::unordered_map<std::string, mx::array> mlx_load_npz_helper(
} }
mx::array mlx_load_npy_helper(nb::object file, mx::StreamOrDevice s) { mx::array mlx_load_npy_helper(nb::object file, mx::StreamOrDevice s) {
if (nb::isinstance<nb::str>(file)) { // Assume .npy file path string if (is_str_or_path(file)) { // Assume .npy file path string
return mx::load(nb::cast<std::string>(file), s); auto file_str = nb::cast<std::string>(nb::str(file));
return mx::load(file_str, s);
} else if (is_istream_object(file)) { } else if (is_istream_object(file)) {
// If we don't own the stream and it was passed to us, eval immediately // If we don't own the stream and it was passed to us, eval immediately
auto arr = mx::load(std::make_shared<PyFileReader>(file), s); auto arr = mx::load(std::make_shared<PyFileReader>(file), s);
@ -264,8 +275,8 @@ LoadOutputTypes mlx_load_helper(
mx::StreamOrDevice s) { mx::StreamOrDevice s) {
if (!format.has_value()) { if (!format.has_value()) {
std::string fname; std::string fname;
if (nb::isinstance<nb::str>(file)) { if (is_str_or_path(file)) {
fname = nb::cast<std::string>(file); fname = nb::cast<std::string>(nb::str(file));
} else if (is_istream_object(file)) { } else if (is_istream_object(file)) {
fname = nb::cast<std::string>(file.attr("name")); fname = nb::cast<std::string>(file.attr("name"));
} else { } else {
@ -384,8 +395,9 @@ class PyFileWriter : public mx::io::Writer {
}; };
void mlx_save_helper(nb::object file, mx::array a) { void mlx_save_helper(nb::object file, mx::array a) {
if (nb::isinstance<nb::str>(file)) { if (is_str_or_path(file)) {
mx::save(nb::cast<std::string>(file), a); auto file_str = nb::cast<std::string>(nb::str(file));
mx::save(file_str, a);
return; return;
} else if (is_ostream_object(file)) { } else if (is_ostream_object(file)) {
auto writer = std::make_shared<PyFileWriter>(file); auto writer = std::make_shared<PyFileWriter>(file);
@ -409,8 +421,8 @@ void mlx_savez_helper(
// Add .npz to the end of the filename if not already there // Add .npz to the end of the filename if not already there
nb::object file = file_; nb::object file = file_;
if (nb::isinstance<nb::str>(file_)) { if (is_str_or_path(file)) {
std::string fname = nb::cast<std::string>(file_); std::string fname = nb::cast<std::string>(nb::str(file_));
// Add .npz to file name if it is not there // Add .npz to file name if it is not there
if (fname.length() < 4 || fname.substr(fname.length() - 4, 4) != ".npz") if (fname.length() < 4 || fname.substr(fname.length() - 4, 4) != ".npz")
@ -473,11 +485,11 @@ void mlx_save_safetensor_helper(
metadata_map = std::unordered_map<std::string, std::string>(); metadata_map = std::unordered_map<std::string, std::string>();
} }
auto arrays_map = nb::cast<std::unordered_map<std::string, mx::array>>(d); auto arrays_map = nb::cast<std::unordered_map<std::string, mx::array>>(d);
if (nb::isinstance<nb::str>(file)) { if (is_str_or_path(file)) {
{ {
auto file_str = nb::cast<std::string>(nb::str(file));
nb::gil_scoped_release nogil; nb::gil_scoped_release nogil;
mx::save_safetensors( mx::save_safetensors(file_str, arrays_map, metadata_map);
nb::cast<std::string>(file), arrays_map, metadata_map);
} }
} else if (is_ostream_object(file)) { } else if (is_ostream_object(file)) {
auto writer = std::make_shared<PyFileWriter>(file); auto writer = std::make_shared<PyFileWriter>(file);
@ -496,19 +508,21 @@ void mlx_save_gguf_helper(
nb::dict a, nb::dict a,
std::optional<nb::dict> m) { std::optional<nb::dict> m) {
auto arrays_map = nb::cast<std::unordered_map<std::string, mx::array>>(a); auto arrays_map = nb::cast<std::unordered_map<std::string, mx::array>>(a);
if (nb::isinstance<nb::str>(file)) { if (is_str_or_path(file)) {
if (m) { if (m) {
auto metadata_map = auto metadata_map =
nb::cast<std::unordered_map<std::string, mx::GGUFMetaData>>( nb::cast<std::unordered_map<std::string, mx::GGUFMetaData>>(
m.value()); m.value());
{ {
auto file_str = nb::cast<std::string>(nb::str(file));
nb::gil_scoped_release nogil; nb::gil_scoped_release nogil;
mx::save_gguf(nb::cast<std::string>(file), arrays_map, metadata_map); mx::save_gguf(file_str, arrays_map, metadata_map);
} }
} else { } else {
{ {
auto file_str = nb::cast<std::string>(nb::str(file));
nb::gil_scoped_release nogil; nb::gil_scoped_release nogil;
mx::save_gguf(nb::cast<std::string>(file), arrays_map); mx::save_gguf(file_str, arrays_map);
} }
} }
} else { } else {

View File

@ -3911,12 +3911,13 @@ void init_ops(nb::module_& m) {
&mlx_save_helper, &mlx_save_helper,
"file"_a, "file"_a,
"arr"_a, "arr"_a,
nb::sig("def save(file: str, arr: array) -> None"), nb::sig(
"def save(file: Union[file, str, pathlib.Path], arr: array) -> None"),
R"pbdoc( R"pbdoc(
Save the array to a binary file in ``.npy`` format. Save the array to a binary file in ``.npy`` format.
Args: Args:
file (str): File to which the array is saved file (str, pathlib.Path, file): File to which the array is saved
arr (array): Array to be saved. arr (array): Array to be saved.
)pbdoc"); )pbdoc");
m.def( m.def(
@ -3927,6 +3928,8 @@ void init_ops(nb::module_& m) {
"file"_a, "file"_a,
"args"_a, "args"_a,
"kwargs"_a, "kwargs"_a,
nb::sig(
"def savez(file: Union[file, str, pathlib.Path], *args, **kwargs)"),
R"pbdoc( R"pbdoc(
Save several arrays to a binary file in uncompressed ``.npz`` Save several arrays to a binary file in uncompressed ``.npz``
format. format.
@ -3946,7 +3949,7 @@ void init_ops(nb::module_& m) {
mx.savez("model.npz", **dict(flat_params)) mx.savez("model.npz", **dict(flat_params))
Args: Args:
file (file, str): Path to file to which the arrays are saved. file (file, str, pathlib.Path): Path to file to which the arrays are saved.
*args (arrays): Arrays to be saved. *args (arrays): Arrays to be saved.
**kwargs (arrays): Arrays to be saved. Each array will be saved **kwargs (arrays): Arrays to be saved. Each array will be saved
with the associated keyword as the output file name. with the associated keyword as the output file name.
@ -3959,12 +3962,13 @@ void init_ops(nb::module_& m) {
nb::arg(), nb::arg(),
"args"_a, "args"_a,
"kwargs"_a, "kwargs"_a,
nb::sig("def savez_compressed(file: str, *args, **kwargs)"), nb::sig(
"def savez_compressed(file: Union[file, str, pathlib.Path], *args, **kwargs)"),
R"pbdoc( R"pbdoc(
Save several arrays to a binary file in compressed ``.npz`` format. Save several arrays to a binary file in compressed ``.npz`` format.
Args: Args:
file (file, str): Path to file to which the arrays are saved. file (file, str, pathlib.Path): Path to file to which the arrays are saved.
*args (arrays): Arrays to be saved. *args (arrays): Arrays to be saved.
**kwargs (arrays): Arrays to be saved. Each array will be saved **kwargs (arrays): Arrays to be saved. Each array will be saved
with the associated keyword as the output file name. with the associated keyword as the output file name.
@ -3978,7 +3982,7 @@ void init_ops(nb::module_& m) {
nb::kw_only(), nb::kw_only(),
"stream"_a = nb::none(), "stream"_a = nb::none(),
nb::sig( nb::sig(
"def load(file: str, /, format: Optional[str] = None, return_metadata: bool = False, *, stream: Union[None, Stream, Device] = None) -> Union[array, dict[str, array]]"), "def load(file: Union[file, str, pathlib.Path], /, format: Optional[str] = None, return_metadata: bool = False, *, stream: Union[None, Stream, Device] = None) -> Union[array, dict[str, array]]"),
R"pbdoc( R"pbdoc(
Load array(s) from a binary file. Load array(s) from a binary file.
@ -3986,7 +3990,7 @@ void init_ops(nb::module_& m) {
``.gguf``. ``.gguf``.
Args: Args:
file (file, str): File in which the array is saved. file (file, str, pathlib.Path): File in which the array is saved.
format (str, optional): Format of the file. If ``None``, the format (str, optional): Format of the file. If ``None``, the
format is inferred from the file extension. Supported formats: format is inferred from the file extension. Supported formats:
``npy``, ``npz``, and ``safetensors``. Default: ``None``. ``npy``, ``npz``, and ``safetensors``. Default: ``None``.
@ -4012,7 +4016,7 @@ void init_ops(nb::module_& m) {
"arrays"_a, "arrays"_a,
"metadata"_a = nb::none(), "metadata"_a = nb::none(),
nb::sig( nb::sig(
"def save_safetensors(file: str, arrays: dict[str, array], metadata: Optional[dict[str, str]] = None)"), "def save_safetensors(file: Union[file, str, pathlib.Path], arrays: dict[str, array], metadata: Optional[dict[str, str]] = None)"),
R"pbdoc( R"pbdoc(
Save array(s) to a binary file in ``.safetensors`` format. Save array(s) to a binary file in ``.safetensors`` format.
@ -4021,7 +4025,7 @@ void init_ops(nb::module_& m) {
information on the format. information on the format.
Args: Args:
file (file, str): File in which the array is saved. file (file, str, pathlib.Path): File in which the array is saved.
arrays (dict(str, array)): The dictionary of names to arrays to arrays (dict(str, array)): The dictionary of names to arrays to
be saved. be saved.
metadata (dict(str, str), optional): The dictionary of metadata (dict(str, str), optional): The dictionary of
@ -4034,7 +4038,7 @@ void init_ops(nb::module_& m) {
"arrays"_a, "arrays"_a,
"metadata"_a = nb::none(), "metadata"_a = nb::none(),
nb::sig( nb::sig(
"def save_gguf(file: str, arrays: dict[str, array], metadata: dict[str, Union[array, str, list[str]]])"), "def save_gguf(file: Union[file, str, pathlib.Path], arrays: dict[str, array], metadata: dict[str, Union[array, str, list[str]]])"),
R"pbdoc( R"pbdoc(
Save array(s) to a binary file in ``.gguf`` format. Save array(s) to a binary file in ``.gguf`` format.
@ -4043,7 +4047,7 @@ void init_ops(nb::module_& m) {
more information on the format. more information on the format.
Args: Args:
file (file, str): File in which the array is saved. file (file, str, pathlib.Path): File in which the array is saved.
arrays (dict(str, array)): The dictionary of names to arrays to arrays (dict(str, array)): The dictionary of names to arrays to
be saved. be saved.
metadata (dict(str, Union[array, str, list(str)])): The dictionary metadata (dict(str, Union[array, str, list(str)])): The dictionary