mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Removes the retain_graph flag (#385)
* Adds global tracing flag * Removes retain_graph in favor of is_tracer
This commit is contained in:
committed by
GitHub
parent
449b43762e
commit
a611b0bc82
@@ -40,11 +40,11 @@ inline bool is_big_endian_() {
|
||||
} // namespace
|
||||
|
||||
/** Save array to out stream in .npy format */
|
||||
void save(std::shared_ptr<io::Writer> out_stream, array a, bool retain_graph) {
|
||||
void save(std::shared_ptr<io::Writer> out_stream, array a) {
|
||||
////////////////////////////////////////////////////////
|
||||
// Check array
|
||||
|
||||
a.eval(retain_graph);
|
||||
a.eval();
|
||||
|
||||
if (a.nbytes() == 0) {
|
||||
throw std::invalid_argument("[save] cannot serialize an empty array");
|
||||
@@ -52,7 +52,7 @@ void save(std::shared_ptr<io::Writer> out_stream, array a, bool retain_graph) {
|
||||
|
||||
if (!(a.flags().row_contiguous || a.flags().col_contiguous)) {
|
||||
a = reshape(flatten(a), a.shape());
|
||||
a.eval(retain_graph);
|
||||
a.eval();
|
||||
}
|
||||
// Check once more in-case the above ops change
|
||||
if (!(a.flags().row_contiguous || a.flags().col_contiguous)) {
|
||||
@@ -127,7 +127,7 @@ void save(std::shared_ptr<io::Writer> out_stream, array a, bool retain_graph) {
|
||||
}
|
||||
|
||||
/** Save array to file in .npy format */
|
||||
void save(const std::string& file_, array a, bool retain_graph) {
|
||||
void save(const std::string& file_, array a) {
|
||||
// Open and check file
|
||||
std::string file = file_;
|
||||
|
||||
@@ -136,7 +136,7 @@ void save(const std::string& file_, array a, bool retain_graph) {
|
||||
file += ".npy";
|
||||
|
||||
// Serialize array
|
||||
save(std::make_shared<io::FileWriter>(file), a, retain_graph);
|
||||
save(std::make_shared<io::FileWriter>(file), a);
|
||||
}
|
||||
|
||||
/** Load array from reader in .npy format */
|
||||
|
||||
Reference in New Issue
Block a user