From ad0dd9b5baf576376d33c2bf43ef389319e3f22d Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Wed, 24 Sep 2025 15:50:53 -0700 Subject: [PATCH] Add types, fix kwarg ordering bug + test --- mlx/export.cpp | 56 ++++++++++++++++++------------ mlx/export.h | 1 + mlx/export_impl.h | 3 +- python/tests/test_export_import.py | 13 +++++++ 4 files changed, 50 insertions(+), 23 deletions(-) diff --git a/mlx/export.cpp b/mlx/export.cpp index d7f714c2e..c4981a28d 100644 --- a/mlx/export.cpp +++ b/mlx/export.cpp @@ -467,8 +467,10 @@ struct FunctionTable { }; bool shapeless; std::unordered_map> table; - Function* find(const Args& args, const Kwargs& kwargs); - std::pair emplace(const Args& args, const Kwargs& kwargs); + Function* find(const Args& args, const std::map& kwargs); + std::pair emplace( + const Args& args, + const std::map& kwargs); void insert( std::vector kwarg_keys, std::vector inputs, @@ -504,12 +506,15 @@ struct FunctionTable { } private: - bool match(const Args& args, const Kwargs& kwargs, const Function& fun); + bool match( + const Args& args, + const std::map& kwargs, + const Function& fun); }; bool FunctionTable::match( const Args& args, - const Kwargs& kwargs, + const std::map& kwargs, const Function& fun) { for (auto& k : fun.kwarg_keys) { if (kwargs.find(k) == kwargs.end()) { @@ -537,9 +542,7 @@ bool FunctionTable::match( return false; } } - auto sorted_kwargs = - std::map(kwargs.begin(), kwargs.end()); - for (auto& [_, in] : sorted_kwargs) { + for (auto& [_, in] : kwargs) { if (!match_inputs(in, fun.inputs[i++])) { return false; } @@ -550,7 +553,7 @@ bool FunctionTable::match( std::pair FunctionTable::emplace( const Args& args, - const Kwargs& kwargs) { + const std::map& kwargs) { auto n_inputs = args.size() + kwargs.size(); auto [it, _] = table.emplace(n_inputs, std::vector{}); auto& funs_vec = it->second; @@ -567,7 +570,7 @@ std::pair FunctionTable::emplace( FunctionTable::Function* FunctionTable::find( const Args& args, - const Kwargs& kwargs) { + const std::map& kwargs) { auto n_inputs = args.size() + kwargs.size(); auto it = table.find(n_inputs); if (it == table.end()) { @@ -611,7 +614,8 @@ void FunctionExporter::close() { void FunctionExporter::export_with_callback( const std::vector& inputs, const std::vector& outputs, - const std::vector& tape) { + const std::vector& tape, + const std::vector& kwarg_keys) { NodeNamer namer{}; auto to_vector_data = [&namer](const auto& arrays) { std::vector> data; @@ -622,10 +626,15 @@ void FunctionExporter::export_with_callback( }; // Callback on the inputs - callback({{"inputs", to_vector_data(inputs)}}); + callback({{"type", "inputs"}, {"inputs", to_vector_data(inputs)}}); + std::vector> keyword_inputs; + for (int i = inputs.size() - kwarg_keys.size(); i < inputs.size(); ++i) { + keyword_inputs.emplace_back(kwarg_keys[i], namer.get_name(inputs[i])); + } + callback({{"type", "keyword_inputs"}, {"keywords", keyword_inputs}}); // Callback on the outputs - callback({{"outputs", to_vector_data(outputs)}}); + callback({{"type", "outputs"}, {"outputs", to_vector_data(outputs)}}); // Callback on the constants { @@ -642,7 +651,7 @@ void FunctionExporter::export_with_callback( new_constants.emplace_back(namer.get_name(arr), arr); } } - callback({{"constants", new_constants}}); + callback({{"type", "constants"}, {"constants", new_constants}}); } auto factory = PrimitiveFactory(); @@ -653,10 +662,11 @@ void FunctionExporter::export_with_callback( } auto [name, state] = factory.extract_state(arr.primitive_ptr()); callback( - {{"inputs", to_vector_data(arr.inputs())}, + {{"type", "primitive"}, + {"inputs", to_vector_data(arr.inputs())}, {"outputs", to_vector_data(arr.outputs())}, - {"primitive", name}, - {"state", state}}); + {"name", name}, + {"arguments", state}}); } } @@ -665,7 +675,9 @@ void FunctionExporter::export_function(const Args& args, const Kwargs& kwargs) { throw std::runtime_error( "[export_function] Attempting to write after exporting is closed."); } - auto [fentry, inserted] = ftable->emplace(args, kwargs); + auto sorted_kwargs = + std::map(kwargs.begin(), kwargs.end()); + auto [fentry, inserted] = ftable->emplace(args, sorted_kwargs); if (!inserted) { throw std::runtime_error( "[export_function] Attempting to export a function twice with " @@ -675,8 +687,6 @@ void FunctionExporter::export_function(const Args& args, const Kwargs& kwargs) { // Flatten the inputs to the function for tracing std::vector kwarg_keys; auto inputs = args; - auto sorted_kwargs = - std::map(kwargs.begin(), kwargs.end()); for (auto& [k, v] : sorted_kwargs) { kwarg_keys.push_back(k); inputs.push_back(v); @@ -710,7 +720,7 @@ void FunctionExporter::export_function(const Args& args, const Kwargs& kwargs) { count++; if (callback) { - export_with_callback(trace_inputs, trace_outputs, tape); + export_with_callback(trace_inputs, trace_outputs, tape, kwarg_keys); return; } @@ -908,7 +918,9 @@ std::vector ImportedFunction::operator()(const Args& args) const { std::vector ImportedFunction::operator()( const Args& args, const Kwargs& kwargs) const { - auto* fun = ftable->find(args, kwargs); + auto sorted_kwargs = + std::map(kwargs.begin(), kwargs.end()); + auto* fun = ftable->find(args, sorted_kwargs); if (fun == nullptr) { std::ostringstream msg; msg << "[import_function::call] No imported function found which matches " @@ -927,7 +939,7 @@ std::vector ImportedFunction::operator()( } auto inputs = args; - for (auto& [_, v] : kwargs) { + for (auto& [_, v] : sorted_kwargs) { inputs.push_back(v); } return detail::compile_replace( diff --git a/mlx/export.h b/mlx/export.h index 934e5471c..715dac2c8 100644 --- a/mlx/export.h +++ b/mlx/export.h @@ -31,6 +31,7 @@ using ExportCallbackInput = std::unordered_map< std::variant< std::vector>, std::vector>, + std::vector>, std::vector, std::string>>; using ExportCallback = std::function; diff --git a/mlx/export_impl.h b/mlx/export_impl.h index 82756c23f..0e7818981 100644 --- a/mlx/export_impl.h +++ b/mlx/export_impl.h @@ -70,7 +70,8 @@ struct FunctionExporter { void export_with_callback( const std::vector& inputs, const std::vector& outputs, - const std::vector& tape); + const std::vector& tape, + const std::vector& kwarg_keys); std::set constants; int count{0}; bool closed{false}; diff --git a/python/tests/test_export_import.py b/python/tests/test_export_import.py index 1d8af8509..11789653c 100644 --- a/python/tests/test_export_import.py +++ b/python/tests/test_export_import.py @@ -485,6 +485,19 @@ class TestExportImport(mlx_tests.MLXTestCase): mx.array_equal(imported_fn(input_data)[0], model(input_data)) ) + def test_export_kwarg_ordering(self): + path = os.path.join(self.test_dir, "fun.mlxfn") + + def fn(x, y): + return x - y + + mx.export_function(path, fn, x=mx.array(1.0), y=mx.array(1.0)) + imported = mx.import_function(path) + out = imported(x=mx.array(2.0), y=mx.array(3.0))[0] + self.assertEqual(out.item(), -1.0) + out = imported(y=mx.array(2.0), x=mx.array(3.0))[0] + self.assertEqual(out.item(), 1.0) + if __name__ == "__main__": mlx_tests.MLXTestRunner()