fix conv export (#2265)

This commit is contained in:
Awni Hannun
2025-06-10 09:34:01 -07:00
committed by GitHub
parent 7c4eb5d03e
commit 62fecf3e13
2 changed files with 35 additions and 1 deletions

View File

@@ -6,6 +6,7 @@ import tempfile
import unittest
import mlx.core as mx
import mlx.nn as nn
import mlx_tests
@@ -312,6 +313,39 @@ class TestExportImport(mlx_tests.MLXTestCase):
out = imported_fun(x, y, z)[0]
self.assertTrue(mx.array_equal(expected, out))
def test_export_conv(self):
path = os.path.join(self.test_dir, "fn.mlxfn")
class Model(nn.Module):
def __init__(self):
super().__init__()
self.c1 = nn.Conv2d(
3, 16, kernel_size=3, stride=1, padding=1, bias=False
)
self.c2 = nn.Conv2d(
16, 16, kernel_size=3, stride=2, padding=1, bias=False
)
self.c3 = nn.Conv2d(
16, 16, kernel_size=3, stride=1, padding=2, bias=False
)
def __call__(self, x):
return self.c3(self.c2(self.c1(x)))
model = Model()
mx.eval(model.parameters())
def forward(x):
return model(x)
input_data = mx.random.normal(shape=(4, 32, 32, 3))
mx.export_function(path, forward, input_data)
imported_fn = mx.import_function(path)
out = imported_fn(input_data)[0]
expected = forward(input_data)
self.assertTrue(mx.allclose(expected, out))
if __name__ == "__main__":
unittest.main()