fix saving for non-contiguous arrays (#389)

This commit is contained in:
Awni Hannun
2024-01-06 12:44:02 -08:00
committed by GitHub
parent 608bd43604
commit b34bf5d52b
3 changed files with 44 additions and 7 deletions

View File

@@ -178,6 +178,29 @@ class TestLoad(mlx_tests.MLXTestCase):
for k, v in load_arr_mlx_npy.items():
self.assertTrue(np.array_equal(save_arrs_npy[k], v))
def test_non_contiguous(self):
if not os.path.isdir(self.test_dir):
os.mkdir(self.test_dir)
a = mx.broadcast_to(mx.array([1, 2]), [4, 2])
save_file = os.path.join(self.test_dir, "a.npy")
mx.save(save_file, a)
aload = mx.load(save_file)
self.assertTrue(mx.array_equal(a, aload))
save_file = os.path.join(self.test_dir, "a.safetensors")
mx.save_safetensors(save_file, {"a": a})
aload = mx.load(save_file)["a"]
self.assertTrue(mx.array_equal(a, aload))
# safetensors only works with row contiguous
# make sure col contiguous is handled properly
a = mx.arange(4).reshape(2, 2).T
mx.save_safetensors(save_file, {"a": a})
aload = mx.load(save_file)["a"]
self.assertTrue(mx.array_equal(a, aload))
if __name__ == "__main__":
unittest.main()