Use unordered map for kwargs in export/import (#2087)

* use unordered map for kwargs in export/import

* comment
This commit is contained in:
Awni Hannun 2025-04-21 07:17:22 -07:00 committed by GitHub
parent 70ebc3b598
commit dc4eada7f0
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 31 additions and 31 deletions

View File

@ -1,5 +1,6 @@
// Copyright © 2024 Apple Inc.
#include "mlx/export.h"
#include <map>
#include "mlx/compile_impl.h"
#include "mlx/fast_primitives.h"
#include "mlx/primitives.h"
@ -481,7 +482,9 @@ bool FunctionTable::match(
return false;
}
}
for (auto& [_, in] : kwargs) {
auto sorted_kwargs =
std::map<std::string, array>(kwargs.begin(), kwargs.end());
for (auto& [_, in] : sorted_kwargs) {
if (!match_inputs(in, fun.inputs[i++])) {
return false;
}
@ -557,7 +560,9 @@ void FunctionExporter::export_function(const Args& args, const Kwargs& kwargs) {
// Flatten the inputs to the function for tracing
std::vector<std::string> kwarg_keys;
auto inputs = args;
for (auto& [k, v] : kwargs) {
auto sorted_kwargs =
std::map<std::string, array>(kwargs.begin(), kwargs.end());
for (auto& [k, v] : sorted_kwargs) {
kwarg_keys.push_back(k);
inputs.push_back(v);
}

View File

@ -2,14 +2,14 @@
#pragma once
#include <map>
#include <set>
#include <unordered_map>
#include "mlx/array.h"
namespace mlx::core {
using Args = std::vector<array>;
using Kwargs = std::map<std::string, array>;
using Kwargs = std::unordered_map<std::string, array>;
struct FunctionExporter;

View File

@ -1,8 +1,8 @@
// Copyright © 2024 Apple Inc.
#include <nanobind/nanobind.h>
#include <nanobind/stl/map.h>
#include <nanobind/stl/optional.h>
#include <nanobind/stl/string.h>
#include <nanobind/stl/unordered_map.h>
#include <nanobind/stl/vector.h>
#include <fstream>
@ -16,8 +16,7 @@ namespace mx = mlx::core;
namespace nb = nanobind;
using namespace nb::literals;
std::pair<std::vector<mx::array>, std::map<std::string, mx::array>>
validate_and_extract_inputs(
std::pair<mx::Args, mx::Kwargs> validate_and_extract_inputs(
const nb::args& args,
const nb::kwargs& kwargs,
const std::string& prefix) {
@ -30,8 +29,8 @@ validate_and_extract_inputs(
"and/or dictionary of arrays.");
}
};
std::vector<mx::array> args_;
std::map<std::string, mx::array> kwargs_;
mx::Args args_;
mx::Kwargs kwargs_;
if (args.size() == 0) {
// No args so kwargs must be keyword arrays
maybe_throw(nb::try_cast(kwargs, kwargs_));
@ -81,9 +80,7 @@ class PyFunctionExporter {
void close() {
exporter_.close();
}
void operator()(
const std::vector<mx::array>& args,
const std::map<std::string, mx::array>& kwargs) {
void operator()(const mx::Args& args, const mx::Kwargs& kwargs) {
exporter_(args, kwargs);
}
@ -112,23 +109,22 @@ PyType_Slot py_function_exporter_slots[] = {
{0, 0}};
auto wrap_export_function(nb::callable fun) {
return [fun = std::move(fun)](
const std::vector<mx::array>& args_,
const std::map<std::string, mx::array>& kwargs_) {
auto kwargs = nb::dict();
kwargs.update(nb::cast(kwargs_));
auto args = nb::tuple(nb::cast(args_));
auto outputs = fun(*args, **kwargs);
std::vector<mx::array> outputs_;
if (nb::isinstance<mx::array>(outputs)) {
outputs_.push_back(nb::cast<mx::array>(outputs));
} else if (!nb::try_cast(outputs, outputs_)) {
throw std::invalid_argument(
"[export_function] Outputs can be either a single array "
"a tuple or list of arrays.");
}
return outputs_;
};
return
[fun = std::move(fun)](const mx::Args& args_, const mx::Kwargs& kwargs_) {
auto kwargs = nb::dict();
kwargs.update(nb::cast(kwargs_));
auto args = nb::tuple(nb::cast(args_));
auto outputs = fun(*args, **kwargs);
std::vector<mx::array> outputs_;
if (nb::isinstance<mx::array>(outputs)) {
outputs_.push_back(nb::cast<mx::array>(outputs));
} else if (!nb::try_cast(outputs, outputs_)) {
throw std::invalid_argument(
"[export_function] Outputs can be either a single array "
"a tuple or list of arrays.");
}
return outputs_;
};
}
void init_export(nb::module_& m) {

View File

@ -97,8 +97,7 @@ TEST_CASE("test export primitives with state") {
TEST_CASE("test export functions with kwargs") {
std::string file_path = get_temp_file("model.mlxfn");
auto fun =
[](const std::map<std::string, array>& kwargs) -> std::vector<array> {
auto fun = [](const Kwargs& kwargs) -> std::vector<array> {
return {kwargs.at("x") + kwargs.at("y")};
};