Removes the retain_graph flag (#385)

* Adds global tracing flag
* Removes retain_graph in favor of is_tracer
This commit is contained in:
Angelos Katharopoulos
2024-01-07 15:16:51 -08:00
committed by GitHub
parent 449b43762e
commit a611b0bc82
22 changed files with 209 additions and 207 deletions

View File

@@ -40,11 +40,11 @@ inline bool is_big_endian_() {
} // namespace
/** Save array to out stream in .npy format */
void save(std::shared_ptr<io::Writer> out_stream, array a, bool retain_graph) {
void save(std::shared_ptr<io::Writer> out_stream, array a) {
////////////////////////////////////////////////////////
// Check array
a.eval(retain_graph);
a.eval();
if (a.nbytes() == 0) {
throw std::invalid_argument("[save] cannot serialize an empty array");
@@ -52,7 +52,7 @@ void save(std::shared_ptr<io::Writer> out_stream, array a, bool retain_graph) {
if (!(a.flags().row_contiguous || a.flags().col_contiguous)) {
a = reshape(flatten(a), a.shape());
a.eval(retain_graph);
a.eval();
}
// Check once more in-case the above ops change
if (!(a.flags().row_contiguous || a.flags().col_contiguous)) {
@@ -127,7 +127,7 @@ void save(std::shared_ptr<io::Writer> out_stream, array a, bool retain_graph) {
}
/** Save array to file in .npy format */
void save(const std::string& file_, array a, bool retain_graph) {
void save(const std::string& file_, array a) {
// Open and check file
std::string file = file_;
@@ -136,7 +136,7 @@ void save(const std::string& file_, array a, bool retain_graph) {
file += ".npy";
// Serialize array
save(std::make_shared<io::FileWriter>(file), a, retain_graph);
save(std::make_shared<io::FileWriter>(file), a);
}
/** Load array from reader in .npy format */

View File

@@ -111,4 +111,4 @@ class FileWriter : public Writer {
};
} // namespace io
} // namespace mlx::core
} // namespace mlx::core

View File

@@ -125,8 +125,7 @@ std::unordered_map<std::string, array> load_safetensors(
/** Save array to out stream in .npy format */
void save_safetensors(
std::shared_ptr<io::Writer> out_stream,
std::unordered_map<std::string, array> a,
std::optional<bool> retain_graph_) {
std::unordered_map<std::string, array> a) {
////////////////////////////////////////////////////////
// Check file
if (!out_stream->good() || !out_stream->is_open()) {
@@ -142,8 +141,7 @@ void save_safetensors(
});
size_t offset = 0;
for (auto& [key, arr] : a) {
auto retain = retain_graph_.value_or(arr.is_tracer());
arr.eval(retain);
arr.eval();
if (arr.nbytes() == 0) {
throw std::invalid_argument(
"[save_safetensors] cannot serialize an empty array key: " + key);
@@ -152,7 +150,7 @@ void save_safetensors(
// Try to make it row contiguous
if (!arr.flags().row_contiguous) {
arr = reshape(flatten(arr), arr.shape());
arr.eval(retain);
arr.eval();
}
// Has to be row-major now but, check one more time in case
@@ -181,8 +179,7 @@ void save_safetensors(
void save_safetensors(
const std::string& file_,
std::unordered_map<std::string, array> a,
std::optional<bool> retain_graph) {
std::unordered_map<std::string, array> a) {
// Open and check file
std::string file = file_;
@@ -192,7 +189,7 @@ void save_safetensors(
file += ".safetensors";
// Serialize array
save_safetensors(std::make_shared<io::FileWriter>(file), a, retain_graph);
save_safetensors(std::make_shared<io::FileWriter>(file), a);
}
} // namespace mlx::core