mirror of
				https://github.com/ml-explore/mlx.git
				synced 2025-10-31 16:21:27 +08:00 
			
		
		
		
	Export / import functions to / from a file (#1642)
* export and import functions * refactor + works for few primitives * nit * allow primitives with state * nit * nit * simplify serialize / deserialize * fix for constants * python bindings * maybe fix serialize failure case * add example * more primitives, training kind of works * same result for python and c++ * some fixes * fix export * template it up * some simplificatoin * rebase * allow kwargs and multiple functions * exporter * more primitives for exporting * deal with endianness * handle invalid stream * add docstring
This commit is contained in:
		
							
								
								
									
										244
									
								
								python/tests/test_export_import.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										244
									
								
								python/tests/test_export_import.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,244 @@ | ||||
| # Copyright © 2024 Apple Inc. | ||||
|  | ||||
| import os | ||||
| import tempfile | ||||
| import unittest | ||||
|  | ||||
| import mlx.core as mx | ||||
| import mlx_tests | ||||
|  | ||||
|  | ||||
| class TestExportImport(mlx_tests.MLXTestCase): | ||||
|  | ||||
|     @classmethod | ||||
|     def setUpClass(cls): | ||||
|         cls.test_dir_fid = tempfile.TemporaryDirectory() | ||||
|         cls.test_dir = cls.test_dir_fid.name | ||||
|         if not os.path.isdir(cls.test_dir): | ||||
|             os.mkdir(cls.test_dir) | ||||
|  | ||||
|     @classmethod | ||||
|     def tearDownClass(cls): | ||||
|         cls.test_dir_fid.cleanup() | ||||
|  | ||||
|     def test_basic_export_import(self): | ||||
|         path = os.path.join(self.test_dir, "fn.mlxfn") | ||||
|  | ||||
|         # Function with no inputs | ||||
|         def fun(): | ||||
|             return mx.zeros((3, 3)) | ||||
|  | ||||
|         mx.export_function(path, fun) | ||||
|         imported = mx.import_function(path) | ||||
|  | ||||
|         expected = fun() | ||||
|         (out,) = imported() | ||||
|         self.assertTrue(mx.array_equal(out, expected)) | ||||
|  | ||||
|         # Simple function with inputs | ||||
|         def fun(x): | ||||
|             return mx.abs(mx.sin(x)) | ||||
|  | ||||
|         inputs = mx.array([1.0, 2.0, 3.0, 4.0, 5.0]) | ||||
|  | ||||
|         mx.export_function(path, fun, inputs) | ||||
|         imported = mx.import_function(path) | ||||
|  | ||||
|         expected = fun(inputs) | ||||
|         (out,) = imported(inputs) | ||||
|         self.assertTrue(mx.allclose(out, expected)) | ||||
|  | ||||
|         # Inputs in a list or tuple | ||||
|         def fun(x): | ||||
|             x = mx.abs(mx.sin(x)) | ||||
|             return x | ||||
|  | ||||
|         mx.export_function(path, fun, [inputs]) | ||||
|         imported = mx.import_function(path) | ||||
|  | ||||
|         expected = fun(inputs) | ||||
|         (out,) = imported([inputs]) | ||||
|         self.assertTrue(mx.allclose(out, expected)) | ||||
|  | ||||
|         (out,) = imported(inputs) | ||||
|         self.assertTrue(mx.allclose(out, expected)) | ||||
|  | ||||
|         mx.export_function(path, fun, (inputs,)) | ||||
|         imported = mx.import_function(path) | ||||
|         (out,) = imported((inputs,)) | ||||
|         self.assertTrue(mx.allclose(out, expected)) | ||||
|  | ||||
|         # Outputs in a list | ||||
|         def fun(x): | ||||
|             return [mx.abs(mx.sin(x))] | ||||
|  | ||||
|         mx.export_function(path, fun, inputs) | ||||
|         imported = mx.import_function(path) | ||||
|         (out,) = imported(inputs) | ||||
|         self.assertTrue(mx.allclose(out, expected)) | ||||
|  | ||||
|         # Outputs in a tuple | ||||
|         def fun(x): | ||||
|             return (mx.abs(mx.sin(x)),) | ||||
|  | ||||
|         mx.export_function(path, fun, inputs) | ||||
|         imported = mx.import_function(path) | ||||
|         (out,) = imported(inputs) | ||||
|         self.assertTrue(mx.allclose(out, expected)) | ||||
|  | ||||
|         # Check throws on invalid inputs / outputs | ||||
|         def fun(x): | ||||
|             return mx.abs(x) | ||||
|  | ||||
|         with self.assertRaises(ValueError): | ||||
|             mx.export_function(path, fun, "hi") | ||||
|  | ||||
|         with self.assertRaises(ValueError): | ||||
|             mx.export_function(path, fun, mx.array(1.0), "hi") | ||||
|  | ||||
|         def fun(x): | ||||
|             return mx.abs(x[0][0]) | ||||
|  | ||||
|         with self.assertRaises(ValueError): | ||||
|             mx.export_function(path, fun, [[mx.array(1.0)]]) | ||||
|  | ||||
|         def fun(): | ||||
|             return (mx.zeros((3, 3)), 1) | ||||
|  | ||||
|         with self.assertRaises(ValueError): | ||||
|             mx.export_function(path, fun) | ||||
|  | ||||
|         def fun(): | ||||
|             return (mx.zeros((3, 3)), [mx.zeros((3, 3))]) | ||||
|  | ||||
|         with self.assertRaises(ValueError): | ||||
|             mx.export_function(path, fun) | ||||
|  | ||||
|         def fun(x, y): | ||||
|             return x + y | ||||
|  | ||||
|         mx.export_function(path, fun, mx.array(1.0), mx.array(1.0)) | ||||
|         imported = mx.import_function(path) | ||||
|  | ||||
|         with self.assertRaises(ValueError): | ||||
|             imported(mx.array(1.0), 1.0) | ||||
|  | ||||
|         with self.assertRaises(ValueError): | ||||
|             imported(mx.array(1.0), mx.array(1.0), mx.array(1.0)) | ||||
|  | ||||
|         with self.assertRaises(ValueError): | ||||
|             imported(mx.array(1.0), [mx.array(1.0)]) | ||||
|  | ||||
|     def test_export_random_sample(self): | ||||
|         path = os.path.join(self.test_dir, "fn.mlxfn") | ||||
|  | ||||
|         mx.random.seed(5) | ||||
|  | ||||
|         def fun(): | ||||
|             return mx.random.uniform(shape=(3,)) | ||||
|  | ||||
|         mx.export_function(path, fun) | ||||
|         imported = mx.import_function(path) | ||||
|  | ||||
|         (out,) = imported() | ||||
|  | ||||
|         mx.random.seed(5) | ||||
|         expected = fun() | ||||
|  | ||||
|         self.assertTrue(mx.array_equal(out, expected)) | ||||
|  | ||||
|     def test_export_with_kwargs(self): | ||||
|         path = os.path.join(self.test_dir, "fn.mlxfn") | ||||
|  | ||||
|         def fun(x, z=None): | ||||
|             out = x | ||||
|             if z is not None: | ||||
|                 out += z | ||||
|             return out | ||||
|  | ||||
|         x = mx.array([1, 2, 3]) | ||||
|         y = mx.array([1, 1, 0]) | ||||
|         z = mx.array([2, 2, 2]) | ||||
|  | ||||
|         mx.export_function(path, fun, (x,), {"z": z}) | ||||
|         imported_fun = mx.import_function(path) | ||||
|  | ||||
|         with self.assertRaises(ValueError): | ||||
|             imported_fun(x, z) | ||||
|  | ||||
|         with self.assertRaises(ValueError): | ||||
|             imported_fun(x, y=z) | ||||
|  | ||||
|         with self.assertRaises(ValueError): | ||||
|             imported_fun((x,), {"y": z}) | ||||
|  | ||||
|         out = imported_fun(x, z=z)[0] | ||||
|         self.assertTrue(mx.array_equal(out, mx.array([3, 4, 5]))) | ||||
|  | ||||
|         out = imported_fun((x,), {"z": z})[0] | ||||
|         self.assertTrue(mx.array_equal(out, mx.array([3, 4, 5]))) | ||||
|  | ||||
|         mx.export_function(path, fun, x, z=z) | ||||
|         imported_fun = mx.import_function(path) | ||||
|         out = imported_fun(x, z=z)[0] | ||||
|         self.assertTrue(mx.array_equal(out, mx.array([3, 4, 5]))) | ||||
|  | ||||
|         out = imported_fun((x,), {"z": z})[0] | ||||
|         self.assertTrue(mx.array_equal(out, mx.array([3, 4, 5]))) | ||||
|  | ||||
|         # Only specify kwargs | ||||
|         mx.export_function(path, fun, x=x, z=z) | ||||
|         imported_fun = mx.import_function(path) | ||||
|         with self.assertRaises(ValueError): | ||||
|             out = imported_fun(x, z=z)[0] | ||||
|  | ||||
|         out = imported_fun(x=x, z=z)[0] | ||||
|         self.assertTrue(mx.array_equal(out, mx.array([3, 4, 5]))) | ||||
|  | ||||
|         out = imported_fun({"x": x, "z": z})[0] | ||||
|         self.assertTrue(mx.array_equal(out, mx.array([3, 4, 5]))) | ||||
|  | ||||
|     def test_export_variable_inputs(self): | ||||
|         path = os.path.join(self.test_dir, "fn.mlxfn") | ||||
|  | ||||
|         def fun(x, y, z=None): | ||||
|             out = x + y | ||||
|             if z is not None: | ||||
|                 out += z | ||||
|             return out | ||||
|  | ||||
|         with mx.exporter(path, fun) as exporter: | ||||
|             exporter(mx.array([1, 2, 3]), mx.array([1, 1, 1])) | ||||
|             exporter(mx.array([1, 2, 3]), mx.array([1, 1, 1]), z=mx.array([2])) | ||||
|  | ||||
|         with self.assertRaises(RuntimeError): | ||||
|             exporter(mx.array([1, 2, 3, 4]), mx.array([1, 1, 1, 1])) | ||||
|  | ||||
|         imported_fun = mx.import_function(path) | ||||
|         out = imported_fun(mx.array([1, 2, 3]), mx.array([1, 1, 1]))[0] | ||||
|         self.assertTrue(mx.array_equal(out, mx.array([2, 3, 4]))) | ||||
|  | ||||
|         out = imported_fun(mx.array([1, 2, 3]), mx.array([1, 1, 1]), z=mx.array([2]))[0] | ||||
|         self.assertTrue(mx.array_equal(out, mx.array([4, 5, 6]))) | ||||
|  | ||||
|         with self.assertRaises(ValueError): | ||||
|             imported_fun(mx.array([1, 2, 3, 4]), mx.array([1, 1, 1, 1])) | ||||
|  | ||||
|         # A function with a large constant | ||||
|         constant = mx.zeros((16, 2048)) | ||||
|         mx.eval(constant) | ||||
|  | ||||
|         def fun(*args): | ||||
|             return constant + sum(args) | ||||
|  | ||||
|         with mx.exporter(path, fun) as exporter: | ||||
|             for i in range(5): | ||||
|                 exporter(*[mx.array(1)] * i) | ||||
|  | ||||
|         # Check the exported file size < constant size + small amount | ||||
|         constants_size = constant.nbytes + 8192 | ||||
|         self.assertTrue(os.path.getsize(path) < constants_size) | ||||
|  | ||||
|  | ||||
| if __name__ == "__main__": | ||||
|     unittest.main() | ||||
| @@ -28,15 +28,14 @@ class TestLoad(mlx_tests.MLXTestCase): | ||||
|     def setUpClass(cls): | ||||
|         cls.test_dir_fid = tempfile.TemporaryDirectory() | ||||
|         cls.test_dir = cls.test_dir_fid.name | ||||
|         if not os.path.isdir(cls.test_dir): | ||||
|             os.mkdir(cls.test_dir) | ||||
|  | ||||
|     @classmethod | ||||
|     def tearDownClass(cls): | ||||
|         cls.test_dir_fid.cleanup() | ||||
|  | ||||
|     def test_save_and_load(self): | ||||
|         if not os.path.isdir(self.test_dir): | ||||
|             os.mkdir(self.test_dir) | ||||
|  | ||||
|         for dt in self.dtypes: | ||||
|             with self.subTest(dtype=dt): | ||||
|                 for i, shape in enumerate([(1,), (23,), (1024, 1024), (4, 6, 3, 1, 2)]): | ||||
| @@ -64,9 +63,6 @@ class TestLoad(mlx_tests.MLXTestCase): | ||||
|                         self.assertTrue(np.array_equal(load_arr_mlx_npy, save_arr_npy)) | ||||
|  | ||||
|     def test_save_and_load_safetensors(self): | ||||
|         if not os.path.isdir(self.test_dir): | ||||
|             os.mkdir(self.test_dir) | ||||
|  | ||||
|         test_file = os.path.join(self.test_dir, "test.safetensors") | ||||
|         with self.assertRaises(Exception): | ||||
|             mx.save_safetensors(test_file, {"a": mx.ones((4, 4))}, {"testing": 0}) | ||||
| @@ -330,9 +326,6 @@ class TestLoad(mlx_tests.MLXTestCase): | ||||
|                         self.assertTrue(np.array_equal(save_arrs_npy[k], v)) | ||||
|  | ||||
|     def test_non_contiguous(self): | ||||
|         if not os.path.isdir(self.test_dir): | ||||
|             os.mkdir(self.test_dir) | ||||
|  | ||||
|         a = mx.broadcast_to(mx.array([1, 2]), [4, 2]) | ||||
|  | ||||
|         save_file = os.path.join(self.test_dir, "a.npy") | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 Awni Hannun
					Awni Hannun