diff --git a/mlx/export.cpp b/mlx/export.cpp index effc7a0c1..c9139e156 100644 --- a/mlx/export.cpp +++ b/mlx/export.cpp @@ -470,6 +470,9 @@ bool FunctionTable::match( if (x.dtype() != y.dtype()) { return false; } + if (x.ndim() != y.ndim()) { + return false; + } if (!shapeless && x.shape() != y.shape()) { return false; } diff --git a/python/tests/test_export_import.py b/python/tests/test_export_import.py index 2b4b425ca..0190827bd 100644 --- a/python/tests/test_export_import.py +++ b/python/tests/test_export_import.py @@ -242,6 +242,7 @@ class TestExportImport(mlx_tests.MLXTestCase): def test_leaks(self): path = os.path.join(self.test_dir, "fn.mlxfn") + mx.synchronize() if mx.metal.is_available(): mem_pre = mx.get_active_memory() else: @@ -267,6 +268,24 @@ class TestExportImport(mlx_tests.MLXTestCase): self.assertEqual(mem_pre, mem_post) + def test_export_import_shapeless(self): + path = os.path.join(self.test_dir, "fn.mlxfn") + + def fun(*args): + return sum(args) + + with mx.exporter(path, fun, shapeless=True) as exporter: + exporter(mx.array(1)) + exporter(mx.array(1), mx.array(2)) + exporter(mx.array(1), mx.array(2), mx.array(3)) + + f2 = mx.import_function(path) + self.assertEqual(f2(mx.array(1))[0].item(), 1) + self.assertEqual(f2(mx.array(1), mx.array(1))[0].item(), 2) + self.assertEqual(f2(mx.array(1), mx.array(1), mx.array(1))[0].item(), 3) + with self.assertRaises(ValueError): + f2(mx.array(10), mx.array([5, 10, 20])) + if __name__ == "__main__": unittest.main()