mirror of
				https://github.com/ml-explore/mlx.git
				synced 2025-11-04 10:38:10 +08:00 
			
		
		
		
	
		
			
				
	
	
		
			369 lines
		
	
	
		
			15 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			369 lines
		
	
	
		
			15 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
# Copyright © 2023 Apple Inc.
 | 
						|
 | 
						|
import os
 | 
						|
import tempfile
 | 
						|
import unittest
 | 
						|
 | 
						|
import mlx.core as mx
 | 
						|
import mlx_tests
 | 
						|
import numpy as np
 | 
						|
 | 
						|
 | 
						|
class TestLoad(mlx_tests.MLXTestCase):
 | 
						|
    dtypes = [
 | 
						|
        "uint8",
 | 
						|
        "uint16",
 | 
						|
        "uint32",
 | 
						|
        "uint64",
 | 
						|
        "int8",
 | 
						|
        "int16",
 | 
						|
        "int32",
 | 
						|
        "int64",
 | 
						|
        "float32",
 | 
						|
        "float16",
 | 
						|
        "complex64",
 | 
						|
    ]
 | 
						|
 | 
						|
    @classmethod
 | 
						|
    def setUpClass(cls):
 | 
						|
        cls.test_dir_fid = tempfile.TemporaryDirectory()
 | 
						|
        cls.test_dir = cls.test_dir_fid.name
 | 
						|
 | 
						|
    @classmethod
 | 
						|
    def tearDownClass(cls):
 | 
						|
        cls.test_dir_fid.cleanup()
 | 
						|
 | 
						|
    def test_save_and_load(self):
 | 
						|
        if not os.path.isdir(self.test_dir):
 | 
						|
            os.mkdir(self.test_dir)
 | 
						|
 | 
						|
        for dt in self.dtypes:
 | 
						|
            with self.subTest(dtype=dt):
 | 
						|
                for i, shape in enumerate([(1,), (23,), (1024, 1024), (4, 6, 3, 1, 2)]):
 | 
						|
                    with self.subTest(shape=shape):
 | 
						|
                        save_file_mlx = os.path.join(self.test_dir, f"mlx_{dt}_{i}.npy")
 | 
						|
                        save_file_npy = os.path.join(self.test_dir, f"npy_{dt}_{i}.npy")
 | 
						|
 | 
						|
                        save_arr = np.random.uniform(0.0, 32.0, size=shape)
 | 
						|
                        save_arr_npy = save_arr.astype(getattr(np, dt))
 | 
						|
                        save_arr_mlx = mx.array(save_arr_npy)
 | 
						|
 | 
						|
                        mx.save(save_file_mlx, save_arr_mlx)
 | 
						|
                        np.save(save_file_npy, save_arr_npy)
 | 
						|
 | 
						|
                        # Load array saved by mlx as mlx array
 | 
						|
                        load_arr_mlx_mlx = mx.load(save_file_mlx)
 | 
						|
                        self.assertTrue(mx.array_equal(load_arr_mlx_mlx, save_arr_mlx))
 | 
						|
 | 
						|
                        # Load array saved by numpy as mlx array
 | 
						|
                        load_arr_npy_mlx = mx.load(save_file_npy)
 | 
						|
                        self.assertTrue(mx.array_equal(load_arr_npy_mlx, save_arr_mlx))
 | 
						|
 | 
						|
                        # Load array saved by mlx as numpy array
 | 
						|
                        load_arr_mlx_npy = np.load(save_file_mlx)
 | 
						|
                        self.assertTrue(np.array_equal(load_arr_mlx_npy, save_arr_npy))
 | 
						|
 | 
						|
    def test_save_and_load_safetensors(self):
 | 
						|
        if not os.path.isdir(self.test_dir):
 | 
						|
            os.mkdir(self.test_dir)
 | 
						|
 | 
						|
        test_file = os.path.join(self.test_dir, "test.safetensors")
 | 
						|
        with self.assertRaises(Exception):
 | 
						|
            mx.save_safetensors(test_file, {"a": mx.ones((4, 4))}, {"testing": 0})
 | 
						|
 | 
						|
        mx.save_safetensors(
 | 
						|
            test_file, {"test": mx.ones((2, 2))}, {"testing": "test", "format": "mlx"}
 | 
						|
        )
 | 
						|
        res = mx.load(test_file, return_metadata=True)
 | 
						|
        self.assertEqual(len(res), 2)
 | 
						|
        self.assertEqual(res[1], {"testing": "test", "format": "mlx"})
 | 
						|
 | 
						|
        for dt in self.dtypes + ["bfloat16"]:
 | 
						|
            with self.subTest(dtype=dt):
 | 
						|
                for i, shape in enumerate([(1,), (23,), (1024, 1024), (4, 6, 3, 1, 2)]):
 | 
						|
                    with self.subTest(shape=shape):
 | 
						|
                        save_file_mlx = os.path.join(
 | 
						|
                            self.test_dir, f"mlx_{dt}_{i}_fs.safetensors"
 | 
						|
                        )
 | 
						|
                        save_dict = {
 | 
						|
                            "test": (
 | 
						|
                                mx.random.normal(shape=shape, dtype=getattr(mx, dt))
 | 
						|
                                if dt in ["float32", "float16", "bfloat16"]
 | 
						|
                                else mx.ones(shape, dtype=getattr(mx, dt))
 | 
						|
                            )
 | 
						|
                        }
 | 
						|
 | 
						|
                        with open(save_file_mlx, "wb") as f:
 | 
						|
                            mx.save_safetensors(f, save_dict)
 | 
						|
                        with open(save_file_mlx, "rb") as f:
 | 
						|
                            load_dict = mx.load(f)
 | 
						|
 | 
						|
                        self.assertTrue("test" in load_dict)
 | 
						|
                        self.assertTrue(
 | 
						|
                            mx.array_equal(load_dict["test"], save_dict["test"])
 | 
						|
                        )
 | 
						|
 | 
						|
    def test_save_and_load_gguf(self):
 | 
						|
        if not os.path.isdir(self.test_dir):
 | 
						|
            os.mkdir(self.test_dir)
 | 
						|
 | 
						|
        # TODO: Add support for other dtypes (self.dtypes + ["bfloat16"])
 | 
						|
        supported_dtypes = ["float16", "float32", "int8", "int16", "int32"]
 | 
						|
        for dt in supported_dtypes:
 | 
						|
            with self.subTest(dtype=dt):
 | 
						|
                for i, shape in enumerate([(1,), (23,), (1024, 1024), (4, 6, 3, 1, 2)]):
 | 
						|
                    with self.subTest(shape=shape):
 | 
						|
                        save_file_mlx = os.path.join(
 | 
						|
                            self.test_dir, f"mlx_{dt}_{i}_fs.gguf"
 | 
						|
                        )
 | 
						|
                        save_dict = {
 | 
						|
                            "test": (
 | 
						|
                                mx.random.normal(shape=shape, dtype=getattr(mx, dt))
 | 
						|
                                if dt in ["float32", "float16", "bfloat16"]
 | 
						|
                                else mx.ones(shape, dtype=getattr(mx, dt))
 | 
						|
                            )
 | 
						|
                        }
 | 
						|
 | 
						|
                        mx.save_gguf(save_file_mlx, save_dict)
 | 
						|
                        load_dict = mx.load(save_file_mlx)
 | 
						|
 | 
						|
                        self.assertTrue("test" in load_dict)
 | 
						|
                        self.assertTrue(
 | 
						|
                            mx.array_equal(load_dict["test"], save_dict["test"])
 | 
						|
                        )
 | 
						|
 | 
						|
    def test_save_and_load_gguf_metadata_basic(self):
 | 
						|
        if not os.path.isdir(self.test_dir):
 | 
						|
            os.mkdir(self.test_dir)
 | 
						|
 | 
						|
        save_file_mlx = os.path.join(self.test_dir, f"mlx_gguf_with_metadata.gguf")
 | 
						|
        save_dict = {"test": mx.ones((4, 4), dtype=mx.int32)}
 | 
						|
        metadata = {}
 | 
						|
 | 
						|
        # Empty works
 | 
						|
        mx.save_gguf(save_file_mlx, save_dict, metadata)
 | 
						|
 | 
						|
        # Loads without the metadata
 | 
						|
        load_dict = mx.load(save_file_mlx)
 | 
						|
        self.assertTrue("test" in load_dict)
 | 
						|
        self.assertTrue(mx.array_equal(load_dict["test"], save_dict["test"]))
 | 
						|
 | 
						|
        # Loads empty metadata
 | 
						|
        load_dict, meta_load_dict = mx.load(save_file_mlx, return_metadata=True)
 | 
						|
        self.assertTrue("test" in load_dict)
 | 
						|
        self.assertTrue(mx.array_equal(load_dict["test"], save_dict["test"]))
 | 
						|
        self.assertEqual(len(meta_load_dict), 0)
 | 
						|
 | 
						|
        # Loads string metadata
 | 
						|
        metadata = {"meta": "data"}
 | 
						|
        mx.save_gguf(save_file_mlx, save_dict, metadata)
 | 
						|
        load_dict, meta_load_dict = mx.load(save_file_mlx, return_metadata=True)
 | 
						|
        self.assertTrue("test" in load_dict)
 | 
						|
        self.assertTrue(mx.array_equal(load_dict["test"], save_dict["test"]))
 | 
						|
        self.assertEqual(len(meta_load_dict), 1)
 | 
						|
        self.assertTrue("meta" in meta_load_dict)
 | 
						|
        self.assertEqual(meta_load_dict["meta"], "data")
 | 
						|
 | 
						|
    def test_save_and_load_gguf_metadata_arrays(self):
 | 
						|
        if not os.path.isdir(self.test_dir):
 | 
						|
            os.mkdir(self.test_dir)
 | 
						|
 | 
						|
        save_file_mlx = os.path.join(self.test_dir, f"mlx_gguf_with_metadata.gguf")
 | 
						|
        save_dict = {"test": mx.ones((4, 4), dtype=mx.int32)}
 | 
						|
 | 
						|
        # Test scalars and one dimensional arrays
 | 
						|
        for t in [
 | 
						|
            mx.uint8,
 | 
						|
            mx.int8,
 | 
						|
            mx.uint16,
 | 
						|
            mx.int16,
 | 
						|
            mx.uint32,
 | 
						|
            mx.int32,
 | 
						|
            mx.uint64,
 | 
						|
            mx.int64,
 | 
						|
            mx.float32,
 | 
						|
        ]:
 | 
						|
            for shape in [(), (2,)]:
 | 
						|
                arr = mx.random.uniform(shape=shape).astype(t)
 | 
						|
                metadata = {"meta": arr}
 | 
						|
                mx.save_gguf(save_file_mlx, save_dict, metadata)
 | 
						|
                _, meta_load_dict = mx.load(save_file_mlx, return_metadata=True)
 | 
						|
                self.assertEqual(len(meta_load_dict), 1)
 | 
						|
                self.assertTrue("meta" in meta_load_dict)
 | 
						|
                self.assertTrue(mx.array_equal(meta_load_dict["meta"], arr))
 | 
						|
                self.assertEqual(meta_load_dict["meta"].dtype, arr.dtype)
 | 
						|
 | 
						|
        for t in [mx.float16, mx.bfloat16, mx.complex64]:
 | 
						|
            with self.assertRaises(ValueError):
 | 
						|
                arr = mx.array(1, t)
 | 
						|
                metadata = {"meta": arr}
 | 
						|
                mx.save_gguf(save_file_mlx, save_dict, metadata)
 | 
						|
 | 
						|
    def test_save_and_load_gguf_metadata_mixed(self):
 | 
						|
        if not os.path.isdir(self.test_dir):
 | 
						|
            os.mkdir(self.test_dir)
 | 
						|
 | 
						|
        save_file_mlx = os.path.join(self.test_dir, f"mlx_gguf_with_metadata.gguf")
 | 
						|
        save_dict = {"test": mx.ones((4, 4), dtype=mx.int32)}
 | 
						|
 | 
						|
        # Test string and array
 | 
						|
        arr = mx.array(1.5)
 | 
						|
        metadata = {"meta1": arr, "meta2": "data"}
 | 
						|
        mx.save_gguf(save_file_mlx, save_dict, metadata)
 | 
						|
        _, meta_load_dict = mx.load(save_file_mlx, return_metadata=True)
 | 
						|
        self.assertEqual(len(meta_load_dict), 2)
 | 
						|
        self.assertTrue("meta1" in meta_load_dict)
 | 
						|
        self.assertTrue(mx.array_equal(meta_load_dict["meta1"], arr))
 | 
						|
        self.assertEqual(meta_load_dict["meta1"].dtype, arr.dtype)
 | 
						|
        self.assertTrue("meta2" in meta_load_dict)
 | 
						|
        self.assertEqual(meta_load_dict["meta2"], "data")
 | 
						|
 | 
						|
        # Test list of strings
 | 
						|
        metadata = {"meta": ["data1", "data2", "data345"]}
 | 
						|
        mx.save_gguf(save_file_mlx, save_dict, metadata)
 | 
						|
        _, meta_load_dict = mx.load(save_file_mlx, return_metadata=True)
 | 
						|
        self.assertEqual(len(meta_load_dict), 1)
 | 
						|
        self.assertEqual(meta_load_dict["meta"], metadata["meta"])
 | 
						|
 | 
						|
        # Test a combination of stuff
 | 
						|
        metadata = {
 | 
						|
            "meta1": ["data1", "data2", "data345"],
 | 
						|
            "meta2": mx.array([1, 2, 3, 4]),
 | 
						|
            "meta3": "data",
 | 
						|
            "meta4": mx.array(1.5),
 | 
						|
        }
 | 
						|
        mx.save_gguf(save_file_mlx, save_dict, metadata)
 | 
						|
        _, meta_load_dict = mx.load(save_file_mlx, return_metadata=True)
 | 
						|
        self.assertEqual(len(meta_load_dict), 4)
 | 
						|
        for k, v in metadata.items():
 | 
						|
            if isinstance(v, mx.array):
 | 
						|
                self.assertTrue(mx.array_equal(meta_load_dict[k], v))
 | 
						|
            else:
 | 
						|
                self.assertEqual(meta_load_dict[k], v)
 | 
						|
 | 
						|
    def test_save_and_load_fs(self):
 | 
						|
        if not os.path.isdir(self.test_dir):
 | 
						|
            os.mkdir(self.test_dir)
 | 
						|
 | 
						|
        for dt in self.dtypes:
 | 
						|
            with self.subTest(dtype=dt):
 | 
						|
                for i, shape in enumerate([(1,), (23,), (1024, 1024), (4, 6, 3, 1, 2)]):
 | 
						|
                    with self.subTest(shape=shape):
 | 
						|
                        save_file_mlx = os.path.join(
 | 
						|
                            self.test_dir, f"mlx_{dt}_{i}_fs.npy"
 | 
						|
                        )
 | 
						|
                        save_file_npy = os.path.join(
 | 
						|
                            self.test_dir, f"npy_{dt}_{i}_fs.npy"
 | 
						|
                        )
 | 
						|
 | 
						|
                        save_arr = np.random.uniform(0.0, 32.0, size=shape)
 | 
						|
                        save_arr_npy = save_arr.astype(getattr(np, dt))
 | 
						|
                        save_arr_mlx = mx.array(save_arr_npy)
 | 
						|
 | 
						|
                        with open(save_file_mlx, "wb") as f:
 | 
						|
                            mx.save(f, save_arr_mlx)
 | 
						|
 | 
						|
                        np.save(save_file_npy, save_arr_npy)
 | 
						|
 | 
						|
                        # Load array saved by mlx as mlx array
 | 
						|
                        with open(save_file_mlx, "rb") as f:
 | 
						|
                            load_arr_mlx_mlx = mx.load(f)
 | 
						|
                        self.assertTrue(mx.array_equal(load_arr_mlx_mlx, save_arr_mlx))
 | 
						|
 | 
						|
                        # Load array saved by numpy as mlx array
 | 
						|
                        with open(save_file_npy, "rb") as f:
 | 
						|
                            load_arr_npy_mlx = mx.load(f)
 | 
						|
                        self.assertTrue(mx.array_equal(load_arr_npy_mlx, save_arr_mlx))
 | 
						|
 | 
						|
                        # Load array saved by mlx as numpy array
 | 
						|
                        load_arr_mlx_npy = np.load(save_file_mlx)
 | 
						|
                        self.assertTrue(np.array_equal(load_arr_mlx_npy, save_arr_npy))
 | 
						|
 | 
						|
    def test_savez_and_loadz(self):
 | 
						|
        if not os.path.isdir(self.test_dir):
 | 
						|
            os.mkdir(self.test_dir)
 | 
						|
 | 
						|
        for dt in self.dtypes:
 | 
						|
            with self.subTest(dtype=dt):
 | 
						|
                shapes = [(6,), (6, 6), (4, 1, 3, 1, 2)]
 | 
						|
                save_file_mlx_uncomp = os.path.join(
 | 
						|
                    self.test_dir, f"mlx_{dt}_uncomp.npz"
 | 
						|
                )
 | 
						|
                save_file_npy_uncomp = os.path.join(
 | 
						|
                    self.test_dir, f"npy_{dt}_uncomp.npz"
 | 
						|
                )
 | 
						|
                save_file_mlx_comp = os.path.join(self.test_dir, f"mlx_{dt}_comp.npz")
 | 
						|
                save_file_npy_comp = os.path.join(self.test_dir, f"npy_{dt}_comp.npz")
 | 
						|
 | 
						|
                # Make dictionary of multiple
 | 
						|
                save_arrs_npy = {
 | 
						|
                    f"save_arr_{i}": np.random.uniform(
 | 
						|
                        0.0, 32.0, size=shapes[i]
 | 
						|
                    ).astype(getattr(np, dt))
 | 
						|
                    for i in range(len(shapes))
 | 
						|
                }
 | 
						|
                save_arrs_mlx = {k: mx.array(v) for k, v in save_arrs_npy.items()}
 | 
						|
 | 
						|
                # Save as npz files
 | 
						|
                np.savez(save_file_npy_uncomp, **save_arrs_npy)
 | 
						|
                mx.savez(save_file_mlx_uncomp, **save_arrs_mlx)
 | 
						|
                np.savez_compressed(save_file_npy_comp, **save_arrs_npy)
 | 
						|
                mx.savez_compressed(save_file_mlx_comp, **save_arrs_mlx)
 | 
						|
 | 
						|
                for save_file_npy, save_file_mlx in (
 | 
						|
                    (save_file_npy_uncomp, save_file_mlx_uncomp),
 | 
						|
                    (save_file_npy_comp, save_file_mlx_comp),
 | 
						|
                ):
 | 
						|
                    # Load array saved by mlx as mlx array
 | 
						|
                    load_arr_mlx_mlx = mx.load(save_file_mlx)
 | 
						|
                    for k, v in load_arr_mlx_mlx.items():
 | 
						|
                        self.assertTrue(mx.array_equal(save_arrs_mlx[k], v))
 | 
						|
 | 
						|
                    # Load arrays saved by numpy as mlx arrays
 | 
						|
                    load_arr_npy_mlx = mx.load(save_file_npy)
 | 
						|
                    for k, v in load_arr_npy_mlx.items():
 | 
						|
                        self.assertTrue(mx.array_equal(save_arrs_mlx[k], v))
 | 
						|
 | 
						|
                    # Load array saved by mlx as numpy array
 | 
						|
                    load_arr_mlx_npy = np.load(save_file_mlx)
 | 
						|
                    for k, v in load_arr_mlx_npy.items():
 | 
						|
                        self.assertTrue(np.array_equal(save_arrs_npy[k], v))
 | 
						|
 | 
						|
    def test_non_contiguous(self):
 | 
						|
        if not os.path.isdir(self.test_dir):
 | 
						|
            os.mkdir(self.test_dir)
 | 
						|
 | 
						|
        a = mx.broadcast_to(mx.array([1, 2]), [4, 2])
 | 
						|
 | 
						|
        save_file = os.path.join(self.test_dir, "a.npy")
 | 
						|
        mx.save(save_file, a)
 | 
						|
        aload = mx.load(save_file)
 | 
						|
        self.assertTrue(mx.array_equal(a, aload))
 | 
						|
 | 
						|
        save_file = os.path.join(self.test_dir, "a.safetensors")
 | 
						|
        mx.save_safetensors(save_file, {"a": a})
 | 
						|
        aload = mx.load(save_file)["a"]
 | 
						|
        self.assertTrue(mx.array_equal(a, aload))
 | 
						|
 | 
						|
        save_file = os.path.join(self.test_dir, "a.gguf")
 | 
						|
        mx.save_gguf(save_file, {"a": a})
 | 
						|
        aload = mx.load(save_file)["a"]
 | 
						|
        self.assertTrue(mx.array_equal(a, aload))
 | 
						|
 | 
						|
        # safetensors and gguf only work with row contiguous
 | 
						|
        # make sure col contiguous is handled properly
 | 
						|
        save_file = os.path.join(self.test_dir, "a.safetensors")
 | 
						|
        a = mx.arange(4).reshape(2, 2).T
 | 
						|
        mx.save_safetensors(save_file, {"a": a})
 | 
						|
        aload = mx.load(save_file)["a"]
 | 
						|
        self.assertTrue(mx.array_equal(a, aload))
 | 
						|
 | 
						|
        save_file = os.path.join(self.test_dir, "a.gguf")
 | 
						|
        mx.save_gguf(save_file, {"a": a})
 | 
						|
        aload = mx.load(save_file)["a"]
 | 
						|
        self.assertTrue(mx.array_equal(a, aload))
 | 
						|
 | 
						|
 | 
						|
if __name__ == "__main__":
 | 
						|
    unittest.main()
 |