add back retain_graph argument

This commit is contained in:
dc-dc-dc 2023-12-20 14:39:31 -05:00
parent edd55388ce
commit fdf9d99f0f
5 changed files with 24 additions and 11 deletions

View File

@ -1067,8 +1067,10 @@ std::unordered_map<std::string, array> load_safetensor(
void save_safetensor( void save_safetensor(
std::shared_ptr<io::Writer> in_stream, std::shared_ptr<io::Writer> in_stream,
std::unordered_map<std::string, array>); std::unordered_map<std::string, array>,
bool retain_graph = true);
void save_safetensor( void save_safetensor(
const std::string& file, const std::string& file,
std::unordered_map<std::string, array>); std::unordered_map<std::string, array>,
bool retain_graph = true);
} // namespace mlx::core } // namespace mlx::core

View File

@ -126,7 +126,8 @@ std::unordered_map<std::string, array> load_safetensor(
/** Save array to out stream in .npy format */ /** Save array to out stream in .npy format */
void save_safetensor( void save_safetensor(
std::shared_ptr<io::Writer> out_stream, std::shared_ptr<io::Writer> out_stream,
std::unordered_map<std::string, array> a) { std::unordered_map<std::string, array> a,
bool retain_graph) {
//////////////////////////////////////////////////////// ////////////////////////////////////////////////////////
// Check file // Check file
if (!out_stream->good() || !out_stream->is_open()) { if (!out_stream->good() || !out_stream->is_open()) {
@ -142,7 +143,7 @@ void save_safetensor(
}); });
size_t offset = 0; size_t offset = 0;
for (auto& [key, arr] : a) { for (auto& [key, arr] : a) {
arr.eval(false); arr.eval(retain_graph);
if (arr.nbytes() == 0) { if (arr.nbytes() == 0) {
throw std::invalid_argument( throw std::invalid_argument(
"[save_safetensor] cannot serialize an empty array key: " + key); "[save_safetensor] cannot serialize an empty array key: " + key);
@ -172,7 +173,8 @@ void save_safetensor(
void save_safetensor( void save_safetensor(
const std::string& file_, const std::string& file_,
std::unordered_map<std::string, array> a) { std::unordered_map<std::string, array> a,
bool retain_graph) {
// Open and check file // Open and check file
std::string file = file_; std::string file = file_;
@ -182,7 +184,7 @@ void save_safetensor(
file += ".safetensors"; file += ".safetensors";
// Serialize array // Serialize array
save_safetensor(std::make_shared<io::FileWriter>(file), a); save_safetensor(std::make_shared<io::FileWriter>(file), a, retain_graph);
} }
} // namespace mlx::core } // namespace mlx::core

View File

@ -419,16 +419,19 @@ void mlx_savez_helper(
return; return;
} }
void mlx_save_safetensor_helper(py::object file, py::dict d) { void mlx_save_safetensor_helper(
py::object file,
py::dict d,
bool retain_graph) {
auto arrays_map = d.cast<std::unordered_map<std::string, array>>(); auto arrays_map = d.cast<std::unordered_map<std::string, array>>();
if (py::isinstance<py::str>(file)) { if (py::isinstance<py::str>(file)) {
save_safetensor(py::cast<std::string>(file), arrays_map); save_safetensor(py::cast<std::string>(file), arrays_map, retain_graph);
return; return;
} else if (is_ostream_object(file)) { } else if (is_ostream_object(file)) {
auto writer = std::make_shared<PyFileWriter>(file); auto writer = std::make_shared<PyFileWriter>(file);
{ {
py::gil_scoped_release gil; py::gil_scoped_release gil;
save_safetensor(writer, arrays_map); save_safetensor(writer, arrays_map, retain_graph);
} }
return; return;

View File

@ -17,7 +17,10 @@ using DictOrArray = std::variant<array, std::unordered_map<std::string, array>>;
std::unordered_map<std::string, array> mlx_load_safetensor_helper( std::unordered_map<std::string, array> mlx_load_safetensor_helper(
py::object file, py::object file,
StreamOrDevice s); StreamOrDevice s);
void mlx_save_safetensor_helper(py::object file, py::dict d); void mlx_save_safetensor_helper(
py::object file,
py::dict d,
bool retain_graph = true);
DictOrArray mlx_load_helper(py::object file, StreamOrDevice s); DictOrArray mlx_load_helper(py::object file, StreamOrDevice s);
void mlx_save_helper( void mlx_save_helper(

View File

@ -2953,15 +2953,18 @@ void init_ops(py::module_& m) {
"file"_a, "file"_a,
"d"_a, "d"_a,
py::pos_only(), py::pos_only(),
"retain_graph"_a = true,
py::kw_only(), py::kw_only(),
R"pbdoc( R"pbdoc(
save_safetensor(file: str, d: Dict[str, array], /, *) save_safetensor(file: str, d: Dict[str, array], /, retain_graph: bool = True, *)
Save array(s) to a binary file in ``.safetensors`` format. Save array(s) to a binary file in ``.safetensors`` format.
Args: Args:
file (file, str): File in which the array is saved file (file, str): File in which the array is saved
d (Dict[str, array]): The dict mapping name to array to be saved d (Dict[str, array]): The dict mapping name to array to be saved
retain_graph(bool): Optional argument to retain graph
during array evaluation before saving. Default: True
)pbdoc"); )pbdoc");
m.def( m.def(
"where", "where",