mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-24 09:21:16 +08:00
fix unflatten vjp (#1708)
This commit is contained in:
parent
a82996e9fb
commit
d03c01dfbc
@ -1716,7 +1716,7 @@ std::vector<array> Unflatten::vjp(
|
|||||||
const std::vector<array>& cotangents,
|
const std::vector<array>& cotangents,
|
||||||
const std::vector<int>&,
|
const std::vector<int>&,
|
||||||
const std::vector<array>&) {
|
const std::vector<array>&) {
|
||||||
return {flatten(cotangents[0], axis_, axis_ + shape_.size(), stream())};
|
return {flatten(cotangents[0], axis_, axis_ + shape_.size() - 1, stream())};
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<array> Unflatten::jvp(
|
std::vector<array> Unflatten::jvp(
|
||||||
|
@ -605,6 +605,21 @@ class TestAutograd(mlx_tests.MLXTestCase):
|
|||||||
dfdx = mx.grad(fun)(x)
|
dfdx = mx.grad(fun)(x)
|
||||||
self.assertTrue(mx.allclose(dfdx, -2j * mx.ones_like(x)))
|
self.assertTrue(mx.allclose(dfdx, -2j * mx.ones_like(x)))
|
||||||
|
|
||||||
|
def test_flatten_unflatten_vjps(self):
|
||||||
|
def fun(x):
|
||||||
|
y = mx.unflatten(x, 0, (2, 2))
|
||||||
|
return y.sum()
|
||||||
|
|
||||||
|
x = mx.zeros((4, 8))
|
||||||
|
self.assertEqual(mx.grad(fun)(x).shape, (4, 8))
|
||||||
|
|
||||||
|
def fun(x):
|
||||||
|
y = mx.flatten(x, 0, 2)
|
||||||
|
return y.sum()
|
||||||
|
|
||||||
|
x = mx.zeros((2, 4, 8))
|
||||||
|
self.assertEqual(mx.grad(fun)(x).shape, (2, 4, 8))
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
unittest.main()
|
unittest.main()
|
||||||
|
Loading…
Reference in New Issue
Block a user