fix unflatten vjp (#1708)

This commit is contained in:
Awni Hannun
2024-12-16 18:37:57 -08:00
committed by GitHub
parent a82996e9fb
commit d03c01dfbc
2 changed files with 16 additions and 1 deletions

View File

@@ -1716,7 +1716,7 @@ std::vector<array> Unflatten::vjp(
const std::vector<array>& cotangents,
const std::vector<int>&,
const std::vector<array>&) {
return {flatten(cotangents[0], axis_, axis_ + shape_.size(), stream())};
return {flatten(cotangents[0], axis_, axis_ + shape_.size() - 1, stream())};
}
std::vector<array> Unflatten::jvp(