diff --git a/mlx/export.cpp b/mlx/export.cpp index 51a04a59f..effc7a0c1 100644 --- a/mlx/export.cpp +++ b/mlx/export.cpp @@ -1,5 +1,6 @@ // Copyright © 2024 Apple Inc. #include "mlx/export.h" +#include #include "mlx/compile_impl.h" #include "mlx/fast_primitives.h" #include "mlx/primitives.h" @@ -481,7 +482,9 @@ bool FunctionTable::match( return false; } } - for (auto& [_, in] : kwargs) { + auto sorted_kwargs = + std::map(kwargs.begin(), kwargs.end()); + for (auto& [_, in] : sorted_kwargs) { if (!match_inputs(in, fun.inputs[i++])) { return false; } @@ -557,7 +560,9 @@ void FunctionExporter::export_function(const Args& args, const Kwargs& kwargs) { // Flatten the inputs to the function for tracing std::vector kwarg_keys; auto inputs = args; - for (auto& [k, v] : kwargs) { + auto sorted_kwargs = + std::map(kwargs.begin(), kwargs.end()); + for (auto& [k, v] : sorted_kwargs) { kwarg_keys.push_back(k); inputs.push_back(v); } diff --git a/mlx/export.h b/mlx/export.h index da090510b..c6859c6d8 100644 --- a/mlx/export.h +++ b/mlx/export.h @@ -2,14 +2,14 @@ #pragma once -#include #include +#include #include "mlx/array.h" namespace mlx::core { using Args = std::vector; -using Kwargs = std::map; +using Kwargs = std::unordered_map; struct FunctionExporter; diff --git a/python/src/export.cpp b/python/src/export.cpp index 0f3bbc1b6..30062ae37 100644 --- a/python/src/export.cpp +++ b/python/src/export.cpp @@ -1,8 +1,8 @@ // Copyright © 2024 Apple Inc. #include -#include #include #include +#include #include #include @@ -16,8 +16,7 @@ namespace mx = mlx::core; namespace nb = nanobind; using namespace nb::literals; -std::pair, std::map> -validate_and_extract_inputs( +std::pair validate_and_extract_inputs( const nb::args& args, const nb::kwargs& kwargs, const std::string& prefix) { @@ -30,8 +29,8 @@ validate_and_extract_inputs( "and/or dictionary of arrays."); } }; - std::vector args_; - std::map kwargs_; + mx::Args args_; + mx::Kwargs kwargs_; if (args.size() == 0) { // No args so kwargs must be keyword arrays maybe_throw(nb::try_cast(kwargs, kwargs_)); @@ -81,9 +80,7 @@ class PyFunctionExporter { void close() { exporter_.close(); } - void operator()( - const std::vector& args, - const std::map& kwargs) { + void operator()(const mx::Args& args, const mx::Kwargs& kwargs) { exporter_(args, kwargs); } @@ -112,23 +109,22 @@ PyType_Slot py_function_exporter_slots[] = { {0, 0}}; auto wrap_export_function(nb::callable fun) { - return [fun = std::move(fun)]( - const std::vector& args_, - const std::map& kwargs_) { - auto kwargs = nb::dict(); - kwargs.update(nb::cast(kwargs_)); - auto args = nb::tuple(nb::cast(args_)); - auto outputs = fun(*args, **kwargs); - std::vector outputs_; - if (nb::isinstance(outputs)) { - outputs_.push_back(nb::cast(outputs)); - } else if (!nb::try_cast(outputs, outputs_)) { - throw std::invalid_argument( - "[export_function] Outputs can be either a single array " - "a tuple or list of arrays."); - } - return outputs_; - }; + return + [fun = std::move(fun)](const mx::Args& args_, const mx::Kwargs& kwargs_) { + auto kwargs = nb::dict(); + kwargs.update(nb::cast(kwargs_)); + auto args = nb::tuple(nb::cast(args_)); + auto outputs = fun(*args, **kwargs); + std::vector outputs_; + if (nb::isinstance(outputs)) { + outputs_.push_back(nb::cast(outputs)); + } else if (!nb::try_cast(outputs, outputs_)) { + throw std::invalid_argument( + "[export_function] Outputs can be either a single array " + "a tuple or list of arrays."); + } + return outputs_; + }; } void init_export(nb::module_& m) { diff --git a/tests/export_import_tests.cpp b/tests/export_import_tests.cpp index 83ee1e590..7ad2c640d 100644 --- a/tests/export_import_tests.cpp +++ b/tests/export_import_tests.cpp @@ -97,8 +97,7 @@ TEST_CASE("test export primitives with state") { TEST_CASE("test export functions with kwargs") { std::string file_path = get_temp_file("model.mlxfn"); - auto fun = - [](const std::map& kwargs) -> std::vector { + auto fun = [](const Kwargs& kwargs) -> std::vector { return {kwargs.at("x") + kwargs.at("y")}; };