Shapeless compilation for some graphs (#687)

* shapeless compilation for some graphs

* update compile benchmark

* default compile a few activations

* buffer donation

* bugfix

* shapeless fix

* update tests to work for cpu and gpu fusion

* test kwargs

* add kwargs to compile

* Recompile when python arguments change

* no compile for tanh

* some constant tests

---------

Co-authored-by: Angelos Katharopoulos <a_katharopoulos@apple.com>
This commit is contained in:
Awni Hannun
2024-02-19 21:43:54 -08:00
committed by GitHub
parent d0fda82595
commit 5798256fcf
14 changed files with 645 additions and 113 deletions

View File

@@ -555,13 +555,19 @@ struct PyCompiledFun {
size_t fun_id;
py::object captured_inputs;
py::object captured_outputs;
bool shapeless;
size_t num_outputs{0};
PyCompiledFun(const py::function& fun, py::object inputs, py::object outputs)
PyCompiledFun(
const py::function& fun,
py::object inputs,
py::object outputs,
bool shapeless)
: fun(fun),
fun_id(reinterpret_cast<size_t>(fun.ptr())),
captured_inputs(inputs),
captured_outputs(outputs) {}
captured_outputs(outputs),
shapeless(shapeless) {}
PyCompiledFun(const PyCompiledFun&) = delete;
PyCompiledFun& operator=(const PyCompiledFun&) = delete;
@@ -571,11 +577,15 @@ struct PyCompiledFun {
other.fun_id = 0;
captured_inputs = std::move(other.captured_inputs);
captured_outputs = std::move(other.captured_outputs);
shapeless = other.shapeless;
num_outputs = other.num_outputs;
};
py::object operator()(const py::args& args) {
auto compile_fun = [this, &args](const std::vector<array>& a) {
py::object operator()(const py::args& args, const py::kwargs& kwargs) {
auto inputs = tree_flatten(args, false);
auto compile_fun = [this, &args, &kwargs, num_args = inputs.size()](
const std::vector<array>& a) {
// Put tracers into captured inputs
std::vector<array> flat_in_captures;
std::vector<array> trace_captures;
@@ -586,8 +596,10 @@ struct PyCompiledFun {
tree_fill(captured_inputs, trace_captures);
}
auto [outputs, py_outputs] = tree_flatten_with_structure(
std::move(fun(*tree_unflatten(args, a))), false);
auto tree_outputs =
fun(*tree_unflatten(args, a), **tree_unflatten(kwargs, a, num_args));
auto [outputs, py_outputs] =
tree_flatten_with_structure(std::move(tree_outputs), false);
tree_cache().insert({fun_id, py_outputs});
@@ -607,7 +619,14 @@ struct PyCompiledFun {
return outputs;
};
auto inputs = tree_flatten(args, false);
{
auto flat_kwargs = tree_flatten(kwargs, false);
inputs.insert(
inputs.end(),
std::make_move_iterator(flat_kwargs.begin()),
std::make_move_iterator(flat_kwargs.end()));
}
if (!py::isinstance<py::none>(captured_inputs)) {
auto flat_in_captures = tree_flatten(captured_inputs, false);
inputs.insert(
@@ -616,8 +635,39 @@ struct PyCompiledFun {
std::make_move_iterator(flat_in_captures.end()));
}
// Collect the compilation constants
std::vector<uint64_t> constants;
auto value_hash = [](py::handle o) -> std::optional<uint64_t> {
// Consider expanding tuples to their contents including start and end
// ids
if (py::isinstance<py::tuple>(o) || py::isinstance<py::str>(o)) {
auto r = py::hash(o);
return *reinterpret_cast<uint64_t*>(&r);
} else if (py::isinstance<py::int_>(o)) {
auto r = o.cast<int64_t>();
return *reinterpret_cast<uint64_t*>(&r);
} else if (py::isinstance<py::float_>(o)) {
auto r = o.cast<double>();
return *reinterpret_cast<uint64_t*>(&r);
} else {
return std::nullopt;
}
};
for (int i = 0; i < args.size(); i++) {
if (auto h = value_hash(args[i]); h.has_value()) {
constants.push_back(*h);
}
}
for (auto& pair : kwargs) {
if (auto h = value_hash(pair.second); h.has_value()) {
constants.push_back(*value_hash(pair.first));
constants.push_back(*h);
}
}
// Compile and call
auto outputs = detail::compile(compile_fun, fun_id)(inputs);
auto outputs =
detail::compile(compile_fun, fun_id, shapeless, constants)(inputs);
if (!py::isinstance<py::none>(captured_outputs)) {
std::vector<array> captures(
std::make_move_iterator(outputs.begin() + num_outputs),
@@ -965,12 +1015,14 @@ void init_transforms(py::module_& m) {
"compile",
[](const py::function& fun,
const py::object& inputs,
const py::object& outputs) {
return py::cpp_function(PyCompiledFun{fun, inputs, outputs});
const py::object& outputs,
bool shapeless) {
return py::cpp_function(PyCompiledFun{fun, inputs, outputs, shapeless});
},
"fun"_a,
"inputs"_a = std::nullopt,
"outputs"_a = std::nullopt,
"shapeless"_a = false,
R"pbdoc(
compile(fun: function) -> function
@@ -990,6 +1042,12 @@ void init_transforms(py::module_& m) {
:obj:`list` or a :obj:`dict` containing arbitrarily nested lists,
dictionaries, or arrays. Leaf nodes that are not :obj:`array` are ignored.
Default: ``None``
shapeless (bool, optional): A function compiled with the ``shapeless``
option enabled will not be recompiled when the input shape changes. Not all
functions can be compiled with ``shapeless`` enabled. Attempting to compile
such functions with shapeless enabled will throw. Note, changing the number
of dimensions or type of any input will result in a recompilation even with
``shapeless`` set to ``True``. Default: ``False``
Returns:
function: A compiled function which has the same input arguments