mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-24 09:21:16 +08:00
Use unordered map for kwargs in export/import (#2087)
* use unordered map for kwargs in export/import * comment
This commit is contained in:
parent
70ebc3b598
commit
dc4eada7f0
@ -1,5 +1,6 @@
|
|||||||
// Copyright © 2024 Apple Inc.
|
// Copyright © 2024 Apple Inc.
|
||||||
#include "mlx/export.h"
|
#include "mlx/export.h"
|
||||||
|
#include <map>
|
||||||
#include "mlx/compile_impl.h"
|
#include "mlx/compile_impl.h"
|
||||||
#include "mlx/fast_primitives.h"
|
#include "mlx/fast_primitives.h"
|
||||||
#include "mlx/primitives.h"
|
#include "mlx/primitives.h"
|
||||||
@ -481,7 +482,9 @@ bool FunctionTable::match(
|
|||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
for (auto& [_, in] : kwargs) {
|
auto sorted_kwargs =
|
||||||
|
std::map<std::string, array>(kwargs.begin(), kwargs.end());
|
||||||
|
for (auto& [_, in] : sorted_kwargs) {
|
||||||
if (!match_inputs(in, fun.inputs[i++])) {
|
if (!match_inputs(in, fun.inputs[i++])) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
@ -557,7 +560,9 @@ void FunctionExporter::export_function(const Args& args, const Kwargs& kwargs) {
|
|||||||
// Flatten the inputs to the function for tracing
|
// Flatten the inputs to the function for tracing
|
||||||
std::vector<std::string> kwarg_keys;
|
std::vector<std::string> kwarg_keys;
|
||||||
auto inputs = args;
|
auto inputs = args;
|
||||||
for (auto& [k, v] : kwargs) {
|
auto sorted_kwargs =
|
||||||
|
std::map<std::string, array>(kwargs.begin(), kwargs.end());
|
||||||
|
for (auto& [k, v] : sorted_kwargs) {
|
||||||
kwarg_keys.push_back(k);
|
kwarg_keys.push_back(k);
|
||||||
inputs.push_back(v);
|
inputs.push_back(v);
|
||||||
}
|
}
|
||||||
|
@ -2,14 +2,14 @@
|
|||||||
|
|
||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
#include <map>
|
|
||||||
#include <set>
|
#include <set>
|
||||||
|
#include <unordered_map>
|
||||||
#include "mlx/array.h"
|
#include "mlx/array.h"
|
||||||
|
|
||||||
namespace mlx::core {
|
namespace mlx::core {
|
||||||
|
|
||||||
using Args = std::vector<array>;
|
using Args = std::vector<array>;
|
||||||
using Kwargs = std::map<std::string, array>;
|
using Kwargs = std::unordered_map<std::string, array>;
|
||||||
|
|
||||||
struct FunctionExporter;
|
struct FunctionExporter;
|
||||||
|
|
||||||
|
@ -1,8 +1,8 @@
|
|||||||
// Copyright © 2024 Apple Inc.
|
// Copyright © 2024 Apple Inc.
|
||||||
#include <nanobind/nanobind.h>
|
#include <nanobind/nanobind.h>
|
||||||
#include <nanobind/stl/map.h>
|
|
||||||
#include <nanobind/stl/optional.h>
|
#include <nanobind/stl/optional.h>
|
||||||
#include <nanobind/stl/string.h>
|
#include <nanobind/stl/string.h>
|
||||||
|
#include <nanobind/stl/unordered_map.h>
|
||||||
#include <nanobind/stl/vector.h>
|
#include <nanobind/stl/vector.h>
|
||||||
|
|
||||||
#include <fstream>
|
#include <fstream>
|
||||||
@ -16,8 +16,7 @@ namespace mx = mlx::core;
|
|||||||
namespace nb = nanobind;
|
namespace nb = nanobind;
|
||||||
using namespace nb::literals;
|
using namespace nb::literals;
|
||||||
|
|
||||||
std::pair<std::vector<mx::array>, std::map<std::string, mx::array>>
|
std::pair<mx::Args, mx::Kwargs> validate_and_extract_inputs(
|
||||||
validate_and_extract_inputs(
|
|
||||||
const nb::args& args,
|
const nb::args& args,
|
||||||
const nb::kwargs& kwargs,
|
const nb::kwargs& kwargs,
|
||||||
const std::string& prefix) {
|
const std::string& prefix) {
|
||||||
@ -30,8 +29,8 @@ validate_and_extract_inputs(
|
|||||||
"and/or dictionary of arrays.");
|
"and/or dictionary of arrays.");
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
std::vector<mx::array> args_;
|
mx::Args args_;
|
||||||
std::map<std::string, mx::array> kwargs_;
|
mx::Kwargs kwargs_;
|
||||||
if (args.size() == 0) {
|
if (args.size() == 0) {
|
||||||
// No args so kwargs must be keyword arrays
|
// No args so kwargs must be keyword arrays
|
||||||
maybe_throw(nb::try_cast(kwargs, kwargs_));
|
maybe_throw(nb::try_cast(kwargs, kwargs_));
|
||||||
@ -81,9 +80,7 @@ class PyFunctionExporter {
|
|||||||
void close() {
|
void close() {
|
||||||
exporter_.close();
|
exporter_.close();
|
||||||
}
|
}
|
||||||
void operator()(
|
void operator()(const mx::Args& args, const mx::Kwargs& kwargs) {
|
||||||
const std::vector<mx::array>& args,
|
|
||||||
const std::map<std::string, mx::array>& kwargs) {
|
|
||||||
exporter_(args, kwargs);
|
exporter_(args, kwargs);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -112,23 +109,22 @@ PyType_Slot py_function_exporter_slots[] = {
|
|||||||
{0, 0}};
|
{0, 0}};
|
||||||
|
|
||||||
auto wrap_export_function(nb::callable fun) {
|
auto wrap_export_function(nb::callable fun) {
|
||||||
return [fun = std::move(fun)](
|
return
|
||||||
const std::vector<mx::array>& args_,
|
[fun = std::move(fun)](const mx::Args& args_, const mx::Kwargs& kwargs_) {
|
||||||
const std::map<std::string, mx::array>& kwargs_) {
|
auto kwargs = nb::dict();
|
||||||
auto kwargs = nb::dict();
|
kwargs.update(nb::cast(kwargs_));
|
||||||
kwargs.update(nb::cast(kwargs_));
|
auto args = nb::tuple(nb::cast(args_));
|
||||||
auto args = nb::tuple(nb::cast(args_));
|
auto outputs = fun(*args, **kwargs);
|
||||||
auto outputs = fun(*args, **kwargs);
|
std::vector<mx::array> outputs_;
|
||||||
std::vector<mx::array> outputs_;
|
if (nb::isinstance<mx::array>(outputs)) {
|
||||||
if (nb::isinstance<mx::array>(outputs)) {
|
outputs_.push_back(nb::cast<mx::array>(outputs));
|
||||||
outputs_.push_back(nb::cast<mx::array>(outputs));
|
} else if (!nb::try_cast(outputs, outputs_)) {
|
||||||
} else if (!nb::try_cast(outputs, outputs_)) {
|
throw std::invalid_argument(
|
||||||
throw std::invalid_argument(
|
"[export_function] Outputs can be either a single array "
|
||||||
"[export_function] Outputs can be either a single array "
|
"a tuple or list of arrays.");
|
||||||
"a tuple or list of arrays.");
|
}
|
||||||
}
|
return outputs_;
|
||||||
return outputs_;
|
};
|
||||||
};
|
|
||||||
}
|
}
|
||||||
|
|
||||||
void init_export(nb::module_& m) {
|
void init_export(nb::module_& m) {
|
||||||
|
@ -97,8 +97,7 @@ TEST_CASE("test export primitives with state") {
|
|||||||
TEST_CASE("test export functions with kwargs") {
|
TEST_CASE("test export functions with kwargs") {
|
||||||
std::string file_path = get_temp_file("model.mlxfn");
|
std::string file_path = get_temp_file("model.mlxfn");
|
||||||
|
|
||||||
auto fun =
|
auto fun = [](const Kwargs& kwargs) -> std::vector<array> {
|
||||||
[](const std::map<std::string, array>& kwargs) -> std::vector<array> {
|
|
||||||
return {kwargs.at("x") + kwargs.at("y")};
|
return {kwargs.at("x") + kwargs.at("y")};
|
||||||
};
|
};
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user