Add types, fix kwarg ordering bug + test

This commit is contained in:
Awni Hannun
2025-09-24 15:50:53 -07:00
parent 9fcfcf04c6
commit ad0dd9b5ba
4 changed files with 50 additions and 23 deletions

View File

@@ -467,8 +467,10 @@ struct FunctionTable {
}; };
bool shapeless; bool shapeless;
std::unordered_map<int, std::vector<Function>> table; std::unordered_map<int, std::vector<Function>> table;
Function* find(const Args& args, const Kwargs& kwargs); Function* find(const Args& args, const std::map<std::string, array>& kwargs);
std::pair<Function&, bool> emplace(const Args& args, const Kwargs& kwargs); std::pair<Function&, bool> emplace(
const Args& args,
const std::map<std::string, array>& kwargs);
void insert( void insert(
std::vector<std::string> kwarg_keys, std::vector<std::string> kwarg_keys,
std::vector<array> inputs, std::vector<array> inputs,
@@ -504,12 +506,15 @@ struct FunctionTable {
} }
private: private:
bool match(const Args& args, const Kwargs& kwargs, const Function& fun); bool match(
const Args& args,
const std::map<std::string, array>& kwargs,
const Function& fun);
}; };
bool FunctionTable::match( bool FunctionTable::match(
const Args& args, const Args& args,
const Kwargs& kwargs, const std::map<std::string, array>& kwargs,
const Function& fun) { const Function& fun) {
for (auto& k : fun.kwarg_keys) { for (auto& k : fun.kwarg_keys) {
if (kwargs.find(k) == kwargs.end()) { if (kwargs.find(k) == kwargs.end()) {
@@ -537,9 +542,7 @@ bool FunctionTable::match(
return false; return false;
} }
} }
auto sorted_kwargs = for (auto& [_, in] : kwargs) {
std::map<std::string, array>(kwargs.begin(), kwargs.end());
for (auto& [_, in] : sorted_kwargs) {
if (!match_inputs(in, fun.inputs[i++])) { if (!match_inputs(in, fun.inputs[i++])) {
return false; return false;
} }
@@ -550,7 +553,7 @@ bool FunctionTable::match(
std::pair<FunctionTable::Function&, bool> FunctionTable::emplace( std::pair<FunctionTable::Function&, bool> FunctionTable::emplace(
const Args& args, const Args& args,
const Kwargs& kwargs) { const std::map<std::string, array>& kwargs) {
auto n_inputs = args.size() + kwargs.size(); auto n_inputs = args.size() + kwargs.size();
auto [it, _] = table.emplace(n_inputs, std::vector<Function>{}); auto [it, _] = table.emplace(n_inputs, std::vector<Function>{});
auto& funs_vec = it->second; auto& funs_vec = it->second;
@@ -567,7 +570,7 @@ std::pair<FunctionTable::Function&, bool> FunctionTable::emplace(
FunctionTable::Function* FunctionTable::find( FunctionTable::Function* FunctionTable::find(
const Args& args, const Args& args,
const Kwargs& kwargs) { const std::map<std::string, array>& kwargs) {
auto n_inputs = args.size() + kwargs.size(); auto n_inputs = args.size() + kwargs.size();
auto it = table.find(n_inputs); auto it = table.find(n_inputs);
if (it == table.end()) { if (it == table.end()) {
@@ -611,7 +614,8 @@ void FunctionExporter::close() {
void FunctionExporter::export_with_callback( void FunctionExporter::export_with_callback(
const std::vector<array>& inputs, const std::vector<array>& inputs,
const std::vector<array>& outputs, const std::vector<array>& outputs,
const std::vector<array>& tape) { const std::vector<array>& tape,
const std::vector<std::string>& kwarg_keys) {
NodeNamer namer{}; NodeNamer namer{};
auto to_vector_data = [&namer](const auto& arrays) { auto to_vector_data = [&namer](const auto& arrays) {
std::vector<std::tuple<std::string, Shape, Dtype>> data; std::vector<std::tuple<std::string, Shape, Dtype>> data;
@@ -622,10 +626,15 @@ void FunctionExporter::export_with_callback(
}; };
// Callback on the inputs // Callback on the inputs
callback({{"inputs", to_vector_data(inputs)}}); callback({{"type", "inputs"}, {"inputs", to_vector_data(inputs)}});
std::vector<std::pair<std::string, std::string>> keyword_inputs;
for (int i = inputs.size() - kwarg_keys.size(); i < inputs.size(); ++i) {
keyword_inputs.emplace_back(kwarg_keys[i], namer.get_name(inputs[i]));
}
callback({{"type", "keyword_inputs"}, {"keywords", keyword_inputs}});
// Callback on the outputs // Callback on the outputs
callback({{"outputs", to_vector_data(outputs)}}); callback({{"type", "outputs"}, {"outputs", to_vector_data(outputs)}});
// Callback on the constants // Callback on the constants
{ {
@@ -642,7 +651,7 @@ void FunctionExporter::export_with_callback(
new_constants.emplace_back(namer.get_name(arr), arr); new_constants.emplace_back(namer.get_name(arr), arr);
} }
} }
callback({{"constants", new_constants}}); callback({{"type", "constants"}, {"constants", new_constants}});
} }
auto factory = PrimitiveFactory(); auto factory = PrimitiveFactory();
@@ -653,10 +662,11 @@ void FunctionExporter::export_with_callback(
} }
auto [name, state] = factory.extract_state(arr.primitive_ptr()); auto [name, state] = factory.extract_state(arr.primitive_ptr());
callback( callback(
{{"inputs", to_vector_data(arr.inputs())}, {{"type", "primitive"},
{"inputs", to_vector_data(arr.inputs())},
{"outputs", to_vector_data(arr.outputs())}, {"outputs", to_vector_data(arr.outputs())},
{"primitive", name}, {"name", name},
{"state", state}}); {"arguments", state}});
} }
} }
@@ -665,7 +675,9 @@ void FunctionExporter::export_function(const Args& args, const Kwargs& kwargs) {
throw std::runtime_error( throw std::runtime_error(
"[export_function] Attempting to write after exporting is closed."); "[export_function] Attempting to write after exporting is closed.");
} }
auto [fentry, inserted] = ftable->emplace(args, kwargs); auto sorted_kwargs =
std::map<std::string, array>(kwargs.begin(), kwargs.end());
auto [fentry, inserted] = ftable->emplace(args, sorted_kwargs);
if (!inserted) { if (!inserted) {
throw std::runtime_error( throw std::runtime_error(
"[export_function] Attempting to export a function twice with " "[export_function] Attempting to export a function twice with "
@@ -675,8 +687,6 @@ void FunctionExporter::export_function(const Args& args, const Kwargs& kwargs) {
// Flatten the inputs to the function for tracing // Flatten the inputs to the function for tracing
std::vector<std::string> kwarg_keys; std::vector<std::string> kwarg_keys;
auto inputs = args; auto inputs = args;
auto sorted_kwargs =
std::map<std::string, array>(kwargs.begin(), kwargs.end());
for (auto& [k, v] : sorted_kwargs) { for (auto& [k, v] : sorted_kwargs) {
kwarg_keys.push_back(k); kwarg_keys.push_back(k);
inputs.push_back(v); inputs.push_back(v);
@@ -710,7 +720,7 @@ void FunctionExporter::export_function(const Args& args, const Kwargs& kwargs) {
count++; count++;
if (callback) { if (callback) {
export_with_callback(trace_inputs, trace_outputs, tape); export_with_callback(trace_inputs, trace_outputs, tape, kwarg_keys);
return; return;
} }
@@ -908,7 +918,9 @@ std::vector<array> ImportedFunction::operator()(const Args& args) const {
std::vector<array> ImportedFunction::operator()( std::vector<array> ImportedFunction::operator()(
const Args& args, const Args& args,
const Kwargs& kwargs) const { const Kwargs& kwargs) const {
auto* fun = ftable->find(args, kwargs); auto sorted_kwargs =
std::map<std::string, array>(kwargs.begin(), kwargs.end());
auto* fun = ftable->find(args, sorted_kwargs);
if (fun == nullptr) { if (fun == nullptr) {
std::ostringstream msg; std::ostringstream msg;
msg << "[import_function::call] No imported function found which matches " msg << "[import_function::call] No imported function found which matches "
@@ -927,7 +939,7 @@ std::vector<array> ImportedFunction::operator()(
} }
auto inputs = args; auto inputs = args;
for (auto& [_, v] : kwargs) { for (auto& [_, v] : sorted_kwargs) {
inputs.push_back(v); inputs.push_back(v);
} }
return detail::compile_replace( return detail::compile_replace(

View File

@@ -31,6 +31,7 @@ using ExportCallbackInput = std::unordered_map<
std::variant< std::variant<
std::vector<std::tuple<std::string, Shape, Dtype>>, std::vector<std::tuple<std::string, Shape, Dtype>>,
std::vector<std::pair<std::string, array>>, std::vector<std::pair<std::string, array>>,
std::vector<std::pair<std::string, std::string>>,
std::vector<StateT>, std::vector<StateT>,
std::string>>; std::string>>;
using ExportCallback = std::function<void(const ExportCallbackInput&)>; using ExportCallback = std::function<void(const ExportCallbackInput&)>;

View File

@@ -70,7 +70,8 @@ struct FunctionExporter {
void export_with_callback( void export_with_callback(
const std::vector<array>& inputs, const std::vector<array>& inputs,
const std::vector<array>& outputs, const std::vector<array>& outputs,
const std::vector<array>& tape); const std::vector<array>& tape,
const std::vector<std::string>& kwarg_keys);
std::set<std::uintptr_t> constants; std::set<std::uintptr_t> constants;
int count{0}; int count{0};
bool closed{false}; bool closed{false};

View File

@@ -485,6 +485,19 @@ class TestExportImport(mlx_tests.MLXTestCase):
mx.array_equal(imported_fn(input_data)[0], model(input_data)) mx.array_equal(imported_fn(input_data)[0], model(input_data))
) )
def test_export_kwarg_ordering(self):
path = os.path.join(self.test_dir, "fun.mlxfn")
def fn(x, y):
return x - y
mx.export_function(path, fn, x=mx.array(1.0), y=mx.array(1.0))
imported = mx.import_function(path)
out = imported(x=mx.array(2.0), y=mx.array(3.0))[0]
self.assertEqual(out.item(), -1.0)
out = imported(y=mx.array(2.0), x=mx.array(3.0))[0]
self.assertEqual(out.item(), 1.0)
if __name__ == "__main__": if __name__ == "__main__":
mlx_tests.MLXTestRunner() mlx_tests.MLXTestRunner()