# Copyright © 2023 Apple Inc. import os import tempfile import unittest import mlx.core as mx import mlx_tests import numpy as np class TestLoad(mlx_tests.MLXTestCase): dtypes = [ "uint8", "uint16", "uint32", "uint64", "int8", "int16", "int32", "int64", "float32", "float16", "complex64", ] @classmethod def setUpClass(cls): cls.test_dir_fid = tempfile.TemporaryDirectory() cls.test_dir = cls.test_dir_fid.name if not os.path.isdir(cls.test_dir): os.mkdir(cls.test_dir) @classmethod def tearDownClass(cls): cls.test_dir_fid.cleanup() def test_save_and_load(self): for dt in self.dtypes: with self.subTest(dtype=dt): for i, shape in enumerate([(1,), (23,), (1024, 1024), (4, 6, 3, 1, 2)]): with self.subTest(shape=shape): save_file_mlx = os.path.join(self.test_dir, f"mlx_{dt}_{i}.npy") save_file_npy = os.path.join(self.test_dir, f"npy_{dt}_{i}.npy") save_arr = np.random.uniform(0.0, 32.0, size=shape) save_arr_npy = save_arr.astype(getattr(np, dt)) save_arr_mlx = mx.array(save_arr_npy) mx.save(save_file_mlx, save_arr_mlx) np.save(save_file_npy, save_arr_npy) # Load array saved by mlx as mlx array load_arr_mlx_mlx = mx.load(save_file_mlx) self.assertTrue(mx.array_equal(load_arr_mlx_mlx, save_arr_mlx)) # Load array saved by numpy as mlx array load_arr_npy_mlx = mx.load(save_file_npy) self.assertTrue(mx.array_equal(load_arr_npy_mlx, save_arr_mlx)) # Load array saved by mlx as numpy array load_arr_mlx_npy = np.load(save_file_mlx) self.assertTrue(np.array_equal(load_arr_mlx_npy, save_arr_npy)) def test_save_and_load_safetensors(self): test_file = os.path.join(self.test_dir, "test.safetensors") with self.assertRaises(Exception): mx.save_safetensors(test_file, {"a": mx.ones((4, 4))}, {"testing": 0}) mx.save_safetensors( test_file, {"test": mx.ones((2, 2))}, {"testing": "test", "format": "mlx"} ) res = mx.load(test_file, return_metadata=True) self.assertEqual(len(res), 2) self.assertEqual(res[1], {"testing": "test", "format": "mlx"}) for dt in self.dtypes + ["bfloat16"]: with self.subTest(dtype=dt): for i, shape in enumerate([(1,), (23,), (1024, 1024), (4, 6, 3, 1, 2)]): with self.subTest(shape=shape): save_file_mlx = os.path.join( self.test_dir, f"mlx_{dt}_{i}_fs.safetensors" ) save_dict = { "test": ( mx.random.normal(shape=shape, dtype=getattr(mx, dt)) if dt in ["float32", "float16", "bfloat16"] else mx.ones(shape, dtype=getattr(mx, dt)) ) } with open(save_file_mlx, "wb") as f: mx.save_safetensors(f, save_dict) with open(save_file_mlx, "rb") as f: load_dict = mx.load(f) self.assertTrue("test" in load_dict) self.assertTrue( mx.array_equal(load_dict["test"], save_dict["test"]) ) def test_save_and_load_gguf(self): if not os.path.isdir(self.test_dir): os.mkdir(self.test_dir) # TODO: Add support for other dtypes (self.dtypes + ["bfloat16"]) supported_dtypes = ["float16", "float32", "int8", "int16", "int32"] for dt in supported_dtypes: with self.subTest(dtype=dt): for i, shape in enumerate([(1,), (23,), (1024, 1024), (4, 6, 3, 1, 2)]): with self.subTest(shape=shape): save_file_mlx = os.path.join( self.test_dir, f"mlx_{dt}_{i}_fs.gguf" ) save_dict = { "test": ( mx.random.normal(shape=shape, dtype=getattr(mx, dt)) if dt in ["float32", "float16", "bfloat16"] else mx.ones(shape, dtype=getattr(mx, dt)) ) } mx.save_gguf(save_file_mlx, save_dict) load_dict = mx.load(save_file_mlx) self.assertTrue("test" in load_dict) self.assertTrue( mx.array_equal(load_dict["test"], save_dict["test"]) ) def test_save_and_load_gguf_metadata_basic(self): if not os.path.isdir(self.test_dir): os.mkdir(self.test_dir) save_file_mlx = os.path.join(self.test_dir, f"mlx_gguf_with_metadata.gguf") save_dict = {"test": mx.ones((4, 4), dtype=mx.int32)} metadata = {} # Empty works mx.save_gguf(save_file_mlx, save_dict, metadata) # Loads without the metadata load_dict = mx.load(save_file_mlx) self.assertTrue("test" in load_dict) self.assertTrue(mx.array_equal(load_dict["test"], save_dict["test"])) # Loads empty metadata load_dict, meta_load_dict = mx.load(save_file_mlx, return_metadata=True) self.assertTrue("test" in load_dict) self.assertTrue(mx.array_equal(load_dict["test"], save_dict["test"])) self.assertEqual(len(meta_load_dict), 0) # Loads string metadata metadata = {"meta": "data"} mx.save_gguf(save_file_mlx, save_dict, metadata) load_dict, meta_load_dict = mx.load(save_file_mlx, return_metadata=True) self.assertTrue("test" in load_dict) self.assertTrue(mx.array_equal(load_dict["test"], save_dict["test"])) self.assertEqual(len(meta_load_dict), 1) self.assertTrue("meta" in meta_load_dict) self.assertEqual(meta_load_dict["meta"], "data") def test_save_and_load_gguf_metadata_arrays(self): if not os.path.isdir(self.test_dir): os.mkdir(self.test_dir) save_file_mlx = os.path.join(self.test_dir, f"mlx_gguf_with_metadata.gguf") save_dict = {"test": mx.ones((4, 4), dtype=mx.int32)} # Test scalars and one dimensional arrays for t in [ mx.uint8, mx.int8, mx.uint16, mx.int16, mx.uint32, mx.int32, mx.uint64, mx.int64, mx.float32, ]: for shape in [(), (2,)]: arr = mx.random.uniform(shape=shape).astype(t) metadata = {"meta": arr} mx.save_gguf(save_file_mlx, save_dict, metadata) _, meta_load_dict = mx.load(save_file_mlx, return_metadata=True) self.assertEqual(len(meta_load_dict), 1) self.assertTrue("meta" in meta_load_dict) self.assertTrue(mx.array_equal(meta_load_dict["meta"], arr)) self.assertEqual(meta_load_dict["meta"].dtype, arr.dtype) for t in [mx.float16, mx.bfloat16, mx.complex64]: with self.assertRaises(ValueError): arr = mx.array(1, t) metadata = {"meta": arr} mx.save_gguf(save_file_mlx, save_dict, metadata) def test_save_and_load_gguf_metadata_mixed(self): if not os.path.isdir(self.test_dir): os.mkdir(self.test_dir) save_file_mlx = os.path.join(self.test_dir, f"mlx_gguf_with_metadata.gguf") save_dict = {"test": mx.ones((4, 4), dtype=mx.int32)} # Test string and array arr = mx.array(1.5) metadata = {"meta1": arr, "meta2": "data"} mx.save_gguf(save_file_mlx, save_dict, metadata) _, meta_load_dict = mx.load(save_file_mlx, return_metadata=True) self.assertEqual(len(meta_load_dict), 2) self.assertTrue("meta1" in meta_load_dict) self.assertTrue(mx.array_equal(meta_load_dict["meta1"], arr)) self.assertEqual(meta_load_dict["meta1"].dtype, arr.dtype) self.assertTrue("meta2" in meta_load_dict) self.assertEqual(meta_load_dict["meta2"], "data") # Test list of strings metadata = {"meta": ["data1", "data2", "data345"]} mx.save_gguf(save_file_mlx, save_dict, metadata) _, meta_load_dict = mx.load(save_file_mlx, return_metadata=True) self.assertEqual(len(meta_load_dict), 1) self.assertEqual(meta_load_dict["meta"], metadata["meta"]) # Test a combination of stuff metadata = { "meta1": ["data1", "data2", "data345"], "meta2": mx.array([1, 2, 3, 4]), "meta3": "data", "meta4": mx.array(1.5), } mx.save_gguf(save_file_mlx, save_dict, metadata) _, meta_load_dict = mx.load(save_file_mlx, return_metadata=True) self.assertEqual(len(meta_load_dict), 4) for k, v in metadata.items(): if isinstance(v, mx.array): self.assertTrue(mx.array_equal(meta_load_dict[k], v)) else: self.assertEqual(meta_load_dict[k], v) def test_save_and_load_fs(self): if not os.path.isdir(self.test_dir): os.mkdir(self.test_dir) for dt in self.dtypes: with self.subTest(dtype=dt): for i, shape in enumerate([(1,), (23,), (1024, 1024), (4, 6, 3, 1, 2)]): with self.subTest(shape=shape): save_file_mlx = os.path.join( self.test_dir, f"mlx_{dt}_{i}_fs.npy" ) save_file_npy = os.path.join( self.test_dir, f"npy_{dt}_{i}_fs.npy" ) save_arr = np.random.uniform(0.0, 32.0, size=shape) save_arr_npy = save_arr.astype(getattr(np, dt)) save_arr_mlx = mx.array(save_arr_npy) with open(save_file_mlx, "wb") as f: mx.save(f, save_arr_mlx) np.save(save_file_npy, save_arr_npy) # Load array saved by mlx as mlx array with open(save_file_mlx, "rb") as f: load_arr_mlx_mlx = mx.load(f) self.assertTrue(mx.array_equal(load_arr_mlx_mlx, save_arr_mlx)) # Load array saved by numpy as mlx array with open(save_file_npy, "rb") as f: load_arr_npy_mlx = mx.load(f) self.assertTrue(mx.array_equal(load_arr_npy_mlx, save_arr_mlx)) # Load array saved by mlx as numpy array load_arr_mlx_npy = np.load(save_file_mlx) self.assertTrue(np.array_equal(load_arr_mlx_npy, save_arr_npy)) def test_savez_and_loadz(self): if not os.path.isdir(self.test_dir): os.mkdir(self.test_dir) for dt in self.dtypes: with self.subTest(dtype=dt): shapes = [(6,), (6, 6), (4, 1, 3, 1, 2)] save_file_mlx_uncomp = os.path.join( self.test_dir, f"mlx_{dt}_uncomp.npz" ) save_file_npy_uncomp = os.path.join( self.test_dir, f"npy_{dt}_uncomp.npz" ) save_file_mlx_comp = os.path.join(self.test_dir, f"mlx_{dt}_comp.npz") save_file_npy_comp = os.path.join(self.test_dir, f"npy_{dt}_comp.npz") # Make dictionary of multiple save_arrs_npy = { f"save_arr_{i}": np.random.uniform( 0.0, 32.0, size=shapes[i] ).astype(getattr(np, dt)) for i in range(len(shapes)) } save_arrs_mlx = {k: mx.array(v) for k, v in save_arrs_npy.items()} # Save as npz files np.savez(save_file_npy_uncomp, **save_arrs_npy) mx.savez(save_file_mlx_uncomp, **save_arrs_mlx) np.savez_compressed(save_file_npy_comp, **save_arrs_npy) mx.savez_compressed(save_file_mlx_comp, **save_arrs_mlx) for save_file_npy, save_file_mlx in ( (save_file_npy_uncomp, save_file_mlx_uncomp), (save_file_npy_comp, save_file_mlx_comp), ): # Load array saved by mlx as mlx array load_arr_mlx_mlx = mx.load(save_file_mlx) for k, v in load_arr_mlx_mlx.items(): self.assertTrue(mx.array_equal(save_arrs_mlx[k], v)) # Load arrays saved by numpy as mlx arrays load_arr_npy_mlx = mx.load(save_file_npy) for k, v in load_arr_npy_mlx.items(): self.assertTrue(mx.array_equal(save_arrs_mlx[k], v)) # Load array saved by mlx as numpy array load_arr_mlx_npy = np.load(save_file_mlx) for k, v in load_arr_mlx_npy.items(): self.assertTrue(np.array_equal(save_arrs_npy[k], v)) def test_non_contiguous(self): a = mx.broadcast_to(mx.array([1, 2]), [4, 2]) save_file = os.path.join(self.test_dir, "a.npy") mx.save(save_file, a) aload = mx.load(save_file) self.assertTrue(mx.array_equal(a, aload)) save_file = os.path.join(self.test_dir, "a.safetensors") mx.save_safetensors(save_file, {"a": a}) aload = mx.load(save_file)["a"] self.assertTrue(mx.array_equal(a, aload)) save_file = os.path.join(self.test_dir, "a.gguf") mx.save_gguf(save_file, {"a": a}) aload = mx.load(save_file)["a"] self.assertTrue(mx.array_equal(a, aload)) # safetensors and gguf only work with row contiguous # make sure col contiguous is handled properly save_file = os.path.join(self.test_dir, "a.safetensors") a = mx.arange(4).reshape(2, 2).T mx.save_safetensors(save_file, {"a": a}) aload = mx.load(save_file)["a"] self.assertTrue(mx.array_equal(a, aload)) save_file = os.path.join(self.test_dir, "a.gguf") mx.save_gguf(save_file, {"a": a}) aload = mx.load(save_file)["a"] self.assertTrue(mx.array_equal(a, aload)) def test_load_donation(self): x = mx.random.normal((1024,)) mx.eval(x) save_file = os.path.join(self.test_dir, "donation.npy") mx.save(save_file, x) mx.metal.reset_peak_memory() scale = mx.array(2.0) y = mx.load(save_file) mx.eval(y) load_only = mx.metal.get_peak_memory() y = mx.load(save_file) * scale mx.eval(y) load_with_binary = mx.metal.get_peak_memory() self.assertEqual(load_only, load_with_binary) if __name__ == "__main__": unittest.main()