From 95d04805b312db955bb76255ff02f6f7132fb17a Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Sun, 6 Oct 2024 19:58:30 -0700 Subject: [PATCH] Fix complex power on Metal (#1460) --- mlx/backend/metal/kernels/binary_ops.h | 2 +- tests/ops_tests.cpp | 7 +++++++ 2 files changed, 8 insertions(+), 1 deletion(-) diff --git a/mlx/backend/metal/kernels/binary_ops.h b/mlx/backend/metal/kernels/binary_ops.h index 9cd2126cd..8f961c2cf 100644 --- a/mlx/backend/metal/kernels/binary_ops.h +++ b/mlx/backend/metal/kernels/binary_ops.h @@ -217,7 +217,7 @@ struct Power { template <> complex64_t operator()(complex64_t x, complex64_t y) { - auto x_theta = metal::atan(x.imag / x.real); + auto x_theta = metal::atan2(x.imag, x.real); auto x_ln_r = 0.5 * metal::log(x.real * x.real + x.imag * x.imag); auto mag = metal::exp(y.real * x_ln_r - y.imag * x_theta); auto phase = y.imag * x_ln_r + y.real * x_theta; diff --git a/tests/ops_tests.cpp b/tests/ops_tests.cpp index 6eec3e99c..682f668c9 100644 --- a/tests/ops_tests.cpp +++ b/tests/ops_tests.cpp @@ -2553,6 +2553,13 @@ TEST_CASE("test power") { auto out = (power(array(a), array(b))).item(); CHECK(abs(out.real() - expected.real()) < 1e-7); CHECK(abs(out.imag() - expected.imag()) < 1e-7); + + a = complex64_t{-1.2, 0.1}; + b = complex64_t{2.2, 0.0}; + expected = std::pow(a, b); + out = (power(array(a), array(b))).item(); + CHECK(abs(out.real() - expected.real()) < 1e-6); + CHECK(abs(out.imag() - expected.imag()) < 1e-6); } TEST_CASE("test where") {