From b34bf5d52b8f71c2e8e53dd82831b8333259294c Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Sat, 6 Jan 2024 12:44:02 -0800 Subject: [PATCH] fix saving for non-contiguous arrays (#389) --- mlx/io/load.cpp | 9 +++++++-- mlx/io/safetensor.cpp | 19 ++++++++++++++----- python/tests/test_load.py | 23 +++++++++++++++++++++++ 3 files changed, 44 insertions(+), 7 deletions(-) diff --git a/mlx/io/load.cpp b/mlx/io/load.cpp index 74e0784f8..27c425455 100644 --- a/mlx/io/load.cpp +++ b/mlx/io/load.cpp @@ -50,9 +50,14 @@ void save(std::shared_ptr out_stream, array a, bool retain_graph) { throw std::invalid_argument("[save] cannot serialize an empty array"); } - if (!a.flags().contiguous) { + if (!(a.flags().row_contiguous || a.flags().col_contiguous)) { + a = reshape(flatten(a), a.shape()); + a.eval(retain_graph); + } + // Check once more in-case the above ops change + if (!(a.flags().row_contiguous || a.flags().col_contiguous)) { throw std::invalid_argument( - "[save] cannot serialize a non-contiguous array"); + "[save] can only serialize row or col contiguous arrays"); } //////////////////////////////////////////////////////// diff --git a/mlx/io/safetensor.cpp b/mlx/io/safetensor.cpp index a690e6420..bb78be797 100644 --- a/mlx/io/safetensor.cpp +++ b/mlx/io/safetensor.cpp @@ -142,17 +142,26 @@ void save_safetensors( }); size_t offset = 0; for (auto& [key, arr] : a) { - arr.eval(retain_graph_.value_or(arr.is_tracer())); + auto retain = retain_graph_.value_or(arr.is_tracer()); + arr.eval(retain); if (arr.nbytes() == 0) { throw std::invalid_argument( "[save_safetensors] cannot serialize an empty array key: " + key); } - if (!arr.flags().contiguous) { - throw std::invalid_argument( - "[save_safetensors] cannot serialize a non-contiguous array key: " + - key); + // Try to make it row contiguous + if (!arr.flags().row_contiguous) { + arr = reshape(flatten(arr), arr.shape()); + arr.eval(retain); } + + // Has to be row-major now but, check one more time in case + // any of the above change in the future + if (!arr.flags().row_contiguous) { + throw std::invalid_argument( + "[save_safetensors] can only serialize row-major arrays"); + } + json child; child["dtype"] = dtype_to_safetensor_str(arr.dtype()); child["shape"] = arr.shape(); diff --git a/python/tests/test_load.py b/python/tests/test_load.py index d1638422b..66cf4aa4e 100644 --- a/python/tests/test_load.py +++ b/python/tests/test_load.py @@ -178,6 +178,29 @@ class TestLoad(mlx_tests.MLXTestCase): for k, v in load_arr_mlx_npy.items(): self.assertTrue(np.array_equal(save_arrs_npy[k], v)) + def test_non_contiguous(self): + if not os.path.isdir(self.test_dir): + os.mkdir(self.test_dir) + + a = mx.broadcast_to(mx.array([1, 2]), [4, 2]) + + save_file = os.path.join(self.test_dir, "a.npy") + mx.save(save_file, a) + aload = mx.load(save_file) + self.assertTrue(mx.array_equal(a, aload)) + + save_file = os.path.join(self.test_dir, "a.safetensors") + mx.save_safetensors(save_file, {"a": a}) + aload = mx.load(save_file)["a"] + self.assertTrue(mx.array_equal(a, aload)) + + # safetensors only works with row contiguous + # make sure col contiguous is handled properly + a = mx.arange(4).reshape(2, 2).T + mx.save_safetensors(save_file, {"a": a}) + aload = mx.load(save_file)["a"] + self.assertTrue(mx.array_equal(a, aload)) + if __name__ == "__main__": unittest.main()