mirror of
https://github.com/ml-explore/mlx.git
synced 2025-10-19 00:04:41 +08:00
Export with callback (#2612)
* export with callback * export with callback * Add types, fix kwarg ordering bug + test * cleanup, test, fix * typos
This commit is contained in:
@@ -1,6 +1,6 @@
|
|||||||
// Copyright © 2025 Apple Inc.
|
// Copyright © 2025 Apple Inc.
|
||||||
|
|
||||||
// This file must not include any host-only code, utilies that work under both
|
// This file must not include any host-only code, utilities that work under both
|
||||||
// host and device can be put here.
|
// host and device can be put here.
|
||||||
//
|
//
|
||||||
// See more about the requirements at:
|
// See more about the requirements at:
|
||||||
@@ -202,7 +202,7 @@ struct Limits<
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
// CUDA 11 does not have host side arithmatic operators for half types.
|
// CUDA 11 does not have host side arithmetic operators for half types.
|
||||||
template <typename T>
|
template <typename T>
|
||||||
struct Limits<
|
struct Limits<
|
||||||
T,
|
T,
|
||||||
|
@@ -1,8 +1,8 @@
|
|||||||
// Copyright © 2025 Apple Inc.
|
// Copyright © 2025 Apple Inc.
|
||||||
|
|
||||||
// This file includes host-only utilies for writing CUDA kernels, the difference
|
// This file includes host-only utilities for writing CUDA kernels, the
|
||||||
// from backend/cuda/device/utils.cuh is that the latter file only include
|
// difference from backend/cuda/device/utils.cuh is that the latter file only
|
||||||
// device-only code.
|
// include device-only code.
|
||||||
|
|
||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
|
@@ -1,6 +1,6 @@
|
|||||||
// Copyright © 2025 Apple Inc.
|
// Copyright © 2025 Apple Inc.
|
||||||
|
|
||||||
// This file include utilies that are used by C++ code (i.e. .cpp files).
|
// This file include utilities that are used by C++ code (i.e. .cpp files).
|
||||||
|
|
||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
|
219
mlx/export.cpp
219
mlx/export.cpp
@@ -3,6 +3,7 @@
|
|||||||
#include <map>
|
#include <map>
|
||||||
#include "mlx/compile_impl.h"
|
#include "mlx/compile_impl.h"
|
||||||
#include "mlx/fast_primitives.h"
|
#include "mlx/fast_primitives.h"
|
||||||
|
#include "mlx/graph_utils.h"
|
||||||
#include "mlx/primitives.h"
|
#include "mlx/primitives.h"
|
||||||
#include "mlx/utils.h"
|
#include "mlx/utils.h"
|
||||||
#include "mlx/version.h"
|
#include "mlx/version.h"
|
||||||
@@ -13,6 +14,7 @@
|
|||||||
#primitive, { \
|
#primitive, { \
|
||||||
serialize_primitive<primitive>, \
|
serialize_primitive<primitive>, \
|
||||||
deserialize_primitive<primitive>, \
|
deserialize_primitive<primitive>, \
|
||||||
|
primitive_state<primitive>, \
|
||||||
{__VA_ARGS__} \
|
{__VA_ARGS__} \
|
||||||
} \
|
} \
|
||||||
}
|
}
|
||||||
@@ -34,15 +36,20 @@ struct PrimitiveSerializer {
|
|||||||
using Serializer = std::function<void(Writer&, const Primitive&)>;
|
using Serializer = std::function<void(Writer&, const Primitive&)>;
|
||||||
using Deserializer =
|
using Deserializer =
|
||||||
std::function<std::shared_ptr<Primitive>(Reader&, Stream s)>;
|
std::function<std::shared_ptr<Primitive>(Reader&, Stream s)>;
|
||||||
|
using StateExtractor = std::function<std::vector<StateT>(const Primitive&)>;
|
||||||
|
|
||||||
PrimitiveSerializer(
|
PrimitiveSerializer(
|
||||||
Serializer serialize,
|
Serializer serialize,
|
||||||
Deserializer deserialize,
|
Deserializer deserialize,
|
||||||
|
StateExtractor extract_state,
|
||||||
std::vector<std::string> keys = {})
|
std::vector<std::string> keys = {})
|
||||||
: serialize(std::move(serialize)),
|
: serialize(std::move(serialize)),
|
||||||
deserialize(std::move(deserialize)),
|
deserialize(std::move(deserialize)),
|
||||||
|
extract_state(std::move(extract_state)),
|
||||||
keys(std::move(keys)) {};
|
keys(std::move(keys)) {};
|
||||||
Serializer serialize;
|
Serializer serialize;
|
||||||
Deserializer deserialize;
|
Deserializer deserialize;
|
||||||
|
StateExtractor extract_state;
|
||||||
std::vector<std::string> keys;
|
std::vector<std::string> keys;
|
||||||
};
|
};
|
||||||
|
|
||||||
@@ -198,6 +205,32 @@ void serialize_primitive(Writer& os, const Primitive& p) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
void extract_state(const T state, std::vector<StateT>& unpacked_state) {
|
||||||
|
if constexpr (std::is_arithmetic_v<T>) {
|
||||||
|
unpacked_state.push_back(state);
|
||||||
|
} else if constexpr (std::is_enum_v<T>) {
|
||||||
|
unpacked_state.push_back(static_cast<int>(state));
|
||||||
|
} else if constexpr (is_iterable<T>) {
|
||||||
|
unpacked_state.push_back(state);
|
||||||
|
} else if constexpr (is_pair<T> || is_tuple<T>) {
|
||||||
|
std::apply(
|
||||||
|
[&unpacked_state](auto&... x) {
|
||||||
|
(..., extract_state(x, unpacked_state));
|
||||||
|
},
|
||||||
|
state);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
std::vector<StateT> primitive_state(const Primitive& p) {
|
||||||
|
std::vector<StateT> state;
|
||||||
|
if constexpr (has_state<T>) {
|
||||||
|
extract_state(static_cast<const T&>(p).state(), state);
|
||||||
|
}
|
||||||
|
return state;
|
||||||
|
}
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
std::shared_ptr<T> deserialize_primitive(Reader& is, Stream s) {
|
std::shared_ptr<T> deserialize_primitive(Reader& is, Stream s) {
|
||||||
if constexpr (has_state<T>) {
|
if constexpr (has_state<T>) {
|
||||||
@@ -383,6 +416,23 @@ struct PrimitiveFactory {
|
|||||||
"[import_function] Unable to deserialize primitive " + name);
|
"[import_function] Unable to deserialize primitive " + name);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
std::pair<std::string, std::vector<StateT>> extract_state(
|
||||||
|
const std::shared_ptr<Primitive>& p) {
|
||||||
|
std::string name = p->name();
|
||||||
|
name = name.substr(0, name.find(' '));
|
||||||
|
if (auto it = name_remap.find(name); it != name_remap.end()) {
|
||||||
|
name = it->second;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (auto it = factory.find(name); it != factory.end()) {
|
||||||
|
auto state = it->second.extract_state(*p);
|
||||||
|
return {name, state};
|
||||||
|
} else {
|
||||||
|
throw std::invalid_argument(
|
||||||
|
"[export_function] Unable to get state for primitive " + name);
|
||||||
|
}
|
||||||
|
};
|
||||||
};
|
};
|
||||||
|
|
||||||
void write_header(Writer& os, int count, bool shapeless) {
|
void write_header(Writer& os, int count, bool shapeless) {
|
||||||
@@ -416,8 +466,10 @@ struct FunctionTable {
|
|||||||
};
|
};
|
||||||
bool shapeless;
|
bool shapeless;
|
||||||
std::unordered_map<int, std::vector<Function>> table;
|
std::unordered_map<int, std::vector<Function>> table;
|
||||||
Function* find(const Args& args, const Kwargs& kwargs);
|
Function* find(const Args& args, const std::map<std::string, array>& kwargs);
|
||||||
std::pair<Function&, bool> emplace(const Args& args, const Kwargs& kwargs);
|
std::pair<Function&, bool> emplace(
|
||||||
|
const Args& args,
|
||||||
|
const std::map<std::string, array>& kwargs);
|
||||||
void insert(
|
void insert(
|
||||||
std::vector<std::string> kwarg_keys,
|
std::vector<std::string> kwarg_keys,
|
||||||
std::vector<array> inputs,
|
std::vector<array> inputs,
|
||||||
@@ -453,12 +505,15 @@ struct FunctionTable {
|
|||||||
}
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
bool match(const Args& args, const Kwargs& kwargs, const Function& fun);
|
bool match(
|
||||||
|
const Args& args,
|
||||||
|
const std::map<std::string, array>& kwargs,
|
||||||
|
const Function& fun);
|
||||||
};
|
};
|
||||||
|
|
||||||
bool FunctionTable::match(
|
bool FunctionTable::match(
|
||||||
const Args& args,
|
const Args& args,
|
||||||
const Kwargs& kwargs,
|
const std::map<std::string, array>& kwargs,
|
||||||
const Function& fun) {
|
const Function& fun) {
|
||||||
for (auto& k : fun.kwarg_keys) {
|
for (auto& k : fun.kwarg_keys) {
|
||||||
if (kwargs.find(k) == kwargs.end()) {
|
if (kwargs.find(k) == kwargs.end()) {
|
||||||
@@ -486,9 +541,7 @@ bool FunctionTable::match(
|
|||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
auto sorted_kwargs =
|
for (auto& [_, in] : kwargs) {
|
||||||
std::map<std::string, array>(kwargs.begin(), kwargs.end());
|
|
||||||
for (auto& [_, in] : sorted_kwargs) {
|
|
||||||
if (!match_inputs(in, fun.inputs[i++])) {
|
if (!match_inputs(in, fun.inputs[i++])) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
@@ -499,7 +552,7 @@ bool FunctionTable::match(
|
|||||||
|
|
||||||
std::pair<FunctionTable::Function&, bool> FunctionTable::emplace(
|
std::pair<FunctionTable::Function&, bool> FunctionTable::emplace(
|
||||||
const Args& args,
|
const Args& args,
|
||||||
const Kwargs& kwargs) {
|
const std::map<std::string, array>& kwargs) {
|
||||||
auto n_inputs = args.size() + kwargs.size();
|
auto n_inputs = args.size() + kwargs.size();
|
||||||
auto [it, _] = table.emplace(n_inputs, std::vector<Function>{});
|
auto [it, _] = table.emplace(n_inputs, std::vector<Function>{});
|
||||||
auto& funs_vec = it->second;
|
auto& funs_vec = it->second;
|
||||||
@@ -516,7 +569,7 @@ std::pair<FunctionTable::Function&, bool> FunctionTable::emplace(
|
|||||||
|
|
||||||
FunctionTable::Function* FunctionTable::find(
|
FunctionTable::Function* FunctionTable::find(
|
||||||
const Args& args,
|
const Args& args,
|
||||||
const Kwargs& kwargs) {
|
const std::map<std::string, array>& kwargs) {
|
||||||
auto n_inputs = args.size() + kwargs.size();
|
auto n_inputs = args.size() + kwargs.size();
|
||||||
auto it = table.find(n_inputs);
|
auto it = table.find(n_inputs);
|
||||||
if (it == table.end()) {
|
if (it == table.end()) {
|
||||||
@@ -545,16 +598,86 @@ FunctionExporter::FunctionExporter(
|
|||||||
write_header(os, count, shapeless);
|
write_header(os, count, shapeless);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
FunctionExporter::FunctionExporter(
|
||||||
|
const ExportCallback& callback,
|
||||||
|
std::function<std::vector<array>(const Args&, const Kwargs&)> fun,
|
||||||
|
bool shapeless)
|
||||||
|
: callback(callback),
|
||||||
|
fun(std::move(fun)),
|
||||||
|
ftable(std::make_shared<FunctionTable>(shapeless)) {}
|
||||||
|
|
||||||
void FunctionExporter::close() {
|
void FunctionExporter::close() {
|
||||||
closed = true;
|
closed = true;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
void FunctionExporter::export_with_callback(
|
||||||
|
const std::vector<array>& inputs,
|
||||||
|
const std::vector<array>& outputs,
|
||||||
|
const std::vector<array>& tape,
|
||||||
|
const std::vector<std::string>& kwarg_keys) {
|
||||||
|
NodeNamer namer{};
|
||||||
|
auto to_vector_data = [&namer](const auto& arrays) {
|
||||||
|
std::vector<std::tuple<std::string, Shape, Dtype>> data;
|
||||||
|
for (auto& a : arrays) {
|
||||||
|
data.emplace_back(namer.get_name(a), a.shape(), a.dtype());
|
||||||
|
}
|
||||||
|
return data;
|
||||||
|
};
|
||||||
|
|
||||||
|
// Callback on the inputs
|
||||||
|
callback({{"type", "inputs"}, {"inputs", to_vector_data(inputs)}});
|
||||||
|
std::vector<std::pair<std::string, std::string>> keyword_inputs;
|
||||||
|
for (int i = inputs.size() - kwarg_keys.size(), j = 0; i < inputs.size();
|
||||||
|
++i, ++j) {
|
||||||
|
keyword_inputs.emplace_back(kwarg_keys[j], namer.get_name(inputs[i]));
|
||||||
|
}
|
||||||
|
callback({{"type", "keyword_inputs"}, {"keywords", keyword_inputs}});
|
||||||
|
|
||||||
|
// Callback on the outputs
|
||||||
|
callback({{"type", "outputs"}, {"outputs", to_vector_data(outputs)}});
|
||||||
|
|
||||||
|
// Callback on the constants
|
||||||
|
{
|
||||||
|
std::unordered_set<std::uintptr_t> input_set;
|
||||||
|
for (auto& in : inputs) {
|
||||||
|
input_set.insert(in.id());
|
||||||
|
}
|
||||||
|
std::vector<std::pair<std::string, array>> new_constants;
|
||||||
|
for (auto& arr : tape) {
|
||||||
|
if (arr.has_primitive() || input_set.find(arr.id()) != input_set.end()) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
if (constants.insert(arr.id()).second) {
|
||||||
|
new_constants.emplace_back(namer.get_name(arr), arr);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
callback({{"type", "constants"}, {"constants", new_constants}});
|
||||||
|
}
|
||||||
|
auto factory = PrimitiveFactory();
|
||||||
|
|
||||||
|
// Callback for each primitive in the tape
|
||||||
|
for (auto& arr : tape) {
|
||||||
|
if (!arr.has_primitive()) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
auto [name, state] = factory.extract_state(arr.primitive_ptr());
|
||||||
|
callback(
|
||||||
|
{{"type", "primitive"},
|
||||||
|
{"inputs", to_vector_data(arr.inputs())},
|
||||||
|
{"outputs", to_vector_data(arr.outputs())},
|
||||||
|
{"name", name},
|
||||||
|
{"arguments", state}});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
void FunctionExporter::export_function(const Args& args, const Kwargs& kwargs) {
|
void FunctionExporter::export_function(const Args& args, const Kwargs& kwargs) {
|
||||||
if (closed) {
|
if (closed) {
|
||||||
throw std::runtime_error(
|
throw std::runtime_error(
|
||||||
"[export_function] Attempting to write after exporting is closed.");
|
"[export_function] Attempting to write after exporting is closed.");
|
||||||
}
|
}
|
||||||
auto [fentry, inserted] = ftable->emplace(args, kwargs);
|
auto sorted_kwargs =
|
||||||
|
std::map<std::string, array>(kwargs.begin(), kwargs.end());
|
||||||
|
auto [fentry, inserted] = ftable->emplace(args, sorted_kwargs);
|
||||||
if (!inserted) {
|
if (!inserted) {
|
||||||
throw std::runtime_error(
|
throw std::runtime_error(
|
||||||
"[export_function] Attempting to export a function twice with "
|
"[export_function] Attempting to export a function twice with "
|
||||||
@@ -564,8 +687,6 @@ void FunctionExporter::export_function(const Args& args, const Kwargs& kwargs) {
|
|||||||
// Flatten the inputs to the function for tracing
|
// Flatten the inputs to the function for tracing
|
||||||
std::vector<std::string> kwarg_keys;
|
std::vector<std::string> kwarg_keys;
|
||||||
auto inputs = args;
|
auto inputs = args;
|
||||||
auto sorted_kwargs =
|
|
||||||
std::map<std::string, array>(kwargs.begin(), kwargs.end());
|
|
||||||
for (auto& [k, v] : sorted_kwargs) {
|
for (auto& [k, v] : sorted_kwargs) {
|
||||||
kwarg_keys.push_back(k);
|
kwarg_keys.push_back(k);
|
||||||
inputs.push_back(v);
|
inputs.push_back(v);
|
||||||
@@ -592,10 +713,18 @@ void FunctionExporter::export_function(const Args& args, const Kwargs& kwargs) {
|
|||||||
|
|
||||||
detail::compile_simplify(tape, parents_map, trace_outputs, /* passes */ 3);
|
detail::compile_simplify(tape, parents_map, trace_outputs, /* passes */ 3);
|
||||||
|
|
||||||
// Update header
|
// Update the table entry
|
||||||
|
fentry.kwarg_keys = kwarg_keys;
|
||||||
|
fentry.inputs = trace_inputs;
|
||||||
|
|
||||||
count++;
|
count++;
|
||||||
|
|
||||||
// Overwrite the header
|
if (callback) {
|
||||||
|
export_with_callback(trace_inputs, trace_outputs, tape, kwarg_keys);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Update the header
|
||||||
auto pos = os.tell();
|
auto pos = os.tell();
|
||||||
os.seek(0);
|
os.seek(0);
|
||||||
write_header(os, count, ftable->shapeless);
|
write_header(os, count, ftable->shapeless);
|
||||||
@@ -616,10 +745,6 @@ void FunctionExporter::export_function(const Args& args, const Kwargs& kwargs) {
|
|||||||
serialize(os, trace_inputs);
|
serialize(os, trace_inputs);
|
||||||
serialize(os, arrays_to_ids(trace_outputs));
|
serialize(os, arrays_to_ids(trace_outputs));
|
||||||
|
|
||||||
// Update the table entry
|
|
||||||
fentry.kwarg_keys = std::move(kwarg_keys);
|
|
||||||
fentry.inputs = std::move(trace_inputs);
|
|
||||||
|
|
||||||
std::unordered_set<std::uintptr_t> input_set(
|
std::unordered_set<std::uintptr_t> input_set(
|
||||||
trace_input_ids.begin(), trace_input_ids.end());
|
trace_input_ids.begin(), trace_input_ids.end());
|
||||||
|
|
||||||
@@ -730,6 +855,58 @@ void export_function(
|
|||||||
exporter(file, fun, shapeless)(args, kwargs);
|
exporter(file, fun, shapeless)(args, kwargs);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
FunctionExporter exporter(
|
||||||
|
const ExportCallback& callback,
|
||||||
|
const std::function<std::vector<array>(const Args&)>& fun,
|
||||||
|
bool shapeless /* = false */) {
|
||||||
|
return FunctionExporter{
|
||||||
|
callback,
|
||||||
|
[fun](const Args& args, const Kwargs&) { return fun(args); },
|
||||||
|
shapeless};
|
||||||
|
}
|
||||||
|
|
||||||
|
FunctionExporter exporter(
|
||||||
|
const ExportCallback& callback,
|
||||||
|
const std::function<std::vector<array>(const Kwargs&)>& fun,
|
||||||
|
bool shapeless /* = false */) {
|
||||||
|
return exporter(
|
||||||
|
callback,
|
||||||
|
[fun](const Args&, const Kwargs kwargs) { return fun(kwargs); },
|
||||||
|
shapeless);
|
||||||
|
}
|
||||||
|
|
||||||
|
FunctionExporter exporter(
|
||||||
|
const ExportCallback& callback,
|
||||||
|
const std::function<std::vector<array>(const Args&, const Kwargs&)>& fun,
|
||||||
|
bool shapeless /* = false */) {
|
||||||
|
return FunctionExporter{callback, fun, shapeless};
|
||||||
|
}
|
||||||
|
|
||||||
|
void export_function(
|
||||||
|
const ExportCallback& callback,
|
||||||
|
const std::function<std::vector<array>(const Args&)>& fun,
|
||||||
|
const Args& args,
|
||||||
|
bool shapeless /* = false */) {
|
||||||
|
exporter(callback, fun, shapeless)(args);
|
||||||
|
}
|
||||||
|
|
||||||
|
void export_function(
|
||||||
|
const ExportCallback& callback,
|
||||||
|
const std::function<std::vector<array>(const Kwargs&)>& fun,
|
||||||
|
const Kwargs& kwargs,
|
||||||
|
bool shapeless /* = false */) {
|
||||||
|
exporter(callback, fun, shapeless)(kwargs);
|
||||||
|
}
|
||||||
|
|
||||||
|
void export_function(
|
||||||
|
const ExportCallback& callback,
|
||||||
|
const std::function<std::vector<array>(const Args&, const Kwargs&)>& fun,
|
||||||
|
const Args& args,
|
||||||
|
const Kwargs& kwargs,
|
||||||
|
bool shapeless /* = false */) {
|
||||||
|
exporter(callback, fun, shapeless)(args, kwargs);
|
||||||
|
}
|
||||||
|
|
||||||
std::vector<array> ImportedFunction::operator()(const Kwargs& kwargs) const {
|
std::vector<array> ImportedFunction::operator()(const Kwargs& kwargs) const {
|
||||||
return this->operator()({}, kwargs);
|
return this->operator()({}, kwargs);
|
||||||
}
|
}
|
||||||
@@ -741,7 +918,9 @@ std::vector<array> ImportedFunction::operator()(const Args& args) const {
|
|||||||
std::vector<array> ImportedFunction::operator()(
|
std::vector<array> ImportedFunction::operator()(
|
||||||
const Args& args,
|
const Args& args,
|
||||||
const Kwargs& kwargs) const {
|
const Kwargs& kwargs) const {
|
||||||
auto* fun = ftable->find(args, kwargs);
|
auto sorted_kwargs =
|
||||||
|
std::map<std::string, array>(kwargs.begin(), kwargs.end());
|
||||||
|
auto* fun = ftable->find(args, sorted_kwargs);
|
||||||
if (fun == nullptr) {
|
if (fun == nullptr) {
|
||||||
std::ostringstream msg;
|
std::ostringstream msg;
|
||||||
msg << "[import_function::call] No imported function found which matches "
|
msg << "[import_function::call] No imported function found which matches "
|
||||||
@@ -760,7 +939,7 @@ std::vector<array> ImportedFunction::operator()(
|
|||||||
}
|
}
|
||||||
|
|
||||||
auto inputs = args;
|
auto inputs = args;
|
||||||
for (auto& [_, v] : kwargs) {
|
for (auto& [_, v] : sorted_kwargs) {
|
||||||
inputs.push_back(v);
|
inputs.push_back(v);
|
||||||
}
|
}
|
||||||
return detail::compile_replace(
|
return detail::compile_replace(
|
||||||
|
66
mlx/export.h
66
mlx/export.h
@@ -4,6 +4,7 @@
|
|||||||
|
|
||||||
#include <set>
|
#include <set>
|
||||||
#include <unordered_map>
|
#include <unordered_map>
|
||||||
|
#include <variant>
|
||||||
#include "mlx/array.h"
|
#include "mlx/array.h"
|
||||||
|
|
||||||
namespace mlx::core {
|
namespace mlx::core {
|
||||||
@@ -11,6 +12,30 @@ namespace mlx::core {
|
|||||||
using Args = std::vector<array>;
|
using Args = std::vector<array>;
|
||||||
using Kwargs = std::unordered_map<std::string, array>;
|
using Kwargs = std::unordered_map<std::string, array>;
|
||||||
|
|
||||||
|
// Possible types for a Primitive's state
|
||||||
|
using StateT = std::variant<
|
||||||
|
bool,
|
||||||
|
int,
|
||||||
|
size_t,
|
||||||
|
float,
|
||||||
|
double,
|
||||||
|
Dtype,
|
||||||
|
Shape,
|
||||||
|
Strides,
|
||||||
|
std::vector<int>,
|
||||||
|
std::vector<size_t>,
|
||||||
|
std::string>;
|
||||||
|
|
||||||
|
using ExportCallbackInput = std::unordered_map<
|
||||||
|
std::string,
|
||||||
|
std::variant<
|
||||||
|
std::vector<std::tuple<std::string, Shape, Dtype>>,
|
||||||
|
std::vector<std::pair<std::string, array>>,
|
||||||
|
std::vector<std::pair<std::string, std::string>>,
|
||||||
|
std::vector<StateT>,
|
||||||
|
std::string>>;
|
||||||
|
using ExportCallback = std::function<void(const ExportCallbackInput&)>;
|
||||||
|
|
||||||
struct FunctionExporter;
|
struct FunctionExporter;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@@ -61,6 +86,47 @@ struct ImportedFunction;
|
|||||||
*/
|
*/
|
||||||
ImportedFunction import_function(const std::string& file);
|
ImportedFunction import_function(const std::string& file);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Make an exporter to export multiple traces of a given function with the same
|
||||||
|
* callback.
|
||||||
|
*/
|
||||||
|
FunctionExporter exporter(
|
||||||
|
const ExportCallback& callback,
|
||||||
|
const std::function<std::vector<array>(const Args&)>& fun,
|
||||||
|
bool shapeless = false);
|
||||||
|
|
||||||
|
FunctionExporter exporter(
|
||||||
|
const ExportCallback& callback,
|
||||||
|
const std::function<std::vector<array>(const Kwargs&)>& fun,
|
||||||
|
bool shapeless = false);
|
||||||
|
|
||||||
|
FunctionExporter exporter(
|
||||||
|
const ExportCallback& callback,
|
||||||
|
const std::function<std::vector<array>(const Args&, const Kwargs&)>& fun,
|
||||||
|
bool shapeless = false);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Export a function with a callback.
|
||||||
|
*/
|
||||||
|
void export_function(
|
||||||
|
const ExportCallback& callback,
|
||||||
|
const std::function<std::vector<array>(const Args&)>& fun,
|
||||||
|
const Args& args,
|
||||||
|
bool shapeless = false);
|
||||||
|
|
||||||
|
void export_function(
|
||||||
|
const ExportCallback& callback,
|
||||||
|
const std::function<std::vector<array>(const Kwargs&)>& fun,
|
||||||
|
const Kwargs& kwargs,
|
||||||
|
bool shapeless = false);
|
||||||
|
|
||||||
|
void export_function(
|
||||||
|
const ExportCallback& callback,
|
||||||
|
const std::function<std::vector<array>(const Args&, const Kwargs&)>& fun,
|
||||||
|
const Args& args,
|
||||||
|
const Kwargs& kwargs,
|
||||||
|
bool shapeless = false);
|
||||||
|
|
||||||
} // namespace mlx::core
|
} // namespace mlx::core
|
||||||
|
|
||||||
#include "mlx/export_impl.h"
|
#include "mlx/export_impl.h"
|
||||||
|
@@ -38,13 +38,40 @@ struct FunctionExporter {
|
|||||||
const std::function<std::vector<array>(const Args&, const Kwargs&)>&,
|
const std::function<std::vector<array>(const Args&, const Kwargs&)>&,
|
||||||
bool shapeless);
|
bool shapeless);
|
||||||
|
|
||||||
|
friend FunctionExporter exporter(
|
||||||
|
const ExportCallback&,
|
||||||
|
const std::function<std::vector<array>(const Args&)>&,
|
||||||
|
bool shapeless);
|
||||||
|
|
||||||
|
friend FunctionExporter exporter(
|
||||||
|
const ExportCallback&,
|
||||||
|
const std::function<std::vector<array>(const Kwargs&)>&,
|
||||||
|
bool shapeless);
|
||||||
|
|
||||||
|
friend FunctionExporter exporter(
|
||||||
|
const ExportCallback&,
|
||||||
|
const std::function<std::vector<array>(const Args&, const Kwargs&)>&,
|
||||||
|
bool shapeless);
|
||||||
|
|
||||||
FunctionExporter(
|
FunctionExporter(
|
||||||
const std::string& file,
|
const std::string& file,
|
||||||
std::function<std::vector<array>(const Args&, const Kwargs&)> fun,
|
std::function<std::vector<array>(const Args&, const Kwargs&)> fun,
|
||||||
bool shapeless);
|
bool shapeless);
|
||||||
|
|
||||||
|
FunctionExporter(
|
||||||
|
const ExportCallback& callback,
|
||||||
|
std::function<std::vector<array>(const Args&, const Kwargs&)> fun,
|
||||||
|
bool shapeless);
|
||||||
|
|
||||||
io::FileWriter os;
|
io::FileWriter os;
|
||||||
|
ExportCallback callback;
|
||||||
std::function<std::vector<array>(const Args&, const Kwargs& kwargs)> fun;
|
std::function<std::vector<array>(const Args&, const Kwargs& kwargs)> fun;
|
||||||
void export_function(const Args& args, const Kwargs& kwargs);
|
void export_function(const Args& args, const Kwargs& kwargs);
|
||||||
|
void export_with_callback(
|
||||||
|
const std::vector<array>& inputs,
|
||||||
|
const std::vector<array>& outputs,
|
||||||
|
const std::vector<array>& tape,
|
||||||
|
const std::vector<std::string>& kwarg_keys);
|
||||||
std::set<std::uintptr_t> constants;
|
std::set<std::uintptr_t> constants;
|
||||||
int count{0};
|
int count{0};
|
||||||
bool closed{false};
|
bool closed{false};
|
||||||
|
@@ -108,6 +108,7 @@ class ParallelFileReader : public Reader {
|
|||||||
|
|
||||||
class FileWriter : public Writer {
|
class FileWriter : public Writer {
|
||||||
public:
|
public:
|
||||||
|
explicit FileWriter() {}
|
||||||
explicit FileWriter(std::string file_path)
|
explicit FileWriter(std::string file_path)
|
||||||
: fd_(open(
|
: fd_(open(
|
||||||
file_path.c_str(),
|
file_path.c_str(),
|
||||||
|
@@ -1,8 +1,11 @@
|
|||||||
// Copyright © 2024 Apple Inc.
|
// Copyright © 2024 Apple Inc.
|
||||||
#include <nanobind/nanobind.h>
|
#include <nanobind/nanobind.h>
|
||||||
#include <nanobind/stl/optional.h>
|
#include <nanobind/stl/optional.h>
|
||||||
|
#include <nanobind/stl/pair.h>
|
||||||
#include <nanobind/stl/string.h>
|
#include <nanobind/stl/string.h>
|
||||||
|
#include <nanobind/stl/tuple.h>
|
||||||
#include <nanobind/stl/unordered_map.h>
|
#include <nanobind/stl/unordered_map.h>
|
||||||
|
#include <nanobind/stl/variant.h>
|
||||||
#include <nanobind/stl/vector.h>
|
#include <nanobind/stl/vector.h>
|
||||||
|
|
||||||
#include <fstream>
|
#include <fstream>
|
||||||
@@ -131,24 +134,38 @@ auto wrap_export_function(nb::callable fun) {
|
|||||||
void init_export(nb::module_& m) {
|
void init_export(nb::module_& m) {
|
||||||
m.def(
|
m.def(
|
||||||
"export_function",
|
"export_function",
|
||||||
[](const std::string& file,
|
[](nb::object& file_or_callback,
|
||||||
const nb::callable& fun,
|
const nb::callable& fun,
|
||||||
const nb::args& args,
|
const nb::args& args,
|
||||||
bool shapeless,
|
bool shapeless,
|
||||||
const nb::kwargs& kwargs) {
|
const nb::kwargs& kwargs) {
|
||||||
auto [args_, kwargs_] =
|
auto [args_, kwargs_] =
|
||||||
validate_and_extract_inputs(args, kwargs, "[export_function]");
|
validate_and_extract_inputs(args, kwargs, "[export_function]");
|
||||||
mx::export_function(
|
if (nb::isinstance<nb::str>(file_or_callback)) {
|
||||||
file, wrap_export_function(fun), args_, kwargs_, shapeless);
|
mx::export_function(
|
||||||
|
nb::cast<std::string>(file_or_callback),
|
||||||
|
wrap_export_function(fun),
|
||||||
|
args_,
|
||||||
|
kwargs_,
|
||||||
|
shapeless);
|
||||||
|
} else {
|
||||||
|
auto callback = nb::cast<nb::callable>(file_or_callback);
|
||||||
|
auto wrapped_callback =
|
||||||
|
[callback](const mx::ExportCallbackInput& input) {
|
||||||
|
return callback(input);
|
||||||
|
};
|
||||||
|
mx::export_function(
|
||||||
|
callback, wrap_export_function(fun), args_, kwargs_, shapeless);
|
||||||
|
}
|
||||||
},
|
},
|
||||||
"file"_a,
|
nb::arg(),
|
||||||
"fun"_a,
|
"fun"_a,
|
||||||
"args"_a,
|
"args"_a,
|
||||||
nb::kw_only(),
|
nb::kw_only(),
|
||||||
"shapeless"_a = false,
|
"shapeless"_a = false,
|
||||||
"kwargs"_a,
|
"kwargs"_a,
|
||||||
R"pbdoc(
|
R"pbdoc(
|
||||||
Export a function to a file.
|
Export an MLX function.
|
||||||
|
|
||||||
Example input arrays must be provided to export a function. The example
|
Example input arrays must be provided to export a function. The example
|
||||||
inputs can be variable ``*args`` and ``**kwargs`` or a tuple of arrays
|
inputs can be variable ``*args`` and ``**kwargs`` or a tuple of arrays
|
||||||
@@ -161,7 +178,8 @@ void init_export(nb::module_& m) {
|
|||||||
versions of MLX may not be compatible with future versions.
|
versions of MLX may not be compatible with future versions.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
file (str): File path to export the function to.
|
file (str or Callable): Either a file path to export the function
|
||||||
|
to or a callback.
|
||||||
fun (Callable): A function which takes as input zero or more
|
fun (Callable): A function which takes as input zero or more
|
||||||
:class:`array` and returns one or more :class:`array`.
|
:class:`array` and returns one or more :class:`array`.
|
||||||
*args (array): Example array inputs to the function.
|
*args (array): Example array inputs to the function.
|
||||||
|
@@ -319,7 +319,7 @@ class TestCompile(mlx_tests.MLXTestCase):
|
|||||||
# Check the state is unchanged
|
# Check the state is unchanged
|
||||||
self.assertEqual(state["y"], 2)
|
self.assertEqual(state["y"], 2)
|
||||||
|
|
||||||
# Check the udpated state is used
|
# Check the updated state is used
|
||||||
state["y"] = mx.array(3)
|
state["y"] = mx.array(3)
|
||||||
out = test_state(mx.array(1))
|
out = test_state(mx.array(1))
|
||||||
self.assertEqual(out.item(), 4)
|
self.assertEqual(out.item(), 4)
|
||||||
|
@@ -485,6 +485,52 @@ class TestExportImport(mlx_tests.MLXTestCase):
|
|||||||
mx.array_equal(imported_fn(input_data)[0], model(input_data))
|
mx.array_equal(imported_fn(input_data)[0], model(input_data))
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def test_export_kwarg_ordering(self):
|
||||||
|
path = os.path.join(self.test_dir, "fun.mlxfn")
|
||||||
|
|
||||||
|
def fn(x, y):
|
||||||
|
return x - y
|
||||||
|
|
||||||
|
mx.export_function(path, fn, x=mx.array(1.0), y=mx.array(1.0))
|
||||||
|
imported = mx.import_function(path)
|
||||||
|
out = imported(x=mx.array(2.0), y=mx.array(3.0))[0]
|
||||||
|
self.assertEqual(out.item(), -1.0)
|
||||||
|
out = imported(y=mx.array(2.0), x=mx.array(3.0))[0]
|
||||||
|
self.assertEqual(out.item(), 1.0)
|
||||||
|
|
||||||
|
def test_export_with_callback(self):
|
||||||
|
|
||||||
|
def fn(x, y):
|
||||||
|
return mx.log(mx.abs(x - y))
|
||||||
|
|
||||||
|
n_in = None
|
||||||
|
n_out = None
|
||||||
|
n_const = None
|
||||||
|
keywords = None
|
||||||
|
primitives = []
|
||||||
|
|
||||||
|
def callback(args):
|
||||||
|
nonlocal n_in, n_out, n_const, keywords, primitives
|
||||||
|
t = args["type"]
|
||||||
|
if t == "inputs":
|
||||||
|
n_in = len(args["inputs"])
|
||||||
|
elif args["type"] == "outputs":
|
||||||
|
n_out = len(args["outputs"])
|
||||||
|
elif args["type"] == "keyword_inputs":
|
||||||
|
keywords = args["keywords"]
|
||||||
|
elif t == "constants":
|
||||||
|
n_const = len(args["constants"])
|
||||||
|
elif t == "primitive":
|
||||||
|
primitives.append(args["name"])
|
||||||
|
|
||||||
|
mx.export_function(callback, fn, mx.array(1.0), y=mx.array(1.0))
|
||||||
|
self.assertEqual(n_in, 2)
|
||||||
|
self.assertEqual(n_out, 1)
|
||||||
|
self.assertEqual(n_const, 0)
|
||||||
|
self.assertEqual(len(keywords), 1)
|
||||||
|
self.assertEqual(keywords[0][0], "y")
|
||||||
|
self.assertEqual(primitives, ["Subtract", "Abs", "Log"])
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
mlx_tests.MLXTestRunner()
|
mlx_tests.MLXTestRunner()
|
||||||
|
Reference in New Issue
Block a user