Add tests for export including control flow models and quantized models (#2430)

* Add tests for export, including control flow export and quantized model export.

* Skip quantization related test for CUDA backend.
This commit is contained in:
junpeiz 2025-07-31 11:06:26 -07:00 committed by GitHub
parent da5912e4f2
commit 8b25ce62d5
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 41 additions and 0 deletions

View File

@ -74,4 +74,5 @@ cuda_skip = {
"TestQuantized.test_small_matrix", "TestQuantized.test_small_matrix",
"TestQuantized.test_throw", "TestQuantized.test_throw",
"TestQuantized.test_vjp_scales_biases", "TestQuantized.test_vjp_scales_biases",
"TestExportImport.test_export_quantized_model",
} }

View File

@ -346,6 +346,46 @@ class TestExportImport(mlx_tests.MLXTestCase):
expected = forward(input_data) expected = forward(input_data)
self.assertTrue(mx.allclose(expected, out)) self.assertTrue(mx.allclose(expected, out))
def test_export_control_flow(self):
def fun(x, y):
if y.shape[0] <= 2:
return x + y
else:
return x + 2 * y
for y in (mx.array([1, 2, 3]), mx.array([1, 2])):
for shapeless in (True, False):
with self.subTest(y=y, shapeless=shapeless):
x = mx.array(1)
export_path = os.path.join(self.test_dir, "control_flow.mlxfn")
mx.export_function(export_path, fun, x, y, shapeless=shapeless)
imported_fn = mx.import_function(export_path)
self.assertTrue(mx.array_equal(imported_fn(x, y)[0], fun(x, y)))
def test_export_quantized_model(self):
for shapeless in (True, False):
with self.subTest(shapeless=shapeless):
model = nn.Sequential(
nn.Linear(1024, 512), nn.ReLU(), nn.Linear(512, 1024)
)
model.eval()
mx.eval(model.parameters())
input_data = mx.ones(shape=(512, 1024))
nn.quantize(model)
self.assertTrue(isinstance(model.layers[0], nn.QuantizedLinear))
self.assertTrue(isinstance(model.layers[2], nn.QuantizedLinear))
mx.eval(model.parameters())
export_path = os.path.join(self.test_dir, "quantized_linear.mlxfn")
mx.export_function(export_path, model, input_data, shapeless=shapeless)
imported_fn = mx.import_function(export_path)
self.assertTrue(
mx.array_equal(imported_fn(input_data)[0], model(input_data))
)
if __name__ == "__main__": if __name__ == "__main__":
mlx_tests.MLXTestRunner() mlx_tests.MLXTestRunner()