mirror of
https://github.com/ml-explore/mlx.git
synced 2025-07-25 20:11:19 +08:00
Fix imag() vjp (#2367)
This commit is contained in:
parent
2d3c26c565
commit
5201df5030
@ -2459,7 +2459,7 @@ std::vector<array> Imag::vjp(
|
||||
assert(primals.size() == 1);
|
||||
assert(argnums.size() == 1);
|
||||
return {multiply(
|
||||
array(complex64_t{0.0f, -1.0f}, primals[0].dtype()),
|
||||
array(complex64_t{0.0f, 1.0f}, primals[0].dtype()),
|
||||
cotangents[0],
|
||||
stream())};
|
||||
}
|
||||
@ -2788,15 +2788,19 @@ std::vector<array> Matmul::vjp(
|
||||
std::vector<int> reorder(cotan.ndim());
|
||||
std::iota(reorder.begin(), reorder.end(), 0);
|
||||
std::iter_swap(reorder.end() - 1, reorder.end() - 2);
|
||||
auto& s = stream();
|
||||
|
||||
auto complex_transpose = [&](const array& x) {
|
||||
return transpose(conjugate(x, s), reorder, s);
|
||||
};
|
||||
|
||||
for (auto arg : argnums) {
|
||||
if (arg == 0) {
|
||||
// M X N * (K X N).T -> M X K
|
||||
vjps.push_back(
|
||||
matmul(cotan, transpose(primals[1], reorder, stream()), stream()));
|
||||
vjps.push_back(matmul(cotan, complex_transpose(primals[1]), s));
|
||||
} else {
|
||||
// (M X K).T * M X N -> K X N
|
||||
vjps.push_back(
|
||||
matmul(transpose(primals[0], reorder, stream()), cotan, stream()));
|
||||
vjps.push_back(matmul(complex_transpose(primals[0]), cotan, s));
|
||||
}
|
||||
}
|
||||
return vjps;
|
||||
|
@ -606,7 +606,7 @@ class TestAutograd(mlx_tests.MLXTestCase):
|
||||
|
||||
x = mx.array([0.0 + 1j, 1.0 + 0.0j, 0.5 + 0.5j])
|
||||
dfdx = mx.grad(fun)(x)
|
||||
self.assertTrue(mx.allclose(dfdx, -2j * mx.ones_like(x)))
|
||||
self.assertTrue(mx.allclose(dfdx, 2j * mx.ones_like(x)))
|
||||
|
||||
def test_flatten_unflatten_vjps(self):
|
||||
def fun(x):
|
||||
|
Loading…
Reference in New Issue
Block a user