Fix complex power on Metal (#1460)

This commit is contained in:
Awni Hannun 2024-10-06 19:58:30 -07:00 committed by GitHub
parent e4534dac17
commit 95d04805b3
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 8 additions and 1 deletions

View File

@ -217,7 +217,7 @@ struct Power {
template <> template <>
complex64_t operator()(complex64_t x, complex64_t y) { complex64_t operator()(complex64_t x, complex64_t y) {
auto x_theta = metal::atan(x.imag / x.real); auto x_theta = metal::atan2(x.imag, x.real);
auto x_ln_r = 0.5 * metal::log(x.real * x.real + x.imag * x.imag); auto x_ln_r = 0.5 * metal::log(x.real * x.real + x.imag * x.imag);
auto mag = metal::exp(y.real * x_ln_r - y.imag * x_theta); auto mag = metal::exp(y.real * x_ln_r - y.imag * x_theta);
auto phase = y.imag * x_ln_r + y.real * x_theta; auto phase = y.imag * x_ln_r + y.real * x_theta;

View File

@ -2553,6 +2553,13 @@ TEST_CASE("test power") {
auto out = (power(array(a), array(b))).item<complex64_t>(); auto out = (power(array(a), array(b))).item<complex64_t>();
CHECK(abs(out.real() - expected.real()) < 1e-7); CHECK(abs(out.real() - expected.real()) < 1e-7);
CHECK(abs(out.imag() - expected.imag()) < 1e-7); CHECK(abs(out.imag() - expected.imag()) < 1e-7);
a = complex64_t{-1.2, 0.1};
b = complex64_t{2.2, 0.0};
expected = std::pow(a, b);
out = (power(array(a), array(b))).item<complex64_t>();
CHECK(abs(out.real() - expected.real()) < 1e-6);
CHECK(abs(out.imag() - expected.imag()) < 1e-6);
} }
TEST_CASE("test where") { TEST_CASE("test where") {