mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-19 19:38:16 +08:00
fix saving for non-contiguous arrays (#389)
This commit is contained in:
@@ -50,9 +50,14 @@ void save(std::shared_ptr<io::Writer> out_stream, array a, bool retain_graph) {
|
||||
throw std::invalid_argument("[save] cannot serialize an empty array");
|
||||
}
|
||||
|
||||
if (!a.flags().contiguous) {
|
||||
if (!(a.flags().row_contiguous || a.flags().col_contiguous)) {
|
||||
a = reshape(flatten(a), a.shape());
|
||||
a.eval(retain_graph);
|
||||
}
|
||||
// Check once more in-case the above ops change
|
||||
if (!(a.flags().row_contiguous || a.flags().col_contiguous)) {
|
||||
throw std::invalid_argument(
|
||||
"[save] cannot serialize a non-contiguous array");
|
||||
"[save] can only serialize row or col contiguous arrays");
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////
|
||||
|
@@ -142,17 +142,26 @@ void save_safetensors(
|
||||
});
|
||||
size_t offset = 0;
|
||||
for (auto& [key, arr] : a) {
|
||||
arr.eval(retain_graph_.value_or(arr.is_tracer()));
|
||||
auto retain = retain_graph_.value_or(arr.is_tracer());
|
||||
arr.eval(retain);
|
||||
if (arr.nbytes() == 0) {
|
||||
throw std::invalid_argument(
|
||||
"[save_safetensors] cannot serialize an empty array key: " + key);
|
||||
}
|
||||
|
||||
if (!arr.flags().contiguous) {
|
||||
throw std::invalid_argument(
|
||||
"[save_safetensors] cannot serialize a non-contiguous array key: " +
|
||||
key);
|
||||
// Try to make it row contiguous
|
||||
if (!arr.flags().row_contiguous) {
|
||||
arr = reshape(flatten(arr), arr.shape());
|
||||
arr.eval(retain);
|
||||
}
|
||||
|
||||
// Has to be row-major now but, check one more time in case
|
||||
// any of the above change in the future
|
||||
if (!arr.flags().row_contiguous) {
|
||||
throw std::invalid_argument(
|
||||
"[save_safetensors] can only serialize row-major arrays");
|
||||
}
|
||||
|
||||
json child;
|
||||
child["dtype"] = dtype_to_safetensor_str(arr.dtype());
|
||||
child["shape"] = arr.shape();
|
||||
|
Reference in New Issue
Block a user