More helpful error message in vjp transform + concate bug (#543)

* more helpful message in vjp transform

* fix concatenate on mismatch dims

* typo

* typo
This commit is contained in:
Awni Hannun
2024-01-24 09:58:33 -08:00
committed by GitHub
parent f30e63353a
commit f27ec5e097
3 changed files with 31 additions and 9 deletions

View File

@@ -337,12 +337,21 @@ std::pair<std::vector<array>, std::vector<array>> vjp(
}
}
if (cotan_index >= cotans.size()) {
throw std::invalid_argument(
"[vjp] Number of outputs with gradient does not match number of cotangents.");
std::ostringstream msg;
msg << "[vjp] Number of outputs to compute gradients for ("
<< outputs.size() << ") does not match number of cotangents ("
<< cotans.size() << ").";
throw std::invalid_argument(msg.str());
}
if (out.shape() != cotans[cotan_index].shape()) {
throw std::invalid_argument(
"[vjp] Output shape does not match shape of cotangent.");
std::ostringstream msg;
msg << "[vjp] Output shape " << out.shape()
<< " does not match cotangent shape " << cotans[cotan_index].shape()
<< ".";
if (outputs.size() == 1 && out.size() == 1) {
msg << " If you are using grad your function must return a scalar.";
}
throw std::invalid_argument(msg.str());
}
output_cotan_pairs.emplace_back(i, cotan_index++);
}