remove pathlib refs

This commit is contained in:
dc-dc-dc 2023-12-22 19:06:31 -05:00
parent ee6ce00aee
commit c6d7702ef0
2 changed files with 6 additions and 8 deletions

View File

@ -178,7 +178,7 @@ std::unordered_map<std::string, array> mlx_load_safetensor_helper(
} }
throw std::invalid_argument( throw std::invalid_argument(
"[load_safetensor] Input must be a file-like object, string, or pathlib.Path"); "[load_safetensor] Input must be a file-like object, or string");
} }
std::unordered_map<std::string, array> mlx_load_npz_helper( std::unordered_map<std::string, array> mlx_load_npz_helper(
@ -233,7 +233,7 @@ array mlx_load_npy_helper(py::object file, StreamOrDevice s) {
return {arr}; return {arr};
} }
throw std::invalid_argument( throw std::invalid_argument(
"[load_npy] Input must be a file-like object, string, or pathlib.Path"); "[load_npy] Input must be a file-like object, or string");
} }
DictOrArray mlx_load_helper( DictOrArray mlx_load_helper(
@ -244,11 +244,9 @@ DictOrArray mlx_load_helper(
std::string fname; std::string fname;
if (py::isinstance<py::str>(file)) { if (py::isinstance<py::str>(file)) {
fname = py::cast<std::string>(file); fname = py::cast<std::string>(file);
} else if (is_istream_object(file)) {
fname = file.attr("name").cast<std::string>();
} else { } else {
throw std::invalid_argument( throw std::invalid_argument(
"[load] Input must be a file-like object, string, or pathlib.Path"); "[load] Input must be a file-like object, or string");
} }
size_t ext = fname.find_last_of('.'); size_t ext = fname.find_last_of('.');
if (ext == std::string::npos) { if (ext == std::string::npos) {
@ -364,7 +362,7 @@ void mlx_save_helper(
} }
throw std::invalid_argument( throw std::invalid_argument(
"[save] Input must be a file-like object, string, or pathlib.Path"); "[save] Input must be a file-like object, or string");
} }
void mlx_savez_helper( void mlx_savez_helper(
@ -440,5 +438,5 @@ void mlx_save_safetensor_helper(
} }
throw std::invalid_argument( throw std::invalid_argument(
"[save_safetensor] Input must be a file-like object, string, or pathlib.Path"); "[save_safetensor] Input must be a file-like object, or string");
} }

View File

@ -20,7 +20,7 @@ std::unordered_map<std::string, array> mlx_load_safetensor_helper(
void mlx_save_safetensor_helper( void mlx_save_safetensor_helper(
py::object file, py::object file,
py::dict d, py::dict d,
bool retain_graph = true); std::optional<bool> retain_graph = std::nullopt);
DictOrArray mlx_load_helper( DictOrArray mlx_load_helper(
py::object file, py::object file,