Custom transforms (#1246)

This commit is contained in:
Angelos Katharopoulos
2024-07-10 18:00:01 -07:00
committed by GitHub
parent a3c287354f
commit 5c1fa64fb0
16 changed files with 734 additions and 39 deletions

View File

@@ -33,8 +33,11 @@ class Synchronizer : public Primitive {
// Initialize the static tracing counter from transforms_impl.h .
//
// This is used to implement the in_tracing() function the returns true if we
// are currently under a function transformation.
// are currently under a function transformation and the retain_graph()
// function which returns true if we are forced to retain the graph during
// evaluation.
int detail::InTracing::tracing_counter{0};
int detail::RetainGraph::tracing_counter{0};
array eval_impl(std::vector<array> outputs, bool async) {
std::queue<array> tape;
@@ -331,7 +334,11 @@ std::pair<std::vector<array>, std::vector<array>> vjp(
}
}
auto vjps = a.primitive().vjp(a.inputs(), cotangents, argnums, outputs);
std::vector<array> vjps;
{
detail::RetainGraph retain;
vjps = a.primitive().vjp(a.inputs(), cotangents, argnums, outputs);
}
// Accumulate the vector-jacobian products for each input
for (int i = 0; i < argnums.size(); ++i) {
auto in_id = a.inputs()[argnums[i]].id();
@@ -778,14 +785,27 @@ std::function<array(const array&)> vmap(
return [vfun](const array& a) { return vfun({a})[0]; };
}
std::function<std::vector<array>(const std::vector<array>&)> custom_vjp(
std::function<std::vector<array>(const std::vector<array>&)> custom_function(
std::function<std::vector<array>(const std::vector<array>&)> fun,
std::function<std::vector<array>(
std::optional<std::function<std::vector<array>(
const std::vector<array>&,
const std::vector<array>&,
const std::vector<array>&)> fun_vjp) {
const std::vector<array>&)>> fun_vjp /* = std::nullopt */,
std::optional<std::function<std::vector<array>(
const std::vector<array>&,
const std::vector<array>&,
const std::vector<int>&)>> fun_jvp /* = std::nullopt */,
std::optional<std::function<std::pair<std::vector<array>, std::vector<int>>(
const std::vector<array>&,
const std::vector<int>&)>> fun_vmap /* = std::nullopt */) {
if (!fun_vjp.has_value() && !fun_jvp.has_value() && !fun_vmap.has_value()) {
return fun;
}
return [fun = std::move(fun),
fun_vjp = std::move(fun_vjp)](const std::vector<array>& args) {
fun_vjp = std::move(fun_vjp),
fun_jvp = std::move(fun_jvp),
fun_vmap = std::move(fun_vmap)](const std::vector<array>& args) {
// Compute the outputs
auto outputs = fun(args);
for (auto& out : outputs) {
@@ -814,11 +834,63 @@ std::function<std::vector<array>(const std::vector<array>&)> custom_vjp(
return array::make_arrays(
std::move(shapes),
dtypes,
std::make_shared<CustomVJP>(to_stream(s), fun_vjp),
std::make_shared<CustomTransforms>(
to_stream(s),
outputs.size(),
// We use the passed vjp function or compute it from the inputs and
// passed cotangents. Note that this may be less efficient than
// using `fun` directly because we may not be able to fully reuse
// the outputs of the forward pass.
fun_vjp.value_or(
[fun](auto primals, auto cotangents, auto outputs) {
auto [__, vjps] = vjp(fun, primals, cotangents);
return vjps;
}),
// We use the passed jvp function or compute it from the primals
// and tangents. Similarly we can't take full advantage of the
// argnums so it is best to use `fun` directly if we don't need a
// custom transform.
//
// TODO: Use stop_gradient to make full use of argnums and not
// waste computation.
fun_jvp.value_or([fun](auto primals, auto tangents, auto argnums) {
std::vector<array> all_tangents;
for (int i = 0, j = 0; i < primals.size(); i++) {
if (j < argnums.size() && i == argnums[j]) {
all_tangents.emplace_back(tangents[j++]);
} else {
all_tangents.emplace_back(zeros_like(primals[i]));
}
}
auto [__, jvps] = jvp(fun, primals, all_tangents);
return jvps;
}),
// Same as above, we use the passed vmap function or we compute it
// from `fun`. The output axes is selected to be all 0s which again
// may be suboptimal but the only thing we can do without any
// information for `fun`.
fun_vmap.value_or(
[fun, out_size = outputs.size()](auto inputs, auto in_axes)
-> std::pair<std::vector<array>, std::vector<int>> {
std::vector<int> out_axes(out_size, 0);
return {vmap(fun, in_axes, out_axes)(inputs), out_axes};
})),
inputs);
};
}
std::function<std::vector<array>(const std::vector<array>&)> custom_vjp(
std::function<std::vector<array>(const std::vector<array>&)> fun,
std::function<std::vector<array>(
const std::vector<array>&,
const std::vector<array>&,
const std::vector<array>&)> fun_vjp) {
return custom_function(fun, fun_vjp, std::nullopt, std::nullopt);
}
std::function<std::vector<array>(const std::vector<array>&)> checkpoint(
std::function<std::vector<array>(const std::vector<array>&)> fun) {
auto vjp_fun = [fun](