mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-24 01:17:26 +08:00
Use unordered map for kwargs in export/import (#2087)
* use unordered map for kwargs in export/import * comment
This commit is contained in:
parent
70ebc3b598
commit
dc4eada7f0
@ -1,5 +1,6 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
#include "mlx/export.h"
|
||||
#include <map>
|
||||
#include "mlx/compile_impl.h"
|
||||
#include "mlx/fast_primitives.h"
|
||||
#include "mlx/primitives.h"
|
||||
@ -481,7 +482,9 @@ bool FunctionTable::match(
|
||||
return false;
|
||||
}
|
||||
}
|
||||
for (auto& [_, in] : kwargs) {
|
||||
auto sorted_kwargs =
|
||||
std::map<std::string, array>(kwargs.begin(), kwargs.end());
|
||||
for (auto& [_, in] : sorted_kwargs) {
|
||||
if (!match_inputs(in, fun.inputs[i++])) {
|
||||
return false;
|
||||
}
|
||||
@ -557,7 +560,9 @@ void FunctionExporter::export_function(const Args& args, const Kwargs& kwargs) {
|
||||
// Flatten the inputs to the function for tracing
|
||||
std::vector<std::string> kwarg_keys;
|
||||
auto inputs = args;
|
||||
for (auto& [k, v] : kwargs) {
|
||||
auto sorted_kwargs =
|
||||
std::map<std::string, array>(kwargs.begin(), kwargs.end());
|
||||
for (auto& [k, v] : sorted_kwargs) {
|
||||
kwarg_keys.push_back(k);
|
||||
inputs.push_back(v);
|
||||
}
|
||||
|
@ -2,14 +2,14 @@
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <map>
|
||||
#include <set>
|
||||
#include <unordered_map>
|
||||
#include "mlx/array.h"
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
using Args = std::vector<array>;
|
||||
using Kwargs = std::map<std::string, array>;
|
||||
using Kwargs = std::unordered_map<std::string, array>;
|
||||
|
||||
struct FunctionExporter;
|
||||
|
||||
|
@ -1,8 +1,8 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
#include <nanobind/nanobind.h>
|
||||
#include <nanobind/stl/map.h>
|
||||
#include <nanobind/stl/optional.h>
|
||||
#include <nanobind/stl/string.h>
|
||||
#include <nanobind/stl/unordered_map.h>
|
||||
#include <nanobind/stl/vector.h>
|
||||
|
||||
#include <fstream>
|
||||
@ -16,8 +16,7 @@ namespace mx = mlx::core;
|
||||
namespace nb = nanobind;
|
||||
using namespace nb::literals;
|
||||
|
||||
std::pair<std::vector<mx::array>, std::map<std::string, mx::array>>
|
||||
validate_and_extract_inputs(
|
||||
std::pair<mx::Args, mx::Kwargs> validate_and_extract_inputs(
|
||||
const nb::args& args,
|
||||
const nb::kwargs& kwargs,
|
||||
const std::string& prefix) {
|
||||
@ -30,8 +29,8 @@ validate_and_extract_inputs(
|
||||
"and/or dictionary of arrays.");
|
||||
}
|
||||
};
|
||||
std::vector<mx::array> args_;
|
||||
std::map<std::string, mx::array> kwargs_;
|
||||
mx::Args args_;
|
||||
mx::Kwargs kwargs_;
|
||||
if (args.size() == 0) {
|
||||
// No args so kwargs must be keyword arrays
|
||||
maybe_throw(nb::try_cast(kwargs, kwargs_));
|
||||
@ -81,9 +80,7 @@ class PyFunctionExporter {
|
||||
void close() {
|
||||
exporter_.close();
|
||||
}
|
||||
void operator()(
|
||||
const std::vector<mx::array>& args,
|
||||
const std::map<std::string, mx::array>& kwargs) {
|
||||
void operator()(const mx::Args& args, const mx::Kwargs& kwargs) {
|
||||
exporter_(args, kwargs);
|
||||
}
|
||||
|
||||
@ -112,23 +109,22 @@ PyType_Slot py_function_exporter_slots[] = {
|
||||
{0, 0}};
|
||||
|
||||
auto wrap_export_function(nb::callable fun) {
|
||||
return [fun = std::move(fun)](
|
||||
const std::vector<mx::array>& args_,
|
||||
const std::map<std::string, mx::array>& kwargs_) {
|
||||
auto kwargs = nb::dict();
|
||||
kwargs.update(nb::cast(kwargs_));
|
||||
auto args = nb::tuple(nb::cast(args_));
|
||||
auto outputs = fun(*args, **kwargs);
|
||||
std::vector<mx::array> outputs_;
|
||||
if (nb::isinstance<mx::array>(outputs)) {
|
||||
outputs_.push_back(nb::cast<mx::array>(outputs));
|
||||
} else if (!nb::try_cast(outputs, outputs_)) {
|
||||
throw std::invalid_argument(
|
||||
"[export_function] Outputs can be either a single array "
|
||||
"a tuple or list of arrays.");
|
||||
}
|
||||
return outputs_;
|
||||
};
|
||||
return
|
||||
[fun = std::move(fun)](const mx::Args& args_, const mx::Kwargs& kwargs_) {
|
||||
auto kwargs = nb::dict();
|
||||
kwargs.update(nb::cast(kwargs_));
|
||||
auto args = nb::tuple(nb::cast(args_));
|
||||
auto outputs = fun(*args, **kwargs);
|
||||
std::vector<mx::array> outputs_;
|
||||
if (nb::isinstance<mx::array>(outputs)) {
|
||||
outputs_.push_back(nb::cast<mx::array>(outputs));
|
||||
} else if (!nb::try_cast(outputs, outputs_)) {
|
||||
throw std::invalid_argument(
|
||||
"[export_function] Outputs can be either a single array "
|
||||
"a tuple or list of arrays.");
|
||||
}
|
||||
return outputs_;
|
||||
};
|
||||
}
|
||||
|
||||
void init_export(nb::module_& m) {
|
||||
|
@ -97,8 +97,7 @@ TEST_CASE("test export primitives with state") {
|
||||
TEST_CASE("test export functions with kwargs") {
|
||||
std::string file_path = get_temp_file("model.mlxfn");
|
||||
|
||||
auto fun =
|
||||
[](const std::map<std::string, array>& kwargs) -> std::vector<array> {
|
||||
auto fun = [](const Kwargs& kwargs) -> std::vector<array> {
|
||||
return {kwargs.at("x") + kwargs.at("y")};
|
||||
};
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user