Fix imag() vjp (#2367)

This commit is contained in:
Angelos Katharopoulos 2025-07-14 13:11:16 -07:00 committed by GitHub
parent 2d3c26c565
commit 5201df5030
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 10 additions and 6 deletions

View File

@ -2459,7 +2459,7 @@ std::vector<array> Imag::vjp(
assert(primals.size() == 1);
assert(argnums.size() == 1);
return {multiply(
array(complex64_t{0.0f, -1.0f}, primals[0].dtype()),
array(complex64_t{0.0f, 1.0f}, primals[0].dtype()),
cotangents[0],
stream())};
}
@ -2788,15 +2788,19 @@ std::vector<array> Matmul::vjp(
std::vector<int> reorder(cotan.ndim());
std::iota(reorder.begin(), reorder.end(), 0);
std::iter_swap(reorder.end() - 1, reorder.end() - 2);
auto& s = stream();
auto complex_transpose = [&](const array& x) {
return transpose(conjugate(x, s), reorder, s);
};
for (auto arg : argnums) {
if (arg == 0) {
// M X N * (K X N).T -> M X K
vjps.push_back(
matmul(cotan, transpose(primals[1], reorder, stream()), stream()));
vjps.push_back(matmul(cotan, complex_transpose(primals[1]), s));
} else {
// (M X K).T * M X N -> K X N
vjps.push_back(
matmul(transpose(primals[0], reorder, stream()), cotan, stream()));
vjps.push_back(matmul(complex_transpose(primals[0]), cotan, s));
}
}
return vjps;

View File

@ -606,7 +606,7 @@ class TestAutograd(mlx_tests.MLXTestCase):
x = mx.array([0.0 + 1j, 1.0 + 0.0j, 0.5 + 0.5j])
dfdx = mx.grad(fun)(x)
self.assertTrue(mx.allclose(dfdx, -2j * mx.ones_like(x)))
self.assertTrue(mx.allclose(dfdx, 2j * mx.ones_like(x)))
def test_flatten_unflatten_vjps(self):
def fun(x):