GGUF support (#350)

* Initial GGUF support for tensor fields.

---------

Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
Juarez Bochi
2024-01-10 16:22:48 -05:00
committed by GitHub
parent e3e933c6bc
commit b7f905787e
12 changed files with 362 additions and 55 deletions

View File

@@ -90,6 +90,33 @@ class TestLoad(mlx_tests.MLXTestCase):
mx.array_equal(load_dict["test"], save_dict["test"])
)
def test_save_and_load_gguf(self):
if not os.path.isdir(self.test_dir):
os.mkdir(self.test_dir)
# TODO: Add support for other dtypes (self.dtypes + ["bfloat16"])
supported_dtypes = ["float16", "float32", "int8", "int16", "int32"]
for dt in supported_dtypes:
with self.subTest(dtype=dt):
for i, shape in enumerate([(1,), (23,), (1024, 1024), (4, 6, 3, 1, 2)]):
with self.subTest(shape=shape):
save_file_mlx = os.path.join(
self.test_dir, f"mlx_{dt}_{i}_fs.gguf"
)
save_dict = {
"test": mx.random.normal(shape=shape, dtype=getattr(mx, dt))
if dt in ["float32", "float16", "bfloat16"]
else mx.ones(shape, dtype=getattr(mx, dt))
}
mx.save_gguf(save_file_mlx, save_dict)
load_dict = mx.load(save_file_mlx)
self.assertTrue("test" in load_dict)
self.assertTrue(
mx.array_equal(load_dict["test"], save_dict["test"])
)
def test_save_and_load_fs(self):
if not os.path.isdir(self.test_dir):
os.mkdir(self.test_dir)
@@ -194,13 +221,24 @@ class TestLoad(mlx_tests.MLXTestCase):
aload = mx.load(save_file)["a"]
self.assertTrue(mx.array_equal(a, aload))
# safetensors only works with row contiguous
save_file = os.path.join(self.test_dir, "a.gguf")
mx.save_gguf(save_file, {"a": a})
aload = mx.load(save_file)["a"]
self.assertTrue(mx.array_equal(a, aload))
# safetensors and gguf only work with row contiguous
# make sure col contiguous is handled properly
save_file = os.path.join(self.test_dir, "a.safetensors")
a = mx.arange(4).reshape(2, 2).T
mx.save_safetensors(save_file, {"a": a})
aload = mx.load(save_file)["a"]
self.assertTrue(mx.array_equal(a, aload))
save_file = os.path.join(self.test_dir, "a.gguf")
mx.save_gguf(save_file, {"a": a})
aload = mx.load(save_file)["a"]
self.assertTrue(mx.array_equal(a, aload))
if __name__ == "__main__":
unittest.main()