Use unordered map for kwargs in export/import (#2087)

* use unordered map for kwargs in export/import

* comment
This commit is contained in:
Awni Hannun
2025-04-21 07:17:22 -07:00
committed by GitHub
parent 70ebc3b598
commit dc4eada7f0
4 changed files with 31 additions and 31 deletions

View File

@@ -1,8 +1,8 @@
// Copyright © 2024 Apple Inc.
#include <nanobind/nanobind.h>
#include <nanobind/stl/map.h>
#include <nanobind/stl/optional.h>
#include <nanobind/stl/string.h>
#include <nanobind/stl/unordered_map.h>
#include <nanobind/stl/vector.h>
#include <fstream>
@@ -16,8 +16,7 @@ namespace mx = mlx::core;
namespace nb = nanobind;
using namespace nb::literals;
std::pair<std::vector<mx::array>, std::map<std::string, mx::array>>
validate_and_extract_inputs(
std::pair<mx::Args, mx::Kwargs> validate_and_extract_inputs(
const nb::args& args,
const nb::kwargs& kwargs,
const std::string& prefix) {
@@ -30,8 +29,8 @@ validate_and_extract_inputs(
"and/or dictionary of arrays.");
}
};
std::vector<mx::array> args_;
std::map<std::string, mx::array> kwargs_;
mx::Args args_;
mx::Kwargs kwargs_;
if (args.size() == 0) {
// No args so kwargs must be keyword arrays
maybe_throw(nb::try_cast(kwargs, kwargs_));
@@ -81,9 +80,7 @@ class PyFunctionExporter {
void close() {
exporter_.close();
}
void operator()(
const std::vector<mx::array>& args,
const std::map<std::string, mx::array>& kwargs) {
void operator()(const mx::Args& args, const mx::Kwargs& kwargs) {
exporter_(args, kwargs);
}
@@ -112,23 +109,22 @@ PyType_Slot py_function_exporter_slots[] = {
{0, 0}};
auto wrap_export_function(nb::callable fun) {
return [fun = std::move(fun)](
const std::vector<mx::array>& args_,
const std::map<std::string, mx::array>& kwargs_) {
auto kwargs = nb::dict();
kwargs.update(nb::cast(kwargs_));
auto args = nb::tuple(nb::cast(args_));
auto outputs = fun(*args, **kwargs);
std::vector<mx::array> outputs_;
if (nb::isinstance<mx::array>(outputs)) {
outputs_.push_back(nb::cast<mx::array>(outputs));
} else if (!nb::try_cast(outputs, outputs_)) {
throw std::invalid_argument(
"[export_function] Outputs can be either a single array "
"a tuple or list of arrays.");
}
return outputs_;
};
return
[fun = std::move(fun)](const mx::Args& args_, const mx::Kwargs& kwargs_) {
auto kwargs = nb::dict();
kwargs.update(nb::cast(kwargs_));
auto args = nb::tuple(nb::cast(args_));
auto outputs = fun(*args, **kwargs);
std::vector<mx::array> outputs_;
if (nb::isinstance<mx::array>(outputs)) {
outputs_.push_back(nb::cast<mx::array>(outputs));
} else if (!nb::try_cast(outputs, outputs_)) {
throw std::invalid_argument(
"[export_function] Outputs can be either a single array "
"a tuple or list of arrays.");
}
return outputs_;
};
}
void init_export(nb::module_& m) {