mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-04 06:44:40 +08:00
Use unordered map for kwargs in export/import (#2087)
* use unordered map for kwargs in export/import * comment
This commit is contained in:
@@ -1,8 +1,8 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
#include <nanobind/nanobind.h>
|
||||
#include <nanobind/stl/map.h>
|
||||
#include <nanobind/stl/optional.h>
|
||||
#include <nanobind/stl/string.h>
|
||||
#include <nanobind/stl/unordered_map.h>
|
||||
#include <nanobind/stl/vector.h>
|
||||
|
||||
#include <fstream>
|
||||
@@ -16,8 +16,7 @@ namespace mx = mlx::core;
|
||||
namespace nb = nanobind;
|
||||
using namespace nb::literals;
|
||||
|
||||
std::pair<std::vector<mx::array>, std::map<std::string, mx::array>>
|
||||
validate_and_extract_inputs(
|
||||
std::pair<mx::Args, mx::Kwargs> validate_and_extract_inputs(
|
||||
const nb::args& args,
|
||||
const nb::kwargs& kwargs,
|
||||
const std::string& prefix) {
|
||||
@@ -30,8 +29,8 @@ validate_and_extract_inputs(
|
||||
"and/or dictionary of arrays.");
|
||||
}
|
||||
};
|
||||
std::vector<mx::array> args_;
|
||||
std::map<std::string, mx::array> kwargs_;
|
||||
mx::Args args_;
|
||||
mx::Kwargs kwargs_;
|
||||
if (args.size() == 0) {
|
||||
// No args so kwargs must be keyword arrays
|
||||
maybe_throw(nb::try_cast(kwargs, kwargs_));
|
||||
@@ -81,9 +80,7 @@ class PyFunctionExporter {
|
||||
void close() {
|
||||
exporter_.close();
|
||||
}
|
||||
void operator()(
|
||||
const std::vector<mx::array>& args,
|
||||
const std::map<std::string, mx::array>& kwargs) {
|
||||
void operator()(const mx::Args& args, const mx::Kwargs& kwargs) {
|
||||
exporter_(args, kwargs);
|
||||
}
|
||||
|
||||
@@ -112,23 +109,22 @@ PyType_Slot py_function_exporter_slots[] = {
|
||||
{0, 0}};
|
||||
|
||||
auto wrap_export_function(nb::callable fun) {
|
||||
return [fun = std::move(fun)](
|
||||
const std::vector<mx::array>& args_,
|
||||
const std::map<std::string, mx::array>& kwargs_) {
|
||||
auto kwargs = nb::dict();
|
||||
kwargs.update(nb::cast(kwargs_));
|
||||
auto args = nb::tuple(nb::cast(args_));
|
||||
auto outputs = fun(*args, **kwargs);
|
||||
std::vector<mx::array> outputs_;
|
||||
if (nb::isinstance<mx::array>(outputs)) {
|
||||
outputs_.push_back(nb::cast<mx::array>(outputs));
|
||||
} else if (!nb::try_cast(outputs, outputs_)) {
|
||||
throw std::invalid_argument(
|
||||
"[export_function] Outputs can be either a single array "
|
||||
"a tuple or list of arrays.");
|
||||
}
|
||||
return outputs_;
|
||||
};
|
||||
return
|
||||
[fun = std::move(fun)](const mx::Args& args_, const mx::Kwargs& kwargs_) {
|
||||
auto kwargs = nb::dict();
|
||||
kwargs.update(nb::cast(kwargs_));
|
||||
auto args = nb::tuple(nb::cast(args_));
|
||||
auto outputs = fun(*args, **kwargs);
|
||||
std::vector<mx::array> outputs_;
|
||||
if (nb::isinstance<mx::array>(outputs)) {
|
||||
outputs_.push_back(nb::cast<mx::array>(outputs));
|
||||
} else if (!nb::try_cast(outputs, outputs_)) {
|
||||
throw std::invalid_argument(
|
||||
"[export_function] Outputs can be either a single array "
|
||||
"a tuple or list of arrays.");
|
||||
}
|
||||
return outputs_;
|
||||
};
|
||||
}
|
||||
|
||||
void init_export(nb::module_& m) {
|
||||
|
Reference in New Issue
Block a user