From a982077f6e57a6a13f09d54b110f0ffb502e1b7c Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Wed, 8 Oct 2025 15:46:00 -0700 Subject: [PATCH] cleanup, test, fix --- mlx/export.cpp | 6 +++--- python/tests/test_export_import.py | 33 ++++++++++++++++++++++++++++++ 2 files changed, 36 insertions(+), 3 deletions(-) diff --git a/mlx/export.cpp b/mlx/export.cpp index c4981a28d..bfd1f9279 100644 --- a/mlx/export.cpp +++ b/mlx/export.cpp @@ -222,7 +222,6 @@ void extract_state(const T state, std::vector& unpacked_state) { } } -// std::vector extract_state(const Primitive& p) { template std::vector primitive_state(const Primitive& p) { std::vector state; @@ -628,8 +627,9 @@ void FunctionExporter::export_with_callback( // Callback on the inputs callback({{"type", "inputs"}, {"inputs", to_vector_data(inputs)}}); std::vector> keyword_inputs; - for (int i = inputs.size() - kwarg_keys.size(); i < inputs.size(); ++i) { - keyword_inputs.emplace_back(kwarg_keys[i], namer.get_name(inputs[i])); + for (int i = inputs.size() - kwarg_keys.size(), j = 0; i < inputs.size(); + ++i, ++j) { + keyword_inputs.emplace_back(kwarg_keys[j], namer.get_name(inputs[i])); } callback({{"type", "keyword_inputs"}, {"keywords", keyword_inputs}}); diff --git a/python/tests/test_export_import.py b/python/tests/test_export_import.py index 11789653c..4a4ca82a5 100644 --- a/python/tests/test_export_import.py +++ b/python/tests/test_export_import.py @@ -498,6 +498,39 @@ class TestExportImport(mlx_tests.MLXTestCase): out = imported(y=mx.array(2.0), x=mx.array(3.0))[0] self.assertEqual(out.item(), 1.0) + def test_export_with_callback(self): + + def fn(x, y): + return mx.log(mx.abs(x - y)) + + n_in = None + n_out = None + n_const = None + keywords = None + primitives = [] + + def callback(args): + nonlocal n_in, n_out, n_const, keywords, primitives + t = args["type"] + if t == "inputs": + n_in = len(args["inputs"]) + elif args["type"] == "outputs": + n_out = len(args["outputs"]) + elif args["type"] == "keyword_inputs": + keywords = args["keywords"] + elif t == "constants": + n_const = len(args["constants"]) + elif t == "primitive": + primitives.append(args["name"]) + + mx.export_function(callback, fn, mx.array(1.0), y=mx.array(1.0)) + self.assertEqual(n_in, 2) + self.assertEqual(n_out, 1) + self.assertEqual(n_const, 0) + self.assertEqual(len(keywords), 1) + self.assertEqual(keywords[0][0], "y") + self.assertEqual(primitives, ["Subtract", "Abs", "Log"]) + if __name__ == "__main__": mlx_tests.MLXTestRunner()