From 8b25ce62d585ebe74710691e99acfcd15d331ada Mon Sep 17 00:00:00 2001 From: junpeiz Date: Thu, 31 Jul 2025 11:06:26 -0700 Subject: [PATCH] Add tests for export including control flow models and quantized models (#2430) * Add tests for export, including control flow export and quantized model export. * Skip quantization related test for CUDA backend. --- python/tests/cuda_skip.py | 1 + python/tests/test_export_import.py | 40 ++++++++++++++++++++++++++++++ 2 files changed, 41 insertions(+) diff --git a/python/tests/cuda_skip.py b/python/tests/cuda_skip.py index 5bb465e1eb..e14aa675e0 100644 --- a/python/tests/cuda_skip.py +++ b/python/tests/cuda_skip.py @@ -74,4 +74,5 @@ cuda_skip = { "TestQuantized.test_small_matrix", "TestQuantized.test_throw", "TestQuantized.test_vjp_scales_biases", + "TestExportImport.test_export_quantized_model", } diff --git a/python/tests/test_export_import.py b/python/tests/test_export_import.py index 099be0cc02..71fb5b27ff 100644 --- a/python/tests/test_export_import.py +++ b/python/tests/test_export_import.py @@ -346,6 +346,46 @@ class TestExportImport(mlx_tests.MLXTestCase): expected = forward(input_data) self.assertTrue(mx.allclose(expected, out)) + def test_export_control_flow(self): + + def fun(x, y): + if y.shape[0] <= 2: + return x + y + else: + return x + 2 * y + + for y in (mx.array([1, 2, 3]), mx.array([1, 2])): + for shapeless in (True, False): + with self.subTest(y=y, shapeless=shapeless): + x = mx.array(1) + export_path = os.path.join(self.test_dir, "control_flow.mlxfn") + mx.export_function(export_path, fun, x, y, shapeless=shapeless) + + imported_fn = mx.import_function(export_path) + self.assertTrue(mx.array_equal(imported_fn(x, y)[0], fun(x, y))) + + def test_export_quantized_model(self): + for shapeless in (True, False): + with self.subTest(shapeless=shapeless): + model = nn.Sequential( + nn.Linear(1024, 512), nn.ReLU(), nn.Linear(512, 1024) + ) + model.eval() + mx.eval(model.parameters()) + input_data = mx.ones(shape=(512, 1024)) + nn.quantize(model) + self.assertTrue(isinstance(model.layers[0], nn.QuantizedLinear)) + self.assertTrue(isinstance(model.layers[2], nn.QuantizedLinear)) + mx.eval(model.parameters()) + + export_path = os.path.join(self.test_dir, "quantized_linear.mlxfn") + mx.export_function(export_path, model, input_data, shapeless=shapeless) + + imported_fn = mx.import_function(export_path) + self.assertTrue( + mx.array_equal(imported_fn(input_data)[0], model(input_data)) + ) + if __name__ == "__main__": mlx_tests.MLXTestRunner()