mirror of
https://github.com/ml-explore/mlx.git
synced 2025-08-28 00:36:32 +08:00
Merge c2511fd83a
into 3dcb286baf
This commit is contained in:
commit
2e37f68a99
@ -3,6 +3,7 @@
|
||||
import os
|
||||
import tempfile
|
||||
import unittest
|
||||
from pathlib import Path
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx_tests
|
||||
@ -62,17 +63,28 @@ class TestLoad(mlx_tests.MLXTestCase):
|
||||
load_arr_mlx_npy = np.load(save_file_mlx)
|
||||
self.assertTrue(np.array_equal(load_arr_mlx_npy, save_arr_npy))
|
||||
|
||||
save_file = os.path.join(self.test_dir, f"mlx_path.npy")
|
||||
save_arr = mx.ones((32,))
|
||||
mx.save(Path(save_file), save_arr)
|
||||
|
||||
# Load array saved by mlx as mlx array
|
||||
load_arr = mx.load(Path(save_file))
|
||||
self.assertTrue(mx.array_equal(load_arr, save_arr))
|
||||
|
||||
def test_save_and_load_safetensors(self):
|
||||
test_file = os.path.join(self.test_dir, "test.safetensors")
|
||||
with self.assertRaises(Exception):
|
||||
mx.save_safetensors(test_file, {"a": mx.ones((4, 4))}, {"testing": 0})
|
||||
|
||||
mx.save_safetensors(
|
||||
test_file, {"test": mx.ones((2, 2))}, {"testing": "test", "format": "mlx"}
|
||||
)
|
||||
res = mx.load(test_file, return_metadata=True)
|
||||
self.assertEqual(len(res), 2)
|
||||
self.assertEqual(res[1], {"testing": "test", "format": "mlx"})
|
||||
for obj in [str, Path]:
|
||||
mx.save_safetensors(
|
||||
obj(test_file),
|
||||
{"test": mx.ones((2, 2))},
|
||||
{"testing": "test", "format": "mlx"},
|
||||
)
|
||||
res = mx.load(obj(test_file), return_metadata=True)
|
||||
self.assertEqual(len(res), 2)
|
||||
self.assertEqual(res[1], {"testing": "test", "format": "mlx"})
|
||||
|
||||
for dt in self.dtypes + ["bfloat16"]:
|
||||
with self.subTest(dtype=dt):
|
||||
@ -128,6 +140,13 @@ class TestLoad(mlx_tests.MLXTestCase):
|
||||
mx.array_equal(load_dict["test"], save_dict["test"])
|
||||
)
|
||||
|
||||
save_file_mlx = os.path.join(self.test_dir, f"mlx_path_test_fs.gguf")
|
||||
save_dict = {"test": mx.ones(shape)}
|
||||
mx.save_gguf(Path(save_file_mlx), save_dict)
|
||||
load_dict = mx.load(Path(save_file_mlx))
|
||||
self.assertTrue("test" in load_dict)
|
||||
self.assertTrue(mx.array_equal(load_dict["test"], save_dict["test"]))
|
||||
|
||||
def test_load_f8_e4m3(self):
|
||||
if not os.path.isdir(self.test_dir):
|
||||
os.mkdir(self.test_dir)
|
||||
|
Loading…
Reference in New Issue
Block a user