mirror of
				https://github.com/ml-explore/mlx.git
				synced 2025-10-31 16:21:27 +08:00 
			
		
		
		
	fix shapeless export (#2148)
This commit is contained in:
		| @@ -242,6 +242,7 @@ class TestExportImport(mlx_tests.MLXTestCase): | ||||
|  | ||||
|     def test_leaks(self): | ||||
|         path = os.path.join(self.test_dir, "fn.mlxfn") | ||||
|         mx.synchronize() | ||||
|         if mx.metal.is_available(): | ||||
|             mem_pre = mx.get_active_memory() | ||||
|         else: | ||||
| @@ -267,6 +268,24 @@ class TestExportImport(mlx_tests.MLXTestCase): | ||||
|  | ||||
|         self.assertEqual(mem_pre, mem_post) | ||||
|  | ||||
|     def test_export_import_shapeless(self): | ||||
|         path = os.path.join(self.test_dir, "fn.mlxfn") | ||||
|  | ||||
|         def fun(*args): | ||||
|             return sum(args) | ||||
|  | ||||
|         with mx.exporter(path, fun, shapeless=True) as exporter: | ||||
|             exporter(mx.array(1)) | ||||
|             exporter(mx.array(1), mx.array(2)) | ||||
|             exporter(mx.array(1), mx.array(2), mx.array(3)) | ||||
|  | ||||
|         f2 = mx.import_function(path) | ||||
|         self.assertEqual(f2(mx.array(1))[0].item(), 1) | ||||
|         self.assertEqual(f2(mx.array(1), mx.array(1))[0].item(), 2) | ||||
|         self.assertEqual(f2(mx.array(1), mx.array(1), mx.array(1))[0].item(), 3) | ||||
|         with self.assertRaises(ValueError): | ||||
|             f2(mx.array(10), mx.array([5, 10, 20])) | ||||
|  | ||||
|  | ||||
| if __name__ == "__main__": | ||||
|     unittest.main() | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 Awni Hannun
					Awni Hannun