mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-15 09:29:26 +08:00
Fix exporting with constants (#2769)
This commit is contained in:
@@ -575,6 +575,27 @@ class TestExportImport(mlx_tests.MLXTestCase):
|
||||
out = imported(a)[0]
|
||||
self.assertTrue(mx.allclose(expected, out))
|
||||
|
||||
def test_export_import_multi_with_constants(self):
|
||||
|
||||
path = os.path.join(self.test_dir, "fn.mlxfn")
|
||||
|
||||
def fun(y):
|
||||
i = y.shape[0]
|
||||
x = mx.array(i)
|
||||
for j in range(10):
|
||||
x = x + mx.array(i + j)
|
||||
return x * y.sum()
|
||||
|
||||
ys = [mx.array([1]), mx.array([1, 1]), mx.array([1, 1, 1])]
|
||||
|
||||
with mx.exporter(path, fun) as exporter:
|
||||
for y in ys:
|
||||
exporter(y)
|
||||
|
||||
imported = mx.import_function(path)
|
||||
for y in ys:
|
||||
self.assertEqual(imported(y)[0].item(), fun(y).item())
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
mlx_tests.MLXTestRunner()
|
||||
|
||||
Reference in New Issue
Block a user