mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Add types, fix kwarg ordering bug + test
This commit is contained in:
@@ -467,8 +467,10 @@ struct FunctionTable {
|
|||||||
};
|
};
|
||||||
bool shapeless;
|
bool shapeless;
|
||||||
std::unordered_map<int, std::vector<Function>> table;
|
std::unordered_map<int, std::vector<Function>> table;
|
||||||
Function* find(const Args& args, const Kwargs& kwargs);
|
Function* find(const Args& args, const std::map<std::string, array>& kwargs);
|
||||||
std::pair<Function&, bool> emplace(const Args& args, const Kwargs& kwargs);
|
std::pair<Function&, bool> emplace(
|
||||||
|
const Args& args,
|
||||||
|
const std::map<std::string, array>& kwargs);
|
||||||
void insert(
|
void insert(
|
||||||
std::vector<std::string> kwarg_keys,
|
std::vector<std::string> kwarg_keys,
|
||||||
std::vector<array> inputs,
|
std::vector<array> inputs,
|
||||||
@@ -504,12 +506,15 @@ struct FunctionTable {
|
|||||||
}
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
bool match(const Args& args, const Kwargs& kwargs, const Function& fun);
|
bool match(
|
||||||
|
const Args& args,
|
||||||
|
const std::map<std::string, array>& kwargs,
|
||||||
|
const Function& fun);
|
||||||
};
|
};
|
||||||
|
|
||||||
bool FunctionTable::match(
|
bool FunctionTable::match(
|
||||||
const Args& args,
|
const Args& args,
|
||||||
const Kwargs& kwargs,
|
const std::map<std::string, array>& kwargs,
|
||||||
const Function& fun) {
|
const Function& fun) {
|
||||||
for (auto& k : fun.kwarg_keys) {
|
for (auto& k : fun.kwarg_keys) {
|
||||||
if (kwargs.find(k) == kwargs.end()) {
|
if (kwargs.find(k) == kwargs.end()) {
|
||||||
@@ -537,9 +542,7 @@ bool FunctionTable::match(
|
|||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
auto sorted_kwargs =
|
for (auto& [_, in] : kwargs) {
|
||||||
std::map<std::string, array>(kwargs.begin(), kwargs.end());
|
|
||||||
for (auto& [_, in] : sorted_kwargs) {
|
|
||||||
if (!match_inputs(in, fun.inputs[i++])) {
|
if (!match_inputs(in, fun.inputs[i++])) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
@@ -550,7 +553,7 @@ bool FunctionTable::match(
|
|||||||
|
|
||||||
std::pair<FunctionTable::Function&, bool> FunctionTable::emplace(
|
std::pair<FunctionTable::Function&, bool> FunctionTable::emplace(
|
||||||
const Args& args,
|
const Args& args,
|
||||||
const Kwargs& kwargs) {
|
const std::map<std::string, array>& kwargs) {
|
||||||
auto n_inputs = args.size() + kwargs.size();
|
auto n_inputs = args.size() + kwargs.size();
|
||||||
auto [it, _] = table.emplace(n_inputs, std::vector<Function>{});
|
auto [it, _] = table.emplace(n_inputs, std::vector<Function>{});
|
||||||
auto& funs_vec = it->second;
|
auto& funs_vec = it->second;
|
||||||
@@ -567,7 +570,7 @@ std::pair<FunctionTable::Function&, bool> FunctionTable::emplace(
|
|||||||
|
|
||||||
FunctionTable::Function* FunctionTable::find(
|
FunctionTable::Function* FunctionTable::find(
|
||||||
const Args& args,
|
const Args& args,
|
||||||
const Kwargs& kwargs) {
|
const std::map<std::string, array>& kwargs) {
|
||||||
auto n_inputs = args.size() + kwargs.size();
|
auto n_inputs = args.size() + kwargs.size();
|
||||||
auto it = table.find(n_inputs);
|
auto it = table.find(n_inputs);
|
||||||
if (it == table.end()) {
|
if (it == table.end()) {
|
||||||
@@ -611,7 +614,8 @@ void FunctionExporter::close() {
|
|||||||
void FunctionExporter::export_with_callback(
|
void FunctionExporter::export_with_callback(
|
||||||
const std::vector<array>& inputs,
|
const std::vector<array>& inputs,
|
||||||
const std::vector<array>& outputs,
|
const std::vector<array>& outputs,
|
||||||
const std::vector<array>& tape) {
|
const std::vector<array>& tape,
|
||||||
|
const std::vector<std::string>& kwarg_keys) {
|
||||||
NodeNamer namer{};
|
NodeNamer namer{};
|
||||||
auto to_vector_data = [&namer](const auto& arrays) {
|
auto to_vector_data = [&namer](const auto& arrays) {
|
||||||
std::vector<std::tuple<std::string, Shape, Dtype>> data;
|
std::vector<std::tuple<std::string, Shape, Dtype>> data;
|
||||||
@@ -622,10 +626,15 @@ void FunctionExporter::export_with_callback(
|
|||||||
};
|
};
|
||||||
|
|
||||||
// Callback on the inputs
|
// Callback on the inputs
|
||||||
callback({{"inputs", to_vector_data(inputs)}});
|
callback({{"type", "inputs"}, {"inputs", to_vector_data(inputs)}});
|
||||||
|
std::vector<std::pair<std::string, std::string>> keyword_inputs;
|
||||||
|
for (int i = inputs.size() - kwarg_keys.size(); i < inputs.size(); ++i) {
|
||||||
|
keyword_inputs.emplace_back(kwarg_keys[i], namer.get_name(inputs[i]));
|
||||||
|
}
|
||||||
|
callback({{"type", "keyword_inputs"}, {"keywords", keyword_inputs}});
|
||||||
|
|
||||||
// Callback on the outputs
|
// Callback on the outputs
|
||||||
callback({{"outputs", to_vector_data(outputs)}});
|
callback({{"type", "outputs"}, {"outputs", to_vector_data(outputs)}});
|
||||||
|
|
||||||
// Callback on the constants
|
// Callback on the constants
|
||||||
{
|
{
|
||||||
@@ -642,7 +651,7 @@ void FunctionExporter::export_with_callback(
|
|||||||
new_constants.emplace_back(namer.get_name(arr), arr);
|
new_constants.emplace_back(namer.get_name(arr), arr);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
callback({{"constants", new_constants}});
|
callback({{"type", "constants"}, {"constants", new_constants}});
|
||||||
}
|
}
|
||||||
auto factory = PrimitiveFactory();
|
auto factory = PrimitiveFactory();
|
||||||
|
|
||||||
@@ -653,10 +662,11 @@ void FunctionExporter::export_with_callback(
|
|||||||
}
|
}
|
||||||
auto [name, state] = factory.extract_state(arr.primitive_ptr());
|
auto [name, state] = factory.extract_state(arr.primitive_ptr());
|
||||||
callback(
|
callback(
|
||||||
{{"inputs", to_vector_data(arr.inputs())},
|
{{"type", "primitive"},
|
||||||
|
{"inputs", to_vector_data(arr.inputs())},
|
||||||
{"outputs", to_vector_data(arr.outputs())},
|
{"outputs", to_vector_data(arr.outputs())},
|
||||||
{"primitive", name},
|
{"name", name},
|
||||||
{"state", state}});
|
{"arguments", state}});
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -665,7 +675,9 @@ void FunctionExporter::export_function(const Args& args, const Kwargs& kwargs) {
|
|||||||
throw std::runtime_error(
|
throw std::runtime_error(
|
||||||
"[export_function] Attempting to write after exporting is closed.");
|
"[export_function] Attempting to write after exporting is closed.");
|
||||||
}
|
}
|
||||||
auto [fentry, inserted] = ftable->emplace(args, kwargs);
|
auto sorted_kwargs =
|
||||||
|
std::map<std::string, array>(kwargs.begin(), kwargs.end());
|
||||||
|
auto [fentry, inserted] = ftable->emplace(args, sorted_kwargs);
|
||||||
if (!inserted) {
|
if (!inserted) {
|
||||||
throw std::runtime_error(
|
throw std::runtime_error(
|
||||||
"[export_function] Attempting to export a function twice with "
|
"[export_function] Attempting to export a function twice with "
|
||||||
@@ -675,8 +687,6 @@ void FunctionExporter::export_function(const Args& args, const Kwargs& kwargs) {
|
|||||||
// Flatten the inputs to the function for tracing
|
// Flatten the inputs to the function for tracing
|
||||||
std::vector<std::string> kwarg_keys;
|
std::vector<std::string> kwarg_keys;
|
||||||
auto inputs = args;
|
auto inputs = args;
|
||||||
auto sorted_kwargs =
|
|
||||||
std::map<std::string, array>(kwargs.begin(), kwargs.end());
|
|
||||||
for (auto& [k, v] : sorted_kwargs) {
|
for (auto& [k, v] : sorted_kwargs) {
|
||||||
kwarg_keys.push_back(k);
|
kwarg_keys.push_back(k);
|
||||||
inputs.push_back(v);
|
inputs.push_back(v);
|
||||||
@@ -710,7 +720,7 @@ void FunctionExporter::export_function(const Args& args, const Kwargs& kwargs) {
|
|||||||
count++;
|
count++;
|
||||||
|
|
||||||
if (callback) {
|
if (callback) {
|
||||||
export_with_callback(trace_inputs, trace_outputs, tape);
|
export_with_callback(trace_inputs, trace_outputs, tape, kwarg_keys);
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -908,7 +918,9 @@ std::vector<array> ImportedFunction::operator()(const Args& args) const {
|
|||||||
std::vector<array> ImportedFunction::operator()(
|
std::vector<array> ImportedFunction::operator()(
|
||||||
const Args& args,
|
const Args& args,
|
||||||
const Kwargs& kwargs) const {
|
const Kwargs& kwargs) const {
|
||||||
auto* fun = ftable->find(args, kwargs);
|
auto sorted_kwargs =
|
||||||
|
std::map<std::string, array>(kwargs.begin(), kwargs.end());
|
||||||
|
auto* fun = ftable->find(args, sorted_kwargs);
|
||||||
if (fun == nullptr) {
|
if (fun == nullptr) {
|
||||||
std::ostringstream msg;
|
std::ostringstream msg;
|
||||||
msg << "[import_function::call] No imported function found which matches "
|
msg << "[import_function::call] No imported function found which matches "
|
||||||
@@ -927,7 +939,7 @@ std::vector<array> ImportedFunction::operator()(
|
|||||||
}
|
}
|
||||||
|
|
||||||
auto inputs = args;
|
auto inputs = args;
|
||||||
for (auto& [_, v] : kwargs) {
|
for (auto& [_, v] : sorted_kwargs) {
|
||||||
inputs.push_back(v);
|
inputs.push_back(v);
|
||||||
}
|
}
|
||||||
return detail::compile_replace(
|
return detail::compile_replace(
|
||||||
|
|||||||
@@ -31,6 +31,7 @@ using ExportCallbackInput = std::unordered_map<
|
|||||||
std::variant<
|
std::variant<
|
||||||
std::vector<std::tuple<std::string, Shape, Dtype>>,
|
std::vector<std::tuple<std::string, Shape, Dtype>>,
|
||||||
std::vector<std::pair<std::string, array>>,
|
std::vector<std::pair<std::string, array>>,
|
||||||
|
std::vector<std::pair<std::string, std::string>>,
|
||||||
std::vector<StateT>,
|
std::vector<StateT>,
|
||||||
std::string>>;
|
std::string>>;
|
||||||
using ExportCallback = std::function<void(const ExportCallbackInput&)>;
|
using ExportCallback = std::function<void(const ExportCallbackInput&)>;
|
||||||
|
|||||||
@@ -70,7 +70,8 @@ struct FunctionExporter {
|
|||||||
void export_with_callback(
|
void export_with_callback(
|
||||||
const std::vector<array>& inputs,
|
const std::vector<array>& inputs,
|
||||||
const std::vector<array>& outputs,
|
const std::vector<array>& outputs,
|
||||||
const std::vector<array>& tape);
|
const std::vector<array>& tape,
|
||||||
|
const std::vector<std::string>& kwarg_keys);
|
||||||
std::set<std::uintptr_t> constants;
|
std::set<std::uintptr_t> constants;
|
||||||
int count{0};
|
int count{0};
|
||||||
bool closed{false};
|
bool closed{false};
|
||||||
|
|||||||
@@ -485,6 +485,19 @@ class TestExportImport(mlx_tests.MLXTestCase):
|
|||||||
mx.array_equal(imported_fn(input_data)[0], model(input_data))
|
mx.array_equal(imported_fn(input_data)[0], model(input_data))
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def test_export_kwarg_ordering(self):
|
||||||
|
path = os.path.join(self.test_dir, "fun.mlxfn")
|
||||||
|
|
||||||
|
def fn(x, y):
|
||||||
|
return x - y
|
||||||
|
|
||||||
|
mx.export_function(path, fn, x=mx.array(1.0), y=mx.array(1.0))
|
||||||
|
imported = mx.import_function(path)
|
||||||
|
out = imported(x=mx.array(2.0), y=mx.array(3.0))[0]
|
||||||
|
self.assertEqual(out.item(), -1.0)
|
||||||
|
out = imported(y=mx.array(2.0), x=mx.array(3.0))[0]
|
||||||
|
self.assertEqual(out.item(), 1.0)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
mlx_tests.MLXTestRunner()
|
mlx_tests.MLXTestRunner()
|
||||||
|
|||||||
Reference in New Issue
Block a user