mirror of
https://github.com/ml-explore/mlx.git
synced 2025-07-27 11:31:21 +08:00
Fix imag() vjp (#2367)
This commit is contained in:
parent
2d3c26c565
commit
5201df5030
@ -2459,7 +2459,7 @@ std::vector<array> Imag::vjp(
|
|||||||
assert(primals.size() == 1);
|
assert(primals.size() == 1);
|
||||||
assert(argnums.size() == 1);
|
assert(argnums.size() == 1);
|
||||||
return {multiply(
|
return {multiply(
|
||||||
array(complex64_t{0.0f, -1.0f}, primals[0].dtype()),
|
array(complex64_t{0.0f, 1.0f}, primals[0].dtype()),
|
||||||
cotangents[0],
|
cotangents[0],
|
||||||
stream())};
|
stream())};
|
||||||
}
|
}
|
||||||
@ -2788,15 +2788,19 @@ std::vector<array> Matmul::vjp(
|
|||||||
std::vector<int> reorder(cotan.ndim());
|
std::vector<int> reorder(cotan.ndim());
|
||||||
std::iota(reorder.begin(), reorder.end(), 0);
|
std::iota(reorder.begin(), reorder.end(), 0);
|
||||||
std::iter_swap(reorder.end() - 1, reorder.end() - 2);
|
std::iter_swap(reorder.end() - 1, reorder.end() - 2);
|
||||||
|
auto& s = stream();
|
||||||
|
|
||||||
|
auto complex_transpose = [&](const array& x) {
|
||||||
|
return transpose(conjugate(x, s), reorder, s);
|
||||||
|
};
|
||||||
|
|
||||||
for (auto arg : argnums) {
|
for (auto arg : argnums) {
|
||||||
if (arg == 0) {
|
if (arg == 0) {
|
||||||
// M X N * (K X N).T -> M X K
|
// M X N * (K X N).T -> M X K
|
||||||
vjps.push_back(
|
vjps.push_back(matmul(cotan, complex_transpose(primals[1]), s));
|
||||||
matmul(cotan, transpose(primals[1], reorder, stream()), stream()));
|
|
||||||
} else {
|
} else {
|
||||||
// (M X K).T * M X N -> K X N
|
// (M X K).T * M X N -> K X N
|
||||||
vjps.push_back(
|
vjps.push_back(matmul(complex_transpose(primals[0]), cotan, s));
|
||||||
matmul(transpose(primals[0], reorder, stream()), cotan, stream()));
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return vjps;
|
return vjps;
|
||||||
|
@ -606,7 +606,7 @@ class TestAutograd(mlx_tests.MLXTestCase):
|
|||||||
|
|
||||||
x = mx.array([0.0 + 1j, 1.0 + 0.0j, 0.5 + 0.5j])
|
x = mx.array([0.0 + 1j, 1.0 + 0.0j, 0.5 + 0.5j])
|
||||||
dfdx = mx.grad(fun)(x)
|
dfdx = mx.grad(fun)(x)
|
||||||
self.assertTrue(mx.allclose(dfdx, -2j * mx.ones_like(x)))
|
self.assertTrue(mx.allclose(dfdx, 2j * mx.ones_like(x)))
|
||||||
|
|
||||||
def test_flatten_unflatten_vjps(self):
|
def test_flatten_unflatten_vjps(self):
|
||||||
def fun(x):
|
def fun(x):
|
||||||
|
Loading…
Reference in New Issue
Block a user