mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-03 06:14:43 +08:00
GGUF: Load and save metadata (#446)
* gguf metadata --------- Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
@@ -117,6 +117,115 @@ class TestLoad(mlx_tests.MLXTestCase):
|
||||
mx.array_equal(load_dict["test"], save_dict["test"])
|
||||
)
|
||||
|
||||
def test_save_and_load_gguf_metadata_basic(self):
|
||||
if not os.path.isdir(self.test_dir):
|
||||
os.mkdir(self.test_dir)
|
||||
|
||||
save_file_mlx = os.path.join(self.test_dir, f"mlx_gguf_with_metadata.gguf")
|
||||
save_dict = {"test": mx.ones((4, 4), dtype=mx.int32)}
|
||||
metadata = {}
|
||||
|
||||
# Empty works
|
||||
mx.save_gguf(save_file_mlx, save_dict, metadata)
|
||||
|
||||
# Loads without the metadata
|
||||
load_dict = mx.load(save_file_mlx)
|
||||
self.assertTrue("test" in load_dict)
|
||||
self.assertTrue(mx.array_equal(load_dict["test"], save_dict["test"]))
|
||||
|
||||
# Loads empty metadata
|
||||
load_dict, meta_load_dict = mx.load(save_file_mlx, return_metadata=True)
|
||||
self.assertTrue("test" in load_dict)
|
||||
self.assertTrue(mx.array_equal(load_dict["test"], save_dict["test"]))
|
||||
self.assertEqual(len(meta_load_dict), 0)
|
||||
|
||||
# Loads string metadata
|
||||
metadata = {"meta": "data"}
|
||||
mx.save_gguf(save_file_mlx, save_dict, metadata)
|
||||
load_dict, meta_load_dict = mx.load(save_file_mlx, return_metadata=True)
|
||||
self.assertTrue("test" in load_dict)
|
||||
self.assertTrue(mx.array_equal(load_dict["test"], save_dict["test"]))
|
||||
self.assertEqual(len(meta_load_dict), 1)
|
||||
self.assertTrue("meta" in meta_load_dict)
|
||||
self.assertEqual(meta_load_dict["meta"], "data")
|
||||
|
||||
def test_save_and_load_gguf_metadata_arrays(self):
|
||||
if not os.path.isdir(self.test_dir):
|
||||
os.mkdir(self.test_dir)
|
||||
|
||||
save_file_mlx = os.path.join(self.test_dir, f"mlx_gguf_with_metadata.gguf")
|
||||
save_dict = {"test": mx.ones((4, 4), dtype=mx.int32)}
|
||||
|
||||
# Test scalars and one dimensional arrays
|
||||
for t in [
|
||||
mx.uint8,
|
||||
mx.int8,
|
||||
mx.uint16,
|
||||
mx.int16,
|
||||
mx.uint32,
|
||||
mx.int32,
|
||||
mx.uint64,
|
||||
mx.int64,
|
||||
mx.float32,
|
||||
]:
|
||||
for shape in [(), (2,)]:
|
||||
arr = mx.random.uniform(shape=shape).astype(t)
|
||||
metadata = {"meta": arr}
|
||||
mx.save_gguf(save_file_mlx, save_dict, metadata)
|
||||
_, meta_load_dict = mx.load(save_file_mlx, return_metadata=True)
|
||||
self.assertEqual(len(meta_load_dict), 1)
|
||||
self.assertTrue("meta" in meta_load_dict)
|
||||
self.assertTrue(mx.array_equal(meta_load_dict["meta"], arr))
|
||||
self.assertEqual(meta_load_dict["meta"].dtype, arr.dtype)
|
||||
|
||||
for t in [mx.float16, mx.bfloat16, mx.complex64]:
|
||||
with self.assertRaises(ValueError):
|
||||
arr = mx.array(1, t)
|
||||
metadata = {"meta": arr}
|
||||
mx.save_gguf(save_file_mlx, save_dict, metadata)
|
||||
|
||||
def test_save_and_load_gguf_metadata_mixed(self):
|
||||
if not os.path.isdir(self.test_dir):
|
||||
os.mkdir(self.test_dir)
|
||||
|
||||
save_file_mlx = os.path.join(self.test_dir, f"mlx_gguf_with_metadata.gguf")
|
||||
save_dict = {"test": mx.ones((4, 4), dtype=mx.int32)}
|
||||
|
||||
# Test string and array
|
||||
arr = mx.array(1.5)
|
||||
metadata = {"meta1": arr, "meta2": "data"}
|
||||
mx.save_gguf(save_file_mlx, save_dict, metadata)
|
||||
_, meta_load_dict = mx.load(save_file_mlx, return_metadata=True)
|
||||
self.assertEqual(len(meta_load_dict), 2)
|
||||
self.assertTrue("meta1" in meta_load_dict)
|
||||
self.assertTrue(mx.array_equal(meta_load_dict["meta1"], arr))
|
||||
self.assertEqual(meta_load_dict["meta1"].dtype, arr.dtype)
|
||||
self.assertTrue("meta2" in meta_load_dict)
|
||||
self.assertEqual(meta_load_dict["meta2"], "data")
|
||||
|
||||
# Test list of strings
|
||||
metadata = {"meta": ["data1", "data2", "data345"]}
|
||||
mx.save_gguf(save_file_mlx, save_dict, metadata)
|
||||
_, meta_load_dict = mx.load(save_file_mlx, return_metadata=True)
|
||||
self.assertEqual(len(meta_load_dict), 1)
|
||||
self.assertEqual(meta_load_dict["meta"], metadata["meta"])
|
||||
|
||||
# Test a combination of stuff
|
||||
metadata = {
|
||||
"meta1": ["data1", "data2", "data345"],
|
||||
"meta2": mx.array([1, 2, 3, 4]),
|
||||
"meta3": "data",
|
||||
"meta4": mx.array(1.5),
|
||||
}
|
||||
mx.save_gguf(save_file_mlx, save_dict, metadata)
|
||||
_, meta_load_dict = mx.load(save_file_mlx, return_metadata=True)
|
||||
self.assertEqual(len(meta_load_dict), 4)
|
||||
for k, v in metadata.items():
|
||||
if isinstance(v, mx.array):
|
||||
self.assertTrue(mx.array_equal(meta_load_dict[k], v))
|
||||
else:
|
||||
self.assertEqual(meta_load_dict[k], v)
|
||||
|
||||
def test_save_and_load_fs(self):
|
||||
if not os.path.isdir(self.test_dir):
|
||||
os.mkdir(self.test_dir)
|
||||
|
Reference in New Issue
Block a user