mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-18 10:11:43 +08:00
Limit grad recursion depth by not recursing through non-grad inputs (#1764)
* limit grad recursion depth * add grad of module test
This commit is contained in:
@@ -278,7 +278,8 @@ void eval(std::vector<array> outputs) {
|
||||
std::pair<std::vector<array>, std::vector<array>> vjp(
|
||||
const std::function<std::vector<array>(const std::vector<array>&)>& fun,
|
||||
const std::vector<array>& primals,
|
||||
const std::vector<array>& cotans) {
|
||||
const std::vector<array>& cotans,
|
||||
const std::vector<int>& argnums) {
|
||||
// Set the global tracing flag.
|
||||
detail::InTracing in_tracing;
|
||||
|
||||
@@ -330,10 +331,14 @@ std::pair<std::vector<array>, std::vector<array>> vjp(
|
||||
// to the tape which need a gradient.
|
||||
std::unordered_set<std::uintptr_t> cache;
|
||||
std::unordered_set<std::uintptr_t> calc_grad;
|
||||
for (auto& primal : primals_) {
|
||||
for (int i = 0, j = 0; i < primals_.size(); ++i) {
|
||||
auto& primal = primals_[i];
|
||||
primal.set_tracer(false);
|
||||
calc_grad.insert(primal.id());
|
||||
cache.insert(primal.id());
|
||||
if (j < argnums.size() && argnums[j] == i) {
|
||||
j++;
|
||||
calc_grad.insert(primal.id());
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<array> tape;
|
||||
@@ -435,7 +440,8 @@ std::pair<std::vector<array>, std::vector<array>> vjp(
|
||||
}
|
||||
}
|
||||
std::vector<array> vjps;
|
||||
for (auto& primal : primals_) {
|
||||
for (auto arg : argnums) {
|
||||
auto& primal = primals_[arg];
|
||||
if (auto cotan_it = cotan_map.find(primal.id());
|
||||
cotan_it != cotan_map.end()) {
|
||||
vjps.push_back(cotan_it->second);
|
||||
@@ -448,6 +454,15 @@ std::pair<std::vector<array>, std::vector<array>> vjp(
|
||||
return {outputs, vjps};
|
||||
}
|
||||
|
||||
std::pair<std::vector<array>, std::vector<array>> vjp(
|
||||
const std::function<std::vector<array>(const std::vector<array>&)>& fun,
|
||||
const std::vector<array>& primals,
|
||||
const std::vector<array>& cotans) {
|
||||
std::vector<int> argnums(primals.size());
|
||||
std::iota(argnums.begin(), argnums.end(), 0);
|
||||
return vjp(fun, primals, cotans, argnums);
|
||||
}
|
||||
|
||||
std::pair<array, array> vjp(
|
||||
const std::function<array(const array&)>& fun,
|
||||
const array& primal,
|
||||
@@ -606,15 +621,10 @@ ValueAndGradFn value_and_grad(
|
||||
<< inputs.size() << " inputs.";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
std::vector<int> sorted_argnums(args.begin(), args.end());
|
||||
|
||||
auto gfun = [&fun, &inputs, &args](const std::vector<array>& ginputs) {
|
||||
std::vector<array> inputs_(inputs);
|
||||
auto argit = args.begin();
|
||||
for (int i = 0; i < ginputs.size(); ++i) {
|
||||
inputs_[*argit] = ginputs[i];
|
||||
++argit;
|
||||
}
|
||||
auto outputs = fun(inputs_);
|
||||
auto gfun = [&fun](const std::vector<array>& inputs) {
|
||||
auto outputs = fun(inputs);
|
||||
for (int i = 1; i < outputs.size(); i++) {
|
||||
auto& out = outputs[i];
|
||||
auto s = out.has_primitive() ? out.primitive().stream()
|
||||
@@ -624,12 +634,8 @@ ValueAndGradFn value_and_grad(
|
||||
return outputs;
|
||||
};
|
||||
|
||||
std::vector<array> ginputs;
|
||||
for (auto arg : args) {
|
||||
ginputs.push_back(inputs[arg]);
|
||||
}
|
||||
// Set the incoming gradient to float32, vjp will cast it to the output type
|
||||
auto [outputs, grads] = vjp(gfun, ginputs, {array(1.0f)});
|
||||
auto [outputs, grads] = vjp(gfun, inputs, {array(1.0f)}, sorted_argnums);
|
||||
return std::make_pair(outputs, grads);
|
||||
};
|
||||
}
|
||||
|
Reference in New Issue
Block a user