mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-04 06:44:40 +08:00
@@ -64,6 +64,33 @@ class TestLoad(mlx_tests.MLXTestCase):
|
||||
load_arr_mlx_npy = np.load(save_file_mlx)
|
||||
self.assertTrue(np.array_equal(load_arr_mlx_npy, save_arr_npy))
|
||||
|
||||
def test_save_and_load_safetensors(self):
|
||||
if not os.path.isdir(self.test_dir):
|
||||
os.mkdir(self.test_dir)
|
||||
|
||||
for dt in self.dtypes + ["bfloat16"]:
|
||||
with self.subTest(dtype=dt):
|
||||
for i, shape in enumerate([(1,), (23,), (1024, 1024), (4, 6, 3, 1, 2)]):
|
||||
with self.subTest(shape=shape):
|
||||
save_file_mlx = os.path.join(
|
||||
self.test_dir, f"mlx_{dt}_{i}_fs.safetensors"
|
||||
)
|
||||
save_dict = {
|
||||
"test": mx.random.normal(shape=shape, dtype=getattr(mx, dt))
|
||||
if dt in ["float32", "float16", "bfloat16"]
|
||||
else mx.ones(shape, dtype=getattr(mx, dt))
|
||||
}
|
||||
|
||||
with open(save_file_mlx, "wb") as f:
|
||||
mx.save_safetensors(f, save_dict)
|
||||
with open(save_file_mlx, "rb") as f:
|
||||
load_dict = mx.load(f)
|
||||
|
||||
self.assertTrue("test" in load_dict)
|
||||
self.assertTrue(
|
||||
mx.array_equal(load_dict["test"], save_dict["test"])
|
||||
)
|
||||
|
||||
def test_save_and_load_fs(self):
|
||||
|
||||
if not os.path.isdir(self.test_dir):
|
||||
|
Reference in New Issue
Block a user