mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-18 01:50:16 +08:00
fix unflatten vjp (#1708)
This commit is contained in:
@@ -1716,7 +1716,7 @@ std::vector<array> Unflatten::vjp(
|
||||
const std::vector<array>& cotangents,
|
||||
const std::vector<int>&,
|
||||
const std::vector<array>&) {
|
||||
return {flatten(cotangents[0], axis_, axis_ + shape_.size(), stream())};
|
||||
return {flatten(cotangents[0], axis_, axis_ + shape_.size() - 1, stream())};
|
||||
}
|
||||
|
||||
std::vector<array> Unflatten::jvp(
|
||||
|
Reference in New Issue
Block a user