mirror of
				https://github.com/ml-explore/mlx.git
				synced 2025-10-31 16:21:27 +08:00 
			
		
		
		
	fix saving for non-contiguous arrays (#389)
This commit is contained in:
		| @@ -50,9 +50,14 @@ void save(std::shared_ptr<io::Writer> out_stream, array a, bool retain_graph) { | ||||
|     throw std::invalid_argument("[save] cannot serialize an empty array"); | ||||
|   } | ||||
|  | ||||
|   if (!a.flags().contiguous) { | ||||
|   if (!(a.flags().row_contiguous || a.flags().col_contiguous)) { | ||||
|     a = reshape(flatten(a), a.shape()); | ||||
|     a.eval(retain_graph); | ||||
|   } | ||||
|   // Check once more in-case the above ops change | ||||
|   if (!(a.flags().row_contiguous || a.flags().col_contiguous)) { | ||||
|     throw std::invalid_argument( | ||||
|         "[save] cannot serialize a non-contiguous array"); | ||||
|         "[save] can only serialize row or col contiguous arrays"); | ||||
|   } | ||||
|  | ||||
|   //////////////////////////////////////////////////////// | ||||
|   | ||||
| @@ -142,17 +142,26 @@ void save_safetensors( | ||||
|   }); | ||||
|   size_t offset = 0; | ||||
|   for (auto& [key, arr] : a) { | ||||
|     arr.eval(retain_graph_.value_or(arr.is_tracer())); | ||||
|     auto retain = retain_graph_.value_or(arr.is_tracer()); | ||||
|     arr.eval(retain); | ||||
|     if (arr.nbytes() == 0) { | ||||
|       throw std::invalid_argument( | ||||
|           "[save_safetensors] cannot serialize an empty array key: " + key); | ||||
|     } | ||||
|  | ||||
|     if (!arr.flags().contiguous) { | ||||
|       throw std::invalid_argument( | ||||
|           "[save_safetensors] cannot serialize a non-contiguous array key: " + | ||||
|           key); | ||||
|     // Try to make it row contiguous | ||||
|     if (!arr.flags().row_contiguous) { | ||||
|       arr = reshape(flatten(arr), arr.shape()); | ||||
|       arr.eval(retain); | ||||
|     } | ||||
|  | ||||
|     // Has to be row-major now but, check one more time in case | ||||
|     // any of the above change in the future | ||||
|     if (!arr.flags().row_contiguous) { | ||||
|       throw std::invalid_argument( | ||||
|           "[save_safetensors] can only serialize row-major arrays"); | ||||
|     } | ||||
|  | ||||
|     json child; | ||||
|     child["dtype"] = dtype_to_safetensor_str(arr.dtype()); | ||||
|     child["shape"] = arr.shape(); | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 Awni Hannun
					Awni Hannun