mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-19 10:48:09 +08:00
Shapeless compilation for some graphs (#687)
* shapeless compilation for some graphs * update compile benchmark * default compile a few activations * buffer donation * bugfix * shapeless fix * update tests to work for cpu and gpu fusion * test kwargs * add kwargs to compile * Recompile when python arguments change * no compile for tanh * some constant tests --------- Co-authored-by: Angelos Katharopoulos <a_katharopoulos@apple.com>
This commit is contained in:
@@ -555,13 +555,19 @@ struct PyCompiledFun {
|
||||
size_t fun_id;
|
||||
py::object captured_inputs;
|
||||
py::object captured_outputs;
|
||||
bool shapeless;
|
||||
size_t num_outputs{0};
|
||||
|
||||
PyCompiledFun(const py::function& fun, py::object inputs, py::object outputs)
|
||||
PyCompiledFun(
|
||||
const py::function& fun,
|
||||
py::object inputs,
|
||||
py::object outputs,
|
||||
bool shapeless)
|
||||
: fun(fun),
|
||||
fun_id(reinterpret_cast<size_t>(fun.ptr())),
|
||||
captured_inputs(inputs),
|
||||
captured_outputs(outputs) {}
|
||||
captured_outputs(outputs),
|
||||
shapeless(shapeless) {}
|
||||
|
||||
PyCompiledFun(const PyCompiledFun&) = delete;
|
||||
PyCompiledFun& operator=(const PyCompiledFun&) = delete;
|
||||
@@ -571,11 +577,15 @@ struct PyCompiledFun {
|
||||
other.fun_id = 0;
|
||||
captured_inputs = std::move(other.captured_inputs);
|
||||
captured_outputs = std::move(other.captured_outputs);
|
||||
shapeless = other.shapeless;
|
||||
num_outputs = other.num_outputs;
|
||||
};
|
||||
|
||||
py::object operator()(const py::args& args) {
|
||||
auto compile_fun = [this, &args](const std::vector<array>& a) {
|
||||
py::object operator()(const py::args& args, const py::kwargs& kwargs) {
|
||||
auto inputs = tree_flatten(args, false);
|
||||
|
||||
auto compile_fun = [this, &args, &kwargs, num_args = inputs.size()](
|
||||
const std::vector<array>& a) {
|
||||
// Put tracers into captured inputs
|
||||
std::vector<array> flat_in_captures;
|
||||
std::vector<array> trace_captures;
|
||||
@@ -586,8 +596,10 @@ struct PyCompiledFun {
|
||||
tree_fill(captured_inputs, trace_captures);
|
||||
}
|
||||
|
||||
auto [outputs, py_outputs] = tree_flatten_with_structure(
|
||||
std::move(fun(*tree_unflatten(args, a))), false);
|
||||
auto tree_outputs =
|
||||
fun(*tree_unflatten(args, a), **tree_unflatten(kwargs, a, num_args));
|
||||
auto [outputs, py_outputs] =
|
||||
tree_flatten_with_structure(std::move(tree_outputs), false);
|
||||
|
||||
tree_cache().insert({fun_id, py_outputs});
|
||||
|
||||
@@ -607,7 +619,14 @@ struct PyCompiledFun {
|
||||
return outputs;
|
||||
};
|
||||
|
||||
auto inputs = tree_flatten(args, false);
|
||||
{
|
||||
auto flat_kwargs = tree_flatten(kwargs, false);
|
||||
inputs.insert(
|
||||
inputs.end(),
|
||||
std::make_move_iterator(flat_kwargs.begin()),
|
||||
std::make_move_iterator(flat_kwargs.end()));
|
||||
}
|
||||
|
||||
if (!py::isinstance<py::none>(captured_inputs)) {
|
||||
auto flat_in_captures = tree_flatten(captured_inputs, false);
|
||||
inputs.insert(
|
||||
@@ -616,8 +635,39 @@ struct PyCompiledFun {
|
||||
std::make_move_iterator(flat_in_captures.end()));
|
||||
}
|
||||
|
||||
// Collect the compilation constants
|
||||
std::vector<uint64_t> constants;
|
||||
auto value_hash = [](py::handle o) -> std::optional<uint64_t> {
|
||||
// Consider expanding tuples to their contents including start and end
|
||||
// ids
|
||||
if (py::isinstance<py::tuple>(o) || py::isinstance<py::str>(o)) {
|
||||
auto r = py::hash(o);
|
||||
return *reinterpret_cast<uint64_t*>(&r);
|
||||
} else if (py::isinstance<py::int_>(o)) {
|
||||
auto r = o.cast<int64_t>();
|
||||
return *reinterpret_cast<uint64_t*>(&r);
|
||||
} else if (py::isinstance<py::float_>(o)) {
|
||||
auto r = o.cast<double>();
|
||||
return *reinterpret_cast<uint64_t*>(&r);
|
||||
} else {
|
||||
return std::nullopt;
|
||||
}
|
||||
};
|
||||
for (int i = 0; i < args.size(); i++) {
|
||||
if (auto h = value_hash(args[i]); h.has_value()) {
|
||||
constants.push_back(*h);
|
||||
}
|
||||
}
|
||||
for (auto& pair : kwargs) {
|
||||
if (auto h = value_hash(pair.second); h.has_value()) {
|
||||
constants.push_back(*value_hash(pair.first));
|
||||
constants.push_back(*h);
|
||||
}
|
||||
}
|
||||
|
||||
// Compile and call
|
||||
auto outputs = detail::compile(compile_fun, fun_id)(inputs);
|
||||
auto outputs =
|
||||
detail::compile(compile_fun, fun_id, shapeless, constants)(inputs);
|
||||
if (!py::isinstance<py::none>(captured_outputs)) {
|
||||
std::vector<array> captures(
|
||||
std::make_move_iterator(outputs.begin() + num_outputs),
|
||||
@@ -965,12 +1015,14 @@ void init_transforms(py::module_& m) {
|
||||
"compile",
|
||||
[](const py::function& fun,
|
||||
const py::object& inputs,
|
||||
const py::object& outputs) {
|
||||
return py::cpp_function(PyCompiledFun{fun, inputs, outputs});
|
||||
const py::object& outputs,
|
||||
bool shapeless) {
|
||||
return py::cpp_function(PyCompiledFun{fun, inputs, outputs, shapeless});
|
||||
},
|
||||
"fun"_a,
|
||||
"inputs"_a = std::nullopt,
|
||||
"outputs"_a = std::nullopt,
|
||||
"shapeless"_a = false,
|
||||
R"pbdoc(
|
||||
compile(fun: function) -> function
|
||||
|
||||
@@ -990,6 +1042,12 @@ void init_transforms(py::module_& m) {
|
||||
:obj:`list` or a :obj:`dict` containing arbitrarily nested lists,
|
||||
dictionaries, or arrays. Leaf nodes that are not :obj:`array` are ignored.
|
||||
Default: ``None``
|
||||
shapeless (bool, optional): A function compiled with the ``shapeless``
|
||||
option enabled will not be recompiled when the input shape changes. Not all
|
||||
functions can be compiled with ``shapeless`` enabled. Attempting to compile
|
||||
such functions with shapeless enabled will throw. Note, changing the number
|
||||
of dimensions or type of any input will result in a recompilation even with
|
||||
``shapeless`` set to ``True``. Default: ``False``
|
||||
|
||||
Returns:
|
||||
function: A compiled function which has the same input arguments
|
||||
|
Reference in New Issue
Block a user