mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-05 07:34:42 +08:00
fix saving for non-contiguous arrays (#389)
This commit is contained in:
@@ -178,6 +178,29 @@ class TestLoad(mlx_tests.MLXTestCase):
|
||||
for k, v in load_arr_mlx_npy.items():
|
||||
self.assertTrue(np.array_equal(save_arrs_npy[k], v))
|
||||
|
||||
def test_non_contiguous(self):
|
||||
if not os.path.isdir(self.test_dir):
|
||||
os.mkdir(self.test_dir)
|
||||
|
||||
a = mx.broadcast_to(mx.array([1, 2]), [4, 2])
|
||||
|
||||
save_file = os.path.join(self.test_dir, "a.npy")
|
||||
mx.save(save_file, a)
|
||||
aload = mx.load(save_file)
|
||||
self.assertTrue(mx.array_equal(a, aload))
|
||||
|
||||
save_file = os.path.join(self.test_dir, "a.safetensors")
|
||||
mx.save_safetensors(save_file, {"a": a})
|
||||
aload = mx.load(save_file)["a"]
|
||||
self.assertTrue(mx.array_equal(a, aload))
|
||||
|
||||
# safetensors only works with row contiguous
|
||||
# make sure col contiguous is handled properly
|
||||
a = mx.arange(4).reshape(2, 2).T
|
||||
mx.save_safetensors(save_file, {"a": a})
|
||||
aload = mx.load(save_file)["a"]
|
||||
self.assertTrue(mx.array_equal(a, aload))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
Reference in New Issue
Block a user