Fix exporting with constants (#2769)

This commit is contained in:
Awni Hannun
2025-11-14 12:52:08 -08:00
committed by GitHub
parent 3b2ffcefc3
commit 27ff069175
3 changed files with 24 additions and 3 deletions

View File

@@ -575,6 +575,27 @@ class TestExportImport(mlx_tests.MLXTestCase):
out = imported(a)[0]
self.assertTrue(mx.allclose(expected, out))
def test_export_import_multi_with_constants(self):
path = os.path.join(self.test_dir, "fn.mlxfn")
def fun(y):
i = y.shape[0]
x = mx.array(i)
for j in range(10):
x = x + mx.array(i + j)
return x * y.sum()
ys = [mx.array([1]), mx.array([1, 1]), mx.array([1, 1, 1])]
with mx.exporter(path, fun) as exporter:
for y in ys:
exporter(y)
imported = mx.import_function(path)
for y in ys:
self.assertEqual(imported(y)[0].item(), fun(y).item())
if __name__ == "__main__":
mlx_tests.MLXTestRunner()