fix conv export (#2265)

This commit is contained in:
Awni Hannun 2025-06-10 09:34:01 -07:00 committed by GitHub
parent 7c4eb5d03e
commit 62fecf3e13
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 35 additions and 1 deletions

View File

@ -719,9 +719,9 @@ class Convolution : public UnaryPrimitive {
bool is_equivalent(const Primitive& other) const override;
auto state() const {
return std::make_tuple(
kernel_strides_,
padding_lo_,
padding_hi_,
kernel_strides_,
kernel_dilation_,
input_dilation_,
groups_,

View File

@ -6,6 +6,7 @@ import tempfile
import unittest
import mlx.core as mx
import mlx.nn as nn
import mlx_tests
@ -312,6 +313,39 @@ class TestExportImport(mlx_tests.MLXTestCase):
out = imported_fun(x, y, z)[0]
self.assertTrue(mx.array_equal(expected, out))
def test_export_conv(self):
path = os.path.join(self.test_dir, "fn.mlxfn")
class Model(nn.Module):
def __init__(self):
super().__init__()
self.c1 = nn.Conv2d(
3, 16, kernel_size=3, stride=1, padding=1, bias=False
)
self.c2 = nn.Conv2d(
16, 16, kernel_size=3, stride=2, padding=1, bias=False
)
self.c3 = nn.Conv2d(
16, 16, kernel_size=3, stride=1, padding=2, bias=False
)
def __call__(self, x):
return self.c3(self.c2(self.c1(x)))
model = Model()
mx.eval(model.parameters())
def forward(x):
return model(x)
input_data = mx.random.normal(shape=(4, 32, 32, 3))
mx.export_function(path, forward, input_data)
imported_fn = mx.import_function(path)
out = imported_fn(input_data)[0]
expected = forward(input_data)
self.assertTrue(mx.allclose(expected, out))
if __name__ == "__main__":
unittest.main()