mirror of
https://github.com/ml-explore/mlx.git
synced 2025-07-04 00:01:17 +08:00
fix saving for non-contiguous arrays (#389)
This commit is contained in:
parent
608bd43604
commit
b34bf5d52b
@ -50,9 +50,14 @@ void save(std::shared_ptr<io::Writer> out_stream, array a, bool retain_graph) {
|
|||||||
throw std::invalid_argument("[save] cannot serialize an empty array");
|
throw std::invalid_argument("[save] cannot serialize an empty array");
|
||||||
}
|
}
|
||||||
|
|
||||||
if (!a.flags().contiguous) {
|
if (!(a.flags().row_contiguous || a.flags().col_contiguous)) {
|
||||||
|
a = reshape(flatten(a), a.shape());
|
||||||
|
a.eval(retain_graph);
|
||||||
|
}
|
||||||
|
// Check once more in-case the above ops change
|
||||||
|
if (!(a.flags().row_contiguous || a.flags().col_contiguous)) {
|
||||||
throw std::invalid_argument(
|
throw std::invalid_argument(
|
||||||
"[save] cannot serialize a non-contiguous array");
|
"[save] can only serialize row or col contiguous arrays");
|
||||||
}
|
}
|
||||||
|
|
||||||
////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////
|
||||||
|
@ -142,17 +142,26 @@ void save_safetensors(
|
|||||||
});
|
});
|
||||||
size_t offset = 0;
|
size_t offset = 0;
|
||||||
for (auto& [key, arr] : a) {
|
for (auto& [key, arr] : a) {
|
||||||
arr.eval(retain_graph_.value_or(arr.is_tracer()));
|
auto retain = retain_graph_.value_or(arr.is_tracer());
|
||||||
|
arr.eval(retain);
|
||||||
if (arr.nbytes() == 0) {
|
if (arr.nbytes() == 0) {
|
||||||
throw std::invalid_argument(
|
throw std::invalid_argument(
|
||||||
"[save_safetensors] cannot serialize an empty array key: " + key);
|
"[save_safetensors] cannot serialize an empty array key: " + key);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (!arr.flags().contiguous) {
|
// Try to make it row contiguous
|
||||||
throw std::invalid_argument(
|
if (!arr.flags().row_contiguous) {
|
||||||
"[save_safetensors] cannot serialize a non-contiguous array key: " +
|
arr = reshape(flatten(arr), arr.shape());
|
||||||
key);
|
arr.eval(retain);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Has to be row-major now but, check one more time in case
|
||||||
|
// any of the above change in the future
|
||||||
|
if (!arr.flags().row_contiguous) {
|
||||||
|
throw std::invalid_argument(
|
||||||
|
"[save_safetensors] can only serialize row-major arrays");
|
||||||
|
}
|
||||||
|
|
||||||
json child;
|
json child;
|
||||||
child["dtype"] = dtype_to_safetensor_str(arr.dtype());
|
child["dtype"] = dtype_to_safetensor_str(arr.dtype());
|
||||||
child["shape"] = arr.shape();
|
child["shape"] = arr.shape();
|
||||||
|
@ -178,6 +178,29 @@ class TestLoad(mlx_tests.MLXTestCase):
|
|||||||
for k, v in load_arr_mlx_npy.items():
|
for k, v in load_arr_mlx_npy.items():
|
||||||
self.assertTrue(np.array_equal(save_arrs_npy[k], v))
|
self.assertTrue(np.array_equal(save_arrs_npy[k], v))
|
||||||
|
|
||||||
|
def test_non_contiguous(self):
|
||||||
|
if not os.path.isdir(self.test_dir):
|
||||||
|
os.mkdir(self.test_dir)
|
||||||
|
|
||||||
|
a = mx.broadcast_to(mx.array([1, 2]), [4, 2])
|
||||||
|
|
||||||
|
save_file = os.path.join(self.test_dir, "a.npy")
|
||||||
|
mx.save(save_file, a)
|
||||||
|
aload = mx.load(save_file)
|
||||||
|
self.assertTrue(mx.array_equal(a, aload))
|
||||||
|
|
||||||
|
save_file = os.path.join(self.test_dir, "a.safetensors")
|
||||||
|
mx.save_safetensors(save_file, {"a": a})
|
||||||
|
aload = mx.load(save_file)["a"]
|
||||||
|
self.assertTrue(mx.array_equal(a, aload))
|
||||||
|
|
||||||
|
# safetensors only works with row contiguous
|
||||||
|
# make sure col contiguous is handled properly
|
||||||
|
a = mx.arange(4).reshape(2, 2).T
|
||||||
|
mx.save_safetensors(save_file, {"a": a})
|
||||||
|
aload = mx.load(save_file)["a"]
|
||||||
|
self.assertTrue(mx.array_equal(a, aload))
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
unittest.main()
|
unittest.main()
|
||||||
|
Loading…
Reference in New Issue
Block a user