Export with callback (#2612)

* export with callback

* export with callback

* Add types, fix kwarg ordering bug + test

* cleanup, test, fix

* typos
This commit is contained in:
Awni Hannun
2025-10-08 19:24:33 -07:00
committed by GitHub
parent 85a8824a8c
commit e89e8b4272
10 changed files with 370 additions and 33 deletions

View File

@@ -319,7 +319,7 @@ class TestCompile(mlx_tests.MLXTestCase):
# Check the state is unchanged
self.assertEqual(state["y"], 2)
# Check the udpated state is used
# Check the updated state is used
state["y"] = mx.array(3)
out = test_state(mx.array(1))
self.assertEqual(out.item(), 4)

View File

@@ -485,6 +485,52 @@ class TestExportImport(mlx_tests.MLXTestCase):
mx.array_equal(imported_fn(input_data)[0], model(input_data))
)
def test_export_kwarg_ordering(self):
path = os.path.join(self.test_dir, "fun.mlxfn")
def fn(x, y):
return x - y
mx.export_function(path, fn, x=mx.array(1.0), y=mx.array(1.0))
imported = mx.import_function(path)
out = imported(x=mx.array(2.0), y=mx.array(3.0))[0]
self.assertEqual(out.item(), -1.0)
out = imported(y=mx.array(2.0), x=mx.array(3.0))[0]
self.assertEqual(out.item(), 1.0)
def test_export_with_callback(self):
def fn(x, y):
return mx.log(mx.abs(x - y))
n_in = None
n_out = None
n_const = None
keywords = None
primitives = []
def callback(args):
nonlocal n_in, n_out, n_const, keywords, primitives
t = args["type"]
if t == "inputs":
n_in = len(args["inputs"])
elif args["type"] == "outputs":
n_out = len(args["outputs"])
elif args["type"] == "keyword_inputs":
keywords = args["keywords"]
elif t == "constants":
n_const = len(args["constants"])
elif t == "primitive":
primitives.append(args["name"])
mx.export_function(callback, fn, mx.array(1.0), y=mx.array(1.0))
self.assertEqual(n_in, 2)
self.assertEqual(n_out, 1)
self.assertEqual(n_const, 0)
self.assertEqual(len(keywords), 1)
self.assertEqual(keywords[0][0], "y")
self.assertEqual(primitives, ["Subtract", "Abs", "Log"])
if __name__ == "__main__":
mlx_tests.MLXTestRunner()