2023-12-01 03:12:53 +08:00
|
|
|
# Copyright © 2023 Apple Inc.
|
|
|
|
|
2023-11-30 02:52:08 +08:00
|
|
|
import os
|
|
|
|
import tempfile
|
2023-12-09 03:31:47 +08:00
|
|
|
import unittest
|
2023-11-30 02:52:08 +08:00
|
|
|
|
2023-12-09 03:31:47 +08:00
|
|
|
import mlx.core as mx
|
2023-11-30 02:52:08 +08:00
|
|
|
import mlx_tests
|
2023-12-09 03:31:47 +08:00
|
|
|
import numpy as np
|
2023-11-30 02:52:08 +08:00
|
|
|
|
|
|
|
|
|
|
|
class TestLoad(mlx_tests.MLXTestCase):
|
|
|
|
dtypes = [
|
|
|
|
"uint8",
|
|
|
|
"uint16",
|
|
|
|
"uint32",
|
|
|
|
"uint64",
|
|
|
|
"int8",
|
|
|
|
"int16",
|
|
|
|
"int32",
|
|
|
|
"int64",
|
|
|
|
"float32",
|
|
|
|
"float16",
|
|
|
|
"complex64",
|
|
|
|
]
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
def setUpClass(cls):
|
|
|
|
cls.test_dir_fid = tempfile.TemporaryDirectory()
|
|
|
|
cls.test_dir = cls.test_dir_fid.name
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
def tearDownClass(cls):
|
|
|
|
cls.test_dir_fid.cleanup()
|
|
|
|
|
|
|
|
def test_save_and_load(self):
|
|
|
|
if not os.path.isdir(self.test_dir):
|
|
|
|
os.mkdir(self.test_dir)
|
|
|
|
|
|
|
|
for dt in self.dtypes:
|
|
|
|
with self.subTest(dtype=dt):
|
|
|
|
for i, shape in enumerate([(1,), (23,), (1024, 1024), (4, 6, 3, 1, 2)]):
|
|
|
|
with self.subTest(shape=shape):
|
|
|
|
save_file_mlx = os.path.join(self.test_dir, f"mlx_{dt}_{i}.npy")
|
|
|
|
save_file_npy = os.path.join(self.test_dir, f"npy_{dt}_{i}.npy")
|
|
|
|
|
|
|
|
save_arr = np.random.uniform(0.0, 32.0, size=shape)
|
|
|
|
save_arr_npy = save_arr.astype(getattr(np, dt))
|
|
|
|
save_arr_mlx = mx.array(save_arr_npy)
|
|
|
|
|
|
|
|
mx.save(save_file_mlx, save_arr_mlx)
|
|
|
|
np.save(save_file_npy, save_arr_npy)
|
|
|
|
|
|
|
|
# Load array saved by mlx as mlx array
|
|
|
|
load_arr_mlx_mlx = mx.load(save_file_mlx)
|
|
|
|
self.assertTrue(mx.array_equal(load_arr_mlx_mlx, save_arr_mlx))
|
|
|
|
|
|
|
|
# Load array saved by numpy as mlx array
|
|
|
|
load_arr_npy_mlx = mx.load(save_file_npy)
|
|
|
|
self.assertTrue(mx.array_equal(load_arr_npy_mlx, save_arr_mlx))
|
|
|
|
|
|
|
|
# Load array saved by mlx as numpy array
|
|
|
|
load_arr_mlx_npy = np.load(save_file_mlx)
|
|
|
|
self.assertTrue(np.array_equal(load_arr_mlx_npy, save_arr_npy))
|
|
|
|
|
2023-12-27 18:06:55 +08:00
|
|
|
def test_save_and_load_safetensors(self):
|
|
|
|
if not os.path.isdir(self.test_dir):
|
|
|
|
os.mkdir(self.test_dir)
|
|
|
|
|
|
|
|
for dt in self.dtypes + ["bfloat16"]:
|
|
|
|
with self.subTest(dtype=dt):
|
|
|
|
for i, shape in enumerate([(1,), (23,), (1024, 1024), (4, 6, 3, 1, 2)]):
|
|
|
|
with self.subTest(shape=shape):
|
|
|
|
save_file_mlx = os.path.join(
|
|
|
|
self.test_dir, f"mlx_{dt}_{i}_fs.safetensors"
|
|
|
|
)
|
|
|
|
save_dict = {
|
|
|
|
"test": mx.random.normal(shape=shape, dtype=getattr(mx, dt))
|
|
|
|
if dt in ["float32", "float16", "bfloat16"]
|
|
|
|
else mx.ones(shape, dtype=getattr(mx, dt))
|
|
|
|
}
|
|
|
|
|
|
|
|
with open(save_file_mlx, "wb") as f:
|
|
|
|
mx.save_safetensors(f, save_dict)
|
|
|
|
with open(save_file_mlx, "rb") as f:
|
|
|
|
load_dict = mx.load(f)
|
|
|
|
|
|
|
|
self.assertTrue("test" in load_dict)
|
|
|
|
self.assertTrue(
|
|
|
|
mx.array_equal(load_dict["test"], save_dict["test"])
|
|
|
|
)
|
|
|
|
|
2024-01-11 05:22:48 +08:00
|
|
|
def test_save_and_load_gguf(self):
|
|
|
|
if not os.path.isdir(self.test_dir):
|
|
|
|
os.mkdir(self.test_dir)
|
|
|
|
|
|
|
|
# TODO: Add support for other dtypes (self.dtypes + ["bfloat16"])
|
|
|
|
supported_dtypes = ["float16", "float32", "int8", "int16", "int32"]
|
|
|
|
for dt in supported_dtypes:
|
|
|
|
with self.subTest(dtype=dt):
|
|
|
|
for i, shape in enumerate([(1,), (23,), (1024, 1024), (4, 6, 3, 1, 2)]):
|
|
|
|
with self.subTest(shape=shape):
|
|
|
|
save_file_mlx = os.path.join(
|
|
|
|
self.test_dir, f"mlx_{dt}_{i}_fs.gguf"
|
|
|
|
)
|
|
|
|
save_dict = {
|
|
|
|
"test": mx.random.normal(shape=shape, dtype=getattr(mx, dt))
|
|
|
|
if dt in ["float32", "float16", "bfloat16"]
|
|
|
|
else mx.ones(shape, dtype=getattr(mx, dt))
|
|
|
|
}
|
|
|
|
|
|
|
|
mx.save_gguf(save_file_mlx, save_dict)
|
|
|
|
load_dict = mx.load(save_file_mlx)
|
|
|
|
|
|
|
|
self.assertTrue("test" in load_dict)
|
|
|
|
self.assertTrue(
|
|
|
|
mx.array_equal(load_dict["test"], save_dict["test"])
|
|
|
|
)
|
|
|
|
|
2023-11-30 02:52:08 +08:00
|
|
|
def test_save_and_load_fs(self):
|
|
|
|
if not os.path.isdir(self.test_dir):
|
|
|
|
os.mkdir(self.test_dir)
|
|
|
|
|
|
|
|
for dt in self.dtypes:
|
|
|
|
with self.subTest(dtype=dt):
|
|
|
|
for i, shape in enumerate([(1,), (23,), (1024, 1024), (4, 6, 3, 1, 2)]):
|
|
|
|
with self.subTest(shape=shape):
|
|
|
|
save_file_mlx = os.path.join(
|
|
|
|
self.test_dir, f"mlx_{dt}_{i}_fs.npy"
|
|
|
|
)
|
|
|
|
save_file_npy = os.path.join(
|
|
|
|
self.test_dir, f"npy_{dt}_{i}_fs.npy"
|
|
|
|
)
|
|
|
|
|
|
|
|
save_arr = np.random.uniform(0.0, 32.0, size=shape)
|
|
|
|
save_arr_npy = save_arr.astype(getattr(np, dt))
|
|
|
|
save_arr_mlx = mx.array(save_arr_npy)
|
|
|
|
|
|
|
|
with open(save_file_mlx, "wb") as f:
|
|
|
|
mx.save(f, save_arr_mlx)
|
|
|
|
|
|
|
|
np.save(save_file_npy, save_arr_npy)
|
|
|
|
|
|
|
|
# Load array saved by mlx as mlx array
|
|
|
|
with open(save_file_mlx, "rb") as f:
|
|
|
|
load_arr_mlx_mlx = mx.load(f)
|
|
|
|
self.assertTrue(mx.array_equal(load_arr_mlx_mlx, save_arr_mlx))
|
|
|
|
|
|
|
|
# Load array saved by numpy as mlx array
|
|
|
|
with open(save_file_npy, "rb") as f:
|
|
|
|
load_arr_npy_mlx = mx.load(f)
|
|
|
|
self.assertTrue(mx.array_equal(load_arr_npy_mlx, save_arr_mlx))
|
|
|
|
|
|
|
|
# Load array saved by mlx as numpy array
|
|
|
|
load_arr_mlx_npy = np.load(save_file_mlx)
|
|
|
|
self.assertTrue(np.array_equal(load_arr_mlx_npy, save_arr_npy))
|
|
|
|
|
|
|
|
def test_savez_and_loadz(self):
|
|
|
|
if not os.path.isdir(self.test_dir):
|
|
|
|
os.mkdir(self.test_dir)
|
|
|
|
|
|
|
|
for dt in self.dtypes:
|
|
|
|
with self.subTest(dtype=dt):
|
|
|
|
shapes = [(6,), (6, 6), (4, 1, 3, 1, 2)]
|
|
|
|
save_file_mlx_uncomp = os.path.join(
|
|
|
|
self.test_dir, f"mlx_{dt}_uncomp.npz"
|
|
|
|
)
|
|
|
|
save_file_npy_uncomp = os.path.join(
|
|
|
|
self.test_dir, f"npy_{dt}_uncomp.npz"
|
|
|
|
)
|
|
|
|
save_file_mlx_comp = os.path.join(self.test_dir, f"mlx_{dt}_comp.npz")
|
|
|
|
save_file_npy_comp = os.path.join(self.test_dir, f"npy_{dt}_comp.npz")
|
|
|
|
|
|
|
|
# Make dictionary of multiple
|
|
|
|
save_arrs_npy = {
|
|
|
|
f"save_arr_{i}": np.random.uniform(
|
|
|
|
0.0, 32.0, size=shapes[i]
|
|
|
|
).astype(getattr(np, dt))
|
|
|
|
for i in range(len(shapes))
|
|
|
|
}
|
|
|
|
save_arrs_mlx = {k: mx.array(v) for k, v in save_arrs_npy.items()}
|
|
|
|
|
|
|
|
# Save as npz files
|
|
|
|
np.savez(save_file_npy_uncomp, **save_arrs_npy)
|
|
|
|
mx.savez(save_file_mlx_uncomp, **save_arrs_mlx)
|
|
|
|
np.savez_compressed(save_file_npy_comp, **save_arrs_npy)
|
|
|
|
mx.savez_compressed(save_file_mlx_comp, **save_arrs_mlx)
|
|
|
|
|
|
|
|
for save_file_npy, save_file_mlx in (
|
|
|
|
(save_file_npy_uncomp, save_file_mlx_uncomp),
|
|
|
|
(save_file_npy_comp, save_file_mlx_comp),
|
|
|
|
):
|
|
|
|
# Load array saved by mlx as mlx array
|
|
|
|
load_arr_mlx_mlx = mx.load(save_file_mlx)
|
|
|
|
for k, v in load_arr_mlx_mlx.items():
|
|
|
|
self.assertTrue(mx.array_equal(save_arrs_mlx[k], v))
|
|
|
|
|
|
|
|
# Load arrays saved by numpy as mlx arrays
|
|
|
|
load_arr_npy_mlx = mx.load(save_file_npy)
|
|
|
|
for k, v in load_arr_npy_mlx.items():
|
|
|
|
self.assertTrue(mx.array_equal(save_arrs_mlx[k], v))
|
|
|
|
|
|
|
|
# Load array saved by mlx as numpy array
|
|
|
|
load_arr_mlx_npy = np.load(save_file_mlx)
|
|
|
|
for k, v in load_arr_mlx_npy.items():
|
|
|
|
self.assertTrue(np.array_equal(save_arrs_npy[k], v))
|
|
|
|
|
2024-01-07 04:44:02 +08:00
|
|
|
def test_non_contiguous(self):
|
|
|
|
if not os.path.isdir(self.test_dir):
|
|
|
|
os.mkdir(self.test_dir)
|
|
|
|
|
|
|
|
a = mx.broadcast_to(mx.array([1, 2]), [4, 2])
|
|
|
|
|
|
|
|
save_file = os.path.join(self.test_dir, "a.npy")
|
|
|
|
mx.save(save_file, a)
|
|
|
|
aload = mx.load(save_file)
|
|
|
|
self.assertTrue(mx.array_equal(a, aload))
|
|
|
|
|
|
|
|
save_file = os.path.join(self.test_dir, "a.safetensors")
|
|
|
|
mx.save_safetensors(save_file, {"a": a})
|
|
|
|
aload = mx.load(save_file)["a"]
|
|
|
|
self.assertTrue(mx.array_equal(a, aload))
|
|
|
|
|
2024-01-11 05:22:48 +08:00
|
|
|
save_file = os.path.join(self.test_dir, "a.gguf")
|
|
|
|
mx.save_gguf(save_file, {"a": a})
|
|
|
|
aload = mx.load(save_file)["a"]
|
|
|
|
self.assertTrue(mx.array_equal(a, aload))
|
|
|
|
|
|
|
|
# safetensors and gguf only work with row contiguous
|
2024-01-07 04:44:02 +08:00
|
|
|
# make sure col contiguous is handled properly
|
2024-01-11 05:22:48 +08:00
|
|
|
save_file = os.path.join(self.test_dir, "a.safetensors")
|
2024-01-07 04:44:02 +08:00
|
|
|
a = mx.arange(4).reshape(2, 2).T
|
|
|
|
mx.save_safetensors(save_file, {"a": a})
|
|
|
|
aload = mx.load(save_file)["a"]
|
|
|
|
self.assertTrue(mx.array_equal(a, aload))
|
|
|
|
|
2024-01-11 05:22:48 +08:00
|
|
|
save_file = os.path.join(self.test_dir, "a.gguf")
|
|
|
|
mx.save_gguf(save_file, {"a": a})
|
|
|
|
aload = mx.load(save_file)["a"]
|
|
|
|
self.assertTrue(mx.array_equal(a, aload))
|
|
|
|
|
2023-11-30 02:52:08 +08:00
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
|
unittest.main()
|