From 62fecf3e13c8f12ee532b56079b073d2afcfa9ef Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Tue, 10 Jun 2025 09:34:01 -0700 Subject: [PATCH] fix conv export (#2265) --- mlx/primitives.h | 2 +- python/tests/test_export_import.py | 34 ++++++++++++++++++++++++++++++ 2 files changed, 35 insertions(+), 1 deletion(-) diff --git a/mlx/primitives.h b/mlx/primitives.h index cc60bcfb9..4b18430ca 100644 --- a/mlx/primitives.h +++ b/mlx/primitives.h @@ -719,9 +719,9 @@ class Convolution : public UnaryPrimitive { bool is_equivalent(const Primitive& other) const override; auto state() const { return std::make_tuple( + kernel_strides_, padding_lo_, padding_hi_, - kernel_strides_, kernel_dilation_, input_dilation_, groups_, diff --git a/python/tests/test_export_import.py b/python/tests/test_export_import.py index ef9827cbe..0fd8bfd87 100644 --- a/python/tests/test_export_import.py +++ b/python/tests/test_export_import.py @@ -6,6 +6,7 @@ import tempfile import unittest import mlx.core as mx +import mlx.nn as nn import mlx_tests @@ -312,6 +313,39 @@ class TestExportImport(mlx_tests.MLXTestCase): out = imported_fun(x, y, z)[0] self.assertTrue(mx.array_equal(expected, out)) + def test_export_conv(self): + path = os.path.join(self.test_dir, "fn.mlxfn") + + class Model(nn.Module): + def __init__(self): + super().__init__() + self.c1 = nn.Conv2d( + 3, 16, kernel_size=3, stride=1, padding=1, bias=False + ) + self.c2 = nn.Conv2d( + 16, 16, kernel_size=3, stride=2, padding=1, bias=False + ) + self.c3 = nn.Conv2d( + 16, 16, kernel_size=3, stride=1, padding=2, bias=False + ) + + def __call__(self, x): + return self.c3(self.c2(self.c1(x))) + + model = Model() + mx.eval(model.parameters()) + + def forward(x): + return model(x) + + input_data = mx.random.normal(shape=(4, 32, 32, 3)) + mx.export_function(path, forward, input_data) + + imported_fn = mx.import_function(path) + out = imported_fn(input_data)[0] + expected = forward(input_data) + self.assertTrue(mx.allclose(expected, out)) + if __name__ == "__main__": unittest.main()