mirror of
				https://github.com/ml-explore/mlx.git
				synced 2025-10-31 16:21:27 +08:00 
			
		
		
		
	GGUF: Load and save metadata (#446)
* gguf metadata --------- Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
		| @@ -576,8 +576,8 @@ class TestBlas(mlx_tests.MLXTestCase): | ||||
|                 ], | ||||
|             ) | ||||
|  | ||||
|             self.assertTrue(mx.allclose(out_ref[0], out_test[0], atol=1e-5).item()) | ||||
|             self.assertTrue(mx.allclose(out_ref[0], out_test[0], atol=1e-4).item()) | ||||
|  | ||||
|             for r, t in zip(dout_ref, dout_test): | ||||
|                 self.assertListEqual(r.shape, t.shape) | ||||
|                 self.assertTrue(mx.allclose(r, t, atol=1e-5).item()) | ||||
|                 self.assertTrue(mx.allclose(r, t, atol=1e-4).item()) | ||||
|   | ||||
| @@ -117,6 +117,115 @@ class TestLoad(mlx_tests.MLXTestCase): | ||||
|                             mx.array_equal(load_dict["test"], save_dict["test"]) | ||||
|                         ) | ||||
|  | ||||
|     def test_save_and_load_gguf_metadata_basic(self): | ||||
|         if not os.path.isdir(self.test_dir): | ||||
|             os.mkdir(self.test_dir) | ||||
|  | ||||
|         save_file_mlx = os.path.join(self.test_dir, f"mlx_gguf_with_metadata.gguf") | ||||
|         save_dict = {"test": mx.ones((4, 4), dtype=mx.int32)} | ||||
|         metadata = {} | ||||
|  | ||||
|         # Empty works | ||||
|         mx.save_gguf(save_file_mlx, save_dict, metadata) | ||||
|  | ||||
|         # Loads without the metadata | ||||
|         load_dict = mx.load(save_file_mlx) | ||||
|         self.assertTrue("test" in load_dict) | ||||
|         self.assertTrue(mx.array_equal(load_dict["test"], save_dict["test"])) | ||||
|  | ||||
|         # Loads empty metadata | ||||
|         load_dict, meta_load_dict = mx.load(save_file_mlx, return_metadata=True) | ||||
|         self.assertTrue("test" in load_dict) | ||||
|         self.assertTrue(mx.array_equal(load_dict["test"], save_dict["test"])) | ||||
|         self.assertEqual(len(meta_load_dict), 0) | ||||
|  | ||||
|         # Loads string metadata | ||||
|         metadata = {"meta": "data"} | ||||
|         mx.save_gguf(save_file_mlx, save_dict, metadata) | ||||
|         load_dict, meta_load_dict = mx.load(save_file_mlx, return_metadata=True) | ||||
|         self.assertTrue("test" in load_dict) | ||||
|         self.assertTrue(mx.array_equal(load_dict["test"], save_dict["test"])) | ||||
|         self.assertEqual(len(meta_load_dict), 1) | ||||
|         self.assertTrue("meta" in meta_load_dict) | ||||
|         self.assertEqual(meta_load_dict["meta"], "data") | ||||
|  | ||||
|     def test_save_and_load_gguf_metadata_arrays(self): | ||||
|         if not os.path.isdir(self.test_dir): | ||||
|             os.mkdir(self.test_dir) | ||||
|  | ||||
|         save_file_mlx = os.path.join(self.test_dir, f"mlx_gguf_with_metadata.gguf") | ||||
|         save_dict = {"test": mx.ones((4, 4), dtype=mx.int32)} | ||||
|  | ||||
|         # Test scalars and one dimensional arrays | ||||
|         for t in [ | ||||
|             mx.uint8, | ||||
|             mx.int8, | ||||
|             mx.uint16, | ||||
|             mx.int16, | ||||
|             mx.uint32, | ||||
|             mx.int32, | ||||
|             mx.uint64, | ||||
|             mx.int64, | ||||
|             mx.float32, | ||||
|         ]: | ||||
|             for shape in [(), (2,)]: | ||||
|                 arr = mx.random.uniform(shape=shape).astype(t) | ||||
|                 metadata = {"meta": arr} | ||||
|                 mx.save_gguf(save_file_mlx, save_dict, metadata) | ||||
|                 _, meta_load_dict = mx.load(save_file_mlx, return_metadata=True) | ||||
|                 self.assertEqual(len(meta_load_dict), 1) | ||||
|                 self.assertTrue("meta" in meta_load_dict) | ||||
|                 self.assertTrue(mx.array_equal(meta_load_dict["meta"], arr)) | ||||
|                 self.assertEqual(meta_load_dict["meta"].dtype, arr.dtype) | ||||
|  | ||||
|         for t in [mx.float16, mx.bfloat16, mx.complex64]: | ||||
|             with self.assertRaises(ValueError): | ||||
|                 arr = mx.array(1, t) | ||||
|                 metadata = {"meta": arr} | ||||
|                 mx.save_gguf(save_file_mlx, save_dict, metadata) | ||||
|  | ||||
|     def test_save_and_load_gguf_metadata_mixed(self): | ||||
|         if not os.path.isdir(self.test_dir): | ||||
|             os.mkdir(self.test_dir) | ||||
|  | ||||
|         save_file_mlx = os.path.join(self.test_dir, f"mlx_gguf_with_metadata.gguf") | ||||
|         save_dict = {"test": mx.ones((4, 4), dtype=mx.int32)} | ||||
|  | ||||
|         # Test string and array | ||||
|         arr = mx.array(1.5) | ||||
|         metadata = {"meta1": arr, "meta2": "data"} | ||||
|         mx.save_gguf(save_file_mlx, save_dict, metadata) | ||||
|         _, meta_load_dict = mx.load(save_file_mlx, return_metadata=True) | ||||
|         self.assertEqual(len(meta_load_dict), 2) | ||||
|         self.assertTrue("meta1" in meta_load_dict) | ||||
|         self.assertTrue(mx.array_equal(meta_load_dict["meta1"], arr)) | ||||
|         self.assertEqual(meta_load_dict["meta1"].dtype, arr.dtype) | ||||
|         self.assertTrue("meta2" in meta_load_dict) | ||||
|         self.assertEqual(meta_load_dict["meta2"], "data") | ||||
|  | ||||
|         # Test list of strings | ||||
|         metadata = {"meta": ["data1", "data2", "data345"]} | ||||
|         mx.save_gguf(save_file_mlx, save_dict, metadata) | ||||
|         _, meta_load_dict = mx.load(save_file_mlx, return_metadata=True) | ||||
|         self.assertEqual(len(meta_load_dict), 1) | ||||
|         self.assertEqual(meta_load_dict["meta"], metadata["meta"]) | ||||
|  | ||||
|         # Test a combination of stuff | ||||
|         metadata = { | ||||
|             "meta1": ["data1", "data2", "data345"], | ||||
|             "meta2": mx.array([1, 2, 3, 4]), | ||||
|             "meta3": "data", | ||||
|             "meta4": mx.array(1.5), | ||||
|         } | ||||
|         mx.save_gguf(save_file_mlx, save_dict, metadata) | ||||
|         _, meta_load_dict = mx.load(save_file_mlx, return_metadata=True) | ||||
|         self.assertEqual(len(meta_load_dict), 4) | ||||
|         for k, v in metadata.items(): | ||||
|             if isinstance(v, mx.array): | ||||
|                 self.assertTrue(mx.array_equal(meta_load_dict[k], v)) | ||||
|             else: | ||||
|                 self.assertEqual(meta_load_dict[k], v) | ||||
|  | ||||
|     def test_save_and_load_fs(self): | ||||
|         if not os.path.isdir(self.test_dir): | ||||
|             os.mkdir(self.test_dir) | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 Juarez Bochi
					Juarez Bochi