mirror of
				https://github.com/ml-explore/mlx.git
				synced 2025-10-31 16:21:27 +08:00 
			
		
		
		
	Removes the retain_graph flag (#385)
				
					
				
			* Adds global tracing flag * Removes retain_graph in favor of is_tracer
This commit is contained in:
		 Angelos Katharopoulos
					Angelos Katharopoulos
				
			
				
					committed by
					
						 GitHub
						GitHub
					
				
			
			
				
	
			
			
			 GitHub
						GitHub
					
				
			
						parent
						
							449b43762e
						
					
				
				
					commit
					a611b0bc82
				
			| @@ -40,11 +40,11 @@ inline bool is_big_endian_() { | ||||
| } // namespace | ||||
|  | ||||
| /** Save array to out stream in .npy format */ | ||||
| void save(std::shared_ptr<io::Writer> out_stream, array a, bool retain_graph) { | ||||
| void save(std::shared_ptr<io::Writer> out_stream, array a) { | ||||
|   //////////////////////////////////////////////////////// | ||||
|   // Check array | ||||
|  | ||||
|   a.eval(retain_graph); | ||||
|   a.eval(); | ||||
|  | ||||
|   if (a.nbytes() == 0) { | ||||
|     throw std::invalid_argument("[save] cannot serialize an empty array"); | ||||
| @@ -52,7 +52,7 @@ void save(std::shared_ptr<io::Writer> out_stream, array a, bool retain_graph) { | ||||
|  | ||||
|   if (!(a.flags().row_contiguous || a.flags().col_contiguous)) { | ||||
|     a = reshape(flatten(a), a.shape()); | ||||
|     a.eval(retain_graph); | ||||
|     a.eval(); | ||||
|   } | ||||
|   // Check once more in-case the above ops change | ||||
|   if (!(a.flags().row_contiguous || a.flags().col_contiguous)) { | ||||
| @@ -127,7 +127,7 @@ void save(std::shared_ptr<io::Writer> out_stream, array a, bool retain_graph) { | ||||
| } | ||||
|  | ||||
| /** Save array to file in .npy format */ | ||||
| void save(const std::string& file_, array a, bool retain_graph) { | ||||
| void save(const std::string& file_, array a) { | ||||
|   // Open and check file | ||||
|   std::string file = file_; | ||||
|  | ||||
| @@ -136,7 +136,7 @@ void save(const std::string& file_, array a, bool retain_graph) { | ||||
|     file += ".npy"; | ||||
|  | ||||
|   // Serialize array | ||||
|   save(std::make_shared<io::FileWriter>(file), a, retain_graph); | ||||
|   save(std::make_shared<io::FileWriter>(file), a); | ||||
| } | ||||
|  | ||||
| /** Load array from reader in .npy format */ | ||||
|   | ||||
| @@ -111,4 +111,4 @@ class FileWriter : public Writer { | ||||
| }; | ||||
|  | ||||
| } // namespace io | ||||
| } // namespace mlx::core | ||||
| } // namespace mlx::core | ||||
|   | ||||
| @@ -125,8 +125,7 @@ std::unordered_map<std::string, array> load_safetensors( | ||||
| /** Save array to out stream in .npy format */ | ||||
| void save_safetensors( | ||||
|     std::shared_ptr<io::Writer> out_stream, | ||||
|     std::unordered_map<std::string, array> a, | ||||
|     std::optional<bool> retain_graph_) { | ||||
|     std::unordered_map<std::string, array> a) { | ||||
|   //////////////////////////////////////////////////////// | ||||
|   // Check file | ||||
|   if (!out_stream->good() || !out_stream->is_open()) { | ||||
| @@ -142,8 +141,7 @@ void save_safetensors( | ||||
|   }); | ||||
|   size_t offset = 0; | ||||
|   for (auto& [key, arr] : a) { | ||||
|     auto retain = retain_graph_.value_or(arr.is_tracer()); | ||||
|     arr.eval(retain); | ||||
|     arr.eval(); | ||||
|     if (arr.nbytes() == 0) { | ||||
|       throw std::invalid_argument( | ||||
|           "[save_safetensors] cannot serialize an empty array key: " + key); | ||||
| @@ -152,7 +150,7 @@ void save_safetensors( | ||||
|     // Try to make it row contiguous | ||||
|     if (!arr.flags().row_contiguous) { | ||||
|       arr = reshape(flatten(arr), arr.shape()); | ||||
|       arr.eval(retain); | ||||
|       arr.eval(); | ||||
|     } | ||||
|  | ||||
|     // Has to be row-major now but, check one more time in case | ||||
| @@ -181,8 +179,7 @@ void save_safetensors( | ||||
|  | ||||
| void save_safetensors( | ||||
|     const std::string& file_, | ||||
|     std::unordered_map<std::string, array> a, | ||||
|     std::optional<bool> retain_graph) { | ||||
|     std::unordered_map<std::string, array> a) { | ||||
|   // Open and check file | ||||
|   std::string file = file_; | ||||
|  | ||||
| @@ -192,7 +189,7 @@ void save_safetensors( | ||||
|     file += ".safetensors"; | ||||
|  | ||||
|   // Serialize array | ||||
|   save_safetensors(std::make_shared<io::FileWriter>(file), a, retain_graph); | ||||
|   save_safetensors(std::make_shared<io::FileWriter>(file), a); | ||||
| } | ||||
|  | ||||
| } // namespace mlx::core | ||||
|   | ||||
		Reference in New Issue
	
	Block a user