Limit grad recursion depth by not recursing through non-grad inputs (#1764)

* limit grad recursion depth

* add grad of module test
This commit is contained in:
Awni Hannun
2025-01-14 14:33:18 -08:00
committed by GitHub
parent 5cc5201914
commit 33421c1dd3
6 changed files with 136 additions and 100 deletions

View File

@@ -278,7 +278,8 @@ void eval(std::vector<array> outputs) {
std::pair<std::vector<array>, std::vector<array>> vjp(
const std::function<std::vector<array>(const std::vector<array>&)>& fun,
const std::vector<array>& primals,
const std::vector<array>& cotans) {
const std::vector<array>& cotans,
const std::vector<int>& argnums) {
// Set the global tracing flag.
detail::InTracing in_tracing;
@@ -330,10 +331,14 @@ std::pair<std::vector<array>, std::vector<array>> vjp(
// to the tape which need a gradient.
std::unordered_set<std::uintptr_t> cache;
std::unordered_set<std::uintptr_t> calc_grad;
for (auto& primal : primals_) {
for (int i = 0, j = 0; i < primals_.size(); ++i) {
auto& primal = primals_[i];
primal.set_tracer(false);
calc_grad.insert(primal.id());
cache.insert(primal.id());
if (j < argnums.size() && argnums[j] == i) {
j++;
calc_grad.insert(primal.id());
}
}
std::vector<array> tape;
@@ -435,7 +440,8 @@ std::pair<std::vector<array>, std::vector<array>> vjp(
}
}
std::vector<array> vjps;
for (auto& primal : primals_) {
for (auto arg : argnums) {
auto& primal = primals_[arg];
if (auto cotan_it = cotan_map.find(primal.id());
cotan_it != cotan_map.end()) {
vjps.push_back(cotan_it->second);
@@ -448,6 +454,15 @@ std::pair<std::vector<array>, std::vector<array>> vjp(
return {outputs, vjps};
}
std::pair<std::vector<array>, std::vector<array>> vjp(
const std::function<std::vector<array>(const std::vector<array>&)>& fun,
const std::vector<array>& primals,
const std::vector<array>& cotans) {
std::vector<int> argnums(primals.size());
std::iota(argnums.begin(), argnums.end(), 0);
return vjp(fun, primals, cotans, argnums);
}
std::pair<array, array> vjp(
const std::function<array(const array&)>& fun,
const array& primal,
@@ -606,15 +621,10 @@ ValueAndGradFn value_and_grad(
<< inputs.size() << " inputs.";
throw std::invalid_argument(msg.str());
}
std::vector<int> sorted_argnums(args.begin(), args.end());
auto gfun = [&fun, &inputs, &args](const std::vector<array>& ginputs) {
std::vector<array> inputs_(inputs);
auto argit = args.begin();
for (int i = 0; i < ginputs.size(); ++i) {
inputs_[*argit] = ginputs[i];
++argit;
}
auto outputs = fun(inputs_);
auto gfun = [&fun](const std::vector<array>& inputs) {
auto outputs = fun(inputs);
for (int i = 1; i < outputs.size(); i++) {
auto& out = outputs[i];
auto s = out.has_primitive() ? out.primitive().stream()
@@ -624,12 +634,8 @@ ValueAndGradFn value_and_grad(
return outputs;
};
std::vector<array> ginputs;
for (auto arg : args) {
ginputs.push_back(inputs[arg]);
}
// Set the incoming gradient to float32, vjp will cast it to the output type
auto [outputs, grads] = vjp(gfun, ginputs, {array(1.0f)});
auto [outputs, grads] = vjp(gfun, inputs, {array(1.0f)}, sorted_argnums);
return std::make_pair(outputs, grads);
};
}