fix saving for non-contiguous arrays (#389)

This commit is contained in:
Awni Hannun
2024-01-06 12:44:02 -08:00
committed by GitHub
parent 608bd43604
commit b34bf5d52b
3 changed files with 44 additions and 7 deletions

View File

@@ -50,9 +50,14 @@ void save(std::shared_ptr<io::Writer> out_stream, array a, bool retain_graph) {
throw std::invalid_argument("[save] cannot serialize an empty array");
}
if (!a.flags().contiguous) {
if (!(a.flags().row_contiguous || a.flags().col_contiguous)) {
a = reshape(flatten(a), a.shape());
a.eval(retain_graph);
}
// Check once more in-case the above ops change
if (!(a.flags().row_contiguous || a.flags().col_contiguous)) {
throw std::invalid_argument(
"[save] cannot serialize a non-contiguous array");
"[save] can only serialize row or col contiguous arrays");
}
////////////////////////////////////////////////////////