Custom transforms (#1246)

This commit is contained in:
Angelos Katharopoulos
2024-07-10 18:00:01 -07:00
committed by GitHub
parent a3c287354f
commit 5c1fa64fb0
16 changed files with 734 additions and 39 deletions

View File

@@ -1113,17 +1113,21 @@ std::pair<std::vector<array>, std::vector<int>> Cosh::vmap(
return {{cosh(inputs[0], stream())}, axes};
}
std::vector<array> CustomVJP::vjp(
std::vector<array> CustomTransforms::vjp(
const std::vector<array>& primals,
const std::vector<array>& cotangents,
const std::vector<int>& argnums,
const std::vector<array>& outputs) {
std::vector<array> inputs(primals.begin(), primals.end() - outputs.size());
// Extract the inputs to the VJP function
std::vector<array> inputs(primals.begin(), primals.end() - num_outputs_);
// Compute all the vjps
auto all_vjps = vjp_fun_(inputs, cotangents, outputs);
for (const auto& cot : cotangents) {
all_vjps.emplace_back(cot);
}
// Select the vjps requested
std::vector<array> vjps;
vjps.reserve(argnums.size());
for (auto arg : argnums) {
@@ -1133,6 +1137,26 @@ std::vector<array> CustomVJP::vjp(
return vjps;
}
std::vector<array> CustomTransforms::jvp(
const std::vector<array>& primals,
const std::vector<array>& tangents,
const std::vector<int>& argnums) {
// Extract the inputs to the JVP function
std::vector<array> inputs(primals.begin(), primals.end() - num_outputs_);
// Compute the jvps
return jvp_fun_(inputs, tangents, argnums);
}
std::pair<std::vector<array>, std::vector<int>> CustomTransforms::vmap(
const std::vector<array>& inputs_,
const std::vector<int>& axes_) {
// Extract the inputs to the vmap function
std::vector<array> inputs(inputs_.begin(), inputs_.end() - num_outputs_);
std::vector<int> axes(axes_.begin(), axes_.end() - num_outputs_);
return vmap_fun_(inputs, axes);
}
std::vector<array> Depends::vjp(
const std::vector<array>& primals,
const std::vector<array>& cotangents,