diff --git a/mlx/primitives.cpp b/mlx/primitives.cpp index eb5d9d6b3..72affbd34 100644 --- a/mlx/primitives.cpp +++ b/mlx/primitives.cpp @@ -2459,7 +2459,7 @@ std::vector Imag::vjp( assert(primals.size() == 1); assert(argnums.size() == 1); return {multiply( - array(complex64_t{0.0f, -1.0f}, primals[0].dtype()), + array(complex64_t{0.0f, 1.0f}, primals[0].dtype()), cotangents[0], stream())}; } @@ -2788,15 +2788,19 @@ std::vector Matmul::vjp( std::vector reorder(cotan.ndim()); std::iota(reorder.begin(), reorder.end(), 0); std::iter_swap(reorder.end() - 1, reorder.end() - 2); + auto& s = stream(); + + auto complex_transpose = [&](const array& x) { + return transpose(conjugate(x, s), reorder, s); + }; + for (auto arg : argnums) { if (arg == 0) { // M X N * (K X N).T -> M X K - vjps.push_back( - matmul(cotan, transpose(primals[1], reorder, stream()), stream())); + vjps.push_back(matmul(cotan, complex_transpose(primals[1]), s)); } else { // (M X K).T * M X N -> K X N - vjps.push_back( - matmul(transpose(primals[0], reorder, stream()), cotan, stream())); + vjps.push_back(matmul(complex_transpose(primals[0]), cotan, s)); } } return vjps; diff --git a/python/tests/test_autograd.py b/python/tests/test_autograd.py index 7973d79be..5722071f6 100644 --- a/python/tests/test_autograd.py +++ b/python/tests/test_autograd.py @@ -606,7 +606,7 @@ class TestAutograd(mlx_tests.MLXTestCase): x = mx.array([0.0 + 1j, 1.0 + 0.0j, 0.5 + 0.5j]) dfdx = mx.grad(fun)(x) - self.assertTrue(mx.allclose(dfdx, -2j * mx.ones_like(x))) + self.assertTrue(mx.allclose(dfdx, 2j * mx.ones_like(x))) def test_flatten_unflatten_vjps(self): def fun(x):