mirror of
https://github.com/ml-explore/mlx.git
synced 2025-08-21 20:46:46 +08:00
add back retain_graph argument
This commit is contained in:
parent
edd55388ce
commit
fdf9d99f0f
@ -1067,8 +1067,10 @@ std::unordered_map<std::string, array> load_safetensor(
|
||||
|
||||
void save_safetensor(
|
||||
std::shared_ptr<io::Writer> in_stream,
|
||||
std::unordered_map<std::string, array>);
|
||||
std::unordered_map<std::string, array>,
|
||||
bool retain_graph = true);
|
||||
void save_safetensor(
|
||||
const std::string& file,
|
||||
std::unordered_map<std::string, array>);
|
||||
std::unordered_map<std::string, array>,
|
||||
bool retain_graph = true);
|
||||
} // namespace mlx::core
|
||||
|
@ -126,7 +126,8 @@ std::unordered_map<std::string, array> load_safetensor(
|
||||
/** Save array to out stream in .npy format */
|
||||
void save_safetensor(
|
||||
std::shared_ptr<io::Writer> out_stream,
|
||||
std::unordered_map<std::string, array> a) {
|
||||
std::unordered_map<std::string, array> a,
|
||||
bool retain_graph) {
|
||||
////////////////////////////////////////////////////////
|
||||
// Check file
|
||||
if (!out_stream->good() || !out_stream->is_open()) {
|
||||
@ -142,7 +143,7 @@ void save_safetensor(
|
||||
});
|
||||
size_t offset = 0;
|
||||
for (auto& [key, arr] : a) {
|
||||
arr.eval(false);
|
||||
arr.eval(retain_graph);
|
||||
if (arr.nbytes() == 0) {
|
||||
throw std::invalid_argument(
|
||||
"[save_safetensor] cannot serialize an empty array key: " + key);
|
||||
@ -172,7 +173,8 @@ void save_safetensor(
|
||||
|
||||
void save_safetensor(
|
||||
const std::string& file_,
|
||||
std::unordered_map<std::string, array> a) {
|
||||
std::unordered_map<std::string, array> a,
|
||||
bool retain_graph) {
|
||||
// Open and check file
|
||||
std::string file = file_;
|
||||
|
||||
@ -182,7 +184,7 @@ void save_safetensor(
|
||||
file += ".safetensors";
|
||||
|
||||
// Serialize array
|
||||
save_safetensor(std::make_shared<io::FileWriter>(file), a);
|
||||
save_safetensor(std::make_shared<io::FileWriter>(file), a, retain_graph);
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
@ -419,16 +419,19 @@ void mlx_savez_helper(
|
||||
return;
|
||||
}
|
||||
|
||||
void mlx_save_safetensor_helper(py::object file, py::dict d) {
|
||||
void mlx_save_safetensor_helper(
|
||||
py::object file,
|
||||
py::dict d,
|
||||
bool retain_graph) {
|
||||
auto arrays_map = d.cast<std::unordered_map<std::string, array>>();
|
||||
if (py::isinstance<py::str>(file)) {
|
||||
save_safetensor(py::cast<std::string>(file), arrays_map);
|
||||
save_safetensor(py::cast<std::string>(file), arrays_map, retain_graph);
|
||||
return;
|
||||
} else if (is_ostream_object(file)) {
|
||||
auto writer = std::make_shared<PyFileWriter>(file);
|
||||
{
|
||||
py::gil_scoped_release gil;
|
||||
save_safetensor(writer, arrays_map);
|
||||
save_safetensor(writer, arrays_map, retain_graph);
|
||||
}
|
||||
|
||||
return;
|
||||
|
@ -17,7 +17,10 @@ using DictOrArray = std::variant<array, std::unordered_map<std::string, array>>;
|
||||
std::unordered_map<std::string, array> mlx_load_safetensor_helper(
|
||||
py::object file,
|
||||
StreamOrDevice s);
|
||||
void mlx_save_safetensor_helper(py::object file, py::dict d);
|
||||
void mlx_save_safetensor_helper(
|
||||
py::object file,
|
||||
py::dict d,
|
||||
bool retain_graph = true);
|
||||
|
||||
DictOrArray mlx_load_helper(py::object file, StreamOrDevice s);
|
||||
void mlx_save_helper(
|
||||
|
@ -2953,15 +2953,18 @@ void init_ops(py::module_& m) {
|
||||
"file"_a,
|
||||
"d"_a,
|
||||
py::pos_only(),
|
||||
"retain_graph"_a = true,
|
||||
py::kw_only(),
|
||||
R"pbdoc(
|
||||
save_safetensor(file: str, d: Dict[str, array], /, *)
|
||||
save_safetensor(file: str, d: Dict[str, array], /, retain_graph: bool = True, *)
|
||||
|
||||
Save array(s) to a binary file in ``.safetensors`` format.
|
||||
|
||||
Args:
|
||||
file (file, str): File in which the array is saved
|
||||
d (Dict[str, array]): The dict mapping name to array to be saved
|
||||
retain_graph(bool): Optional argument to retain graph
|
||||
during array evaluation before saving. Default: True
|
||||
)pbdoc");
|
||||
m.def(
|
||||
"where",
|
||||
|
Loading…
Reference in New Issue
Block a user