mirror of
https://github.com/ml-explore/mlx.git
synced 2025-08-09 18:56:39 +08:00
Add tests for export including control flow models and quantized models (#2430)
* Add tests for export, including control flow export and quantized model export. * Skip quantization related test for CUDA backend.
This commit is contained in:
parent
da5912e4f2
commit
8b25ce62d5
@ -74,4 +74,5 @@ cuda_skip = {
|
||||
"TestQuantized.test_small_matrix",
|
||||
"TestQuantized.test_throw",
|
||||
"TestQuantized.test_vjp_scales_biases",
|
||||
"TestExportImport.test_export_quantized_model",
|
||||
}
|
||||
|
@ -346,6 +346,46 @@ class TestExportImport(mlx_tests.MLXTestCase):
|
||||
expected = forward(input_data)
|
||||
self.assertTrue(mx.allclose(expected, out))
|
||||
|
||||
def test_export_control_flow(self):
|
||||
|
||||
def fun(x, y):
|
||||
if y.shape[0] <= 2:
|
||||
return x + y
|
||||
else:
|
||||
return x + 2 * y
|
||||
|
||||
for y in (mx.array([1, 2, 3]), mx.array([1, 2])):
|
||||
for shapeless in (True, False):
|
||||
with self.subTest(y=y, shapeless=shapeless):
|
||||
x = mx.array(1)
|
||||
export_path = os.path.join(self.test_dir, "control_flow.mlxfn")
|
||||
mx.export_function(export_path, fun, x, y, shapeless=shapeless)
|
||||
|
||||
imported_fn = mx.import_function(export_path)
|
||||
self.assertTrue(mx.array_equal(imported_fn(x, y)[0], fun(x, y)))
|
||||
|
||||
def test_export_quantized_model(self):
|
||||
for shapeless in (True, False):
|
||||
with self.subTest(shapeless=shapeless):
|
||||
model = nn.Sequential(
|
||||
nn.Linear(1024, 512), nn.ReLU(), nn.Linear(512, 1024)
|
||||
)
|
||||
model.eval()
|
||||
mx.eval(model.parameters())
|
||||
input_data = mx.ones(shape=(512, 1024))
|
||||
nn.quantize(model)
|
||||
self.assertTrue(isinstance(model.layers[0], nn.QuantizedLinear))
|
||||
self.assertTrue(isinstance(model.layers[2], nn.QuantizedLinear))
|
||||
mx.eval(model.parameters())
|
||||
|
||||
export_path = os.path.join(self.test_dir, "quantized_linear.mlxfn")
|
||||
mx.export_function(export_path, model, input_data, shapeless=shapeless)
|
||||
|
||||
imported_fn = mx.import_function(export_path)
|
||||
self.assertTrue(
|
||||
mx.array_equal(imported_fn(input_data)[0], model(input_data))
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
mlx_tests.MLXTestRunner()
|
||||
|
Loading…
Reference in New Issue
Block a user