This commit is contained in:
Awni Hannun 2025-08-25 22:57:29 +00:00 committed by GitHub
commit 2e37f68a99
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -3,6 +3,7 @@
import os
import tempfile
import unittest
from pathlib import Path
import mlx.core as mx
import mlx_tests
@ -62,17 +63,28 @@ class TestLoad(mlx_tests.MLXTestCase):
load_arr_mlx_npy = np.load(save_file_mlx)
self.assertTrue(np.array_equal(load_arr_mlx_npy, save_arr_npy))
save_file = os.path.join(self.test_dir, f"mlx_path.npy")
save_arr = mx.ones((32,))
mx.save(Path(save_file), save_arr)
# Load array saved by mlx as mlx array
load_arr = mx.load(Path(save_file))
self.assertTrue(mx.array_equal(load_arr, save_arr))
def test_save_and_load_safetensors(self):
test_file = os.path.join(self.test_dir, "test.safetensors")
with self.assertRaises(Exception):
mx.save_safetensors(test_file, {"a": mx.ones((4, 4))}, {"testing": 0})
mx.save_safetensors(
test_file, {"test": mx.ones((2, 2))}, {"testing": "test", "format": "mlx"}
)
res = mx.load(test_file, return_metadata=True)
self.assertEqual(len(res), 2)
self.assertEqual(res[1], {"testing": "test", "format": "mlx"})
for obj in [str, Path]:
mx.save_safetensors(
obj(test_file),
{"test": mx.ones((2, 2))},
{"testing": "test", "format": "mlx"},
)
res = mx.load(obj(test_file), return_metadata=True)
self.assertEqual(len(res), 2)
self.assertEqual(res[1], {"testing": "test", "format": "mlx"})
for dt in self.dtypes + ["bfloat16"]:
with self.subTest(dtype=dt):
@ -128,6 +140,13 @@ class TestLoad(mlx_tests.MLXTestCase):
mx.array_equal(load_dict["test"], save_dict["test"])
)
save_file_mlx = os.path.join(self.test_dir, f"mlx_path_test_fs.gguf")
save_dict = {"test": mx.ones(shape)}
mx.save_gguf(Path(save_file_mlx), save_dict)
load_dict = mx.load(Path(save_file_mlx))
self.assertTrue("test" in load_dict)
self.assertTrue(mx.array_equal(load_dict["test"], save_dict["test"]))
def test_load_f8_e4m3(self):
if not os.path.isdir(self.test_dir):
os.mkdir(self.test_dir)