Metadata support for safetensors (#639)

* metadata support for safetensors

* aliases making it alittle more readable

* addressing comments

* python binding tests
This commit is contained in:
Diogo
2024-02-08 22:33:15 -05:00
committed by GitHub
parent 221f8d3fc2
commit b57bd0488d
8 changed files with 108 additions and 69 deletions

View File

@@ -66,6 +66,15 @@ class TestLoad(mlx_tests.MLXTestCase):
def test_save_and_load_safetensors(self):
if not os.path.isdir(self.test_dir):
os.mkdir(self.test_dir)
with self.assertRaises(Exception):
mx.save_safetensors("test", {"a": mx.ones((4, 4))}, {"testing": 0})
mx.save_safetensors(
"test", {"test": mx.ones((2, 2))}, {"testing": "test", "format": "mlx"}
)
res = mx.load("test.safetensors", return_metadata=True)
self.assertEqual(len(res), 2)
self.assertEqual(res[1], {"testing": "test", "format": "mlx"})
for dt in self.dtypes + ["bfloat16"]:
with self.subTest(dtype=dt):