Log for complex numbers in Metal (#2025)

* Log for complex numbers in Metal

* fix log2
This commit is contained in:
Awni Hannun 2025-03-30 17:04:38 -07:00 committed by GitHub
parent b2d2b37888
commit 28f39e9038
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 48 additions and 1 deletions

View File

@ -87,7 +87,6 @@ DEFAULT_UNARY(cosh, std::cosh)
DEFAULT_UNARY(expm1, std::expm1) DEFAULT_UNARY(expm1, std::expm1)
DEFAULT_UNARY(floor, std::floor) DEFAULT_UNARY(floor, std::floor)
DEFAULT_UNARY(log, std::log) DEFAULT_UNARY(log, std::log)
DEFAULT_UNARY(log2, std::log2)
DEFAULT_UNARY(log10, std::log10) DEFAULT_UNARY(log10, std::log10)
DEFAULT_UNARY(log1p, std::log1p) DEFAULT_UNARY(log1p, std::log1p)
DEFAULT_UNARY(sinh, std::sinh) DEFAULT_UNARY(sinh, std::sinh)
@ -95,6 +94,17 @@ DEFAULT_UNARY(sqrt, std::sqrt)
DEFAULT_UNARY(tan, std::tan) DEFAULT_UNARY(tan, std::tan)
DEFAULT_UNARY(tanh, std::tanh) DEFAULT_UNARY(tanh, std::tanh)
template <typename T>
Simd<T, 1> log2(Simd<T, 1> in) {
if constexpr (is_complex<T>) {
auto out = std::log(in.value);
auto scale = decltype(out.real())(M_LN2);
return Simd<T, 1>{T{out.real() / scale, out.imag() / scale}};
} else {
return Simd<T, 1>{std::log2(in.value)};
}
}
template <typename T> template <typename T>
Simd<T, 1> operator~(Simd<T, 1> in) { Simd<T, 1> operator~(Simd<T, 1> in) {
return ~in.value; return ~in.value;

View File

@ -73,6 +73,9 @@ instantiate_unary_all_same(Conjugate, complex64, complex64_t)
instantiate_unary_all_same(Cos, complex64, complex64_t) instantiate_unary_all_same(Cos, complex64, complex64_t)
instantiate_unary_all_same(Cosh, complex64, complex64_t) instantiate_unary_all_same(Cosh, complex64, complex64_t)
instantiate_unary_all_same(Exp, complex64, complex64_t) instantiate_unary_all_same(Exp, complex64, complex64_t)
instantiate_unary_all_same(Log, complex64, complex64_t)
instantiate_unary_all_same(Log2, complex64, complex64_t)
instantiate_unary_all_same(Log10, complex64, complex64_t)
instantiate_unary_all_same(Negative, complex64, complex64_t) instantiate_unary_all_same(Negative, complex64, complex64_t)
instantiate_unary_all_same(Sign, complex64, complex64_t) instantiate_unary_all_same(Sign, complex64, complex64_t)
instantiate_unary_all_same(Sin, complex64, complex64_t) instantiate_unary_all_same(Sin, complex64, complex64_t)

View File

@ -257,6 +257,13 @@ struct Log {
T operator()(T x) { T operator()(T x) {
return metal::precise::log(x); return metal::precise::log(x);
}; };
template <>
complex64_t operator()(complex64_t x) {
auto r = metal::precise::log(Abs{}(x).real);
auto i = metal::precise::atan2(x.imag, x.real);
return {r, i};
};
}; };
struct Log2 { struct Log2 {
@ -264,6 +271,12 @@ struct Log2 {
T operator()(T x) { T operator()(T x) {
return metal::precise::log2(x); return metal::precise::log2(x);
}; };
template <>
complex64_t operator()(complex64_t x) {
auto y = Log{}(x);
return {y.real / M_LN2_F, y.imag / M_LN2_F};
};
}; };
struct Log10 { struct Log10 {
@ -271,6 +284,12 @@ struct Log10 {
T operator()(T x) { T operator()(T x) {
return metal::precise::log10(x); return metal::precise::log10(x);
}; };
template <>
complex64_t operator()(complex64_t x) {
auto y = Log{}(x);
return {y.real / M_LN10_F, y.imag / M_LN10_F};
};
}; };
struct Log1p { struct Log1p {

View File

@ -845,6 +845,11 @@ class TestOps(mlx_tests.MLXTestCase):
self.assertTrue(np.allclose(result, expected)) self.assertTrue(np.allclose(result, expected))
a = mx.array(1.0) + 1j * mx.array(2.0)
result = mx.log(a)
expected = np.log(np.array(a))
self.assertTrue(np.allclose(result, expected))
def test_log2(self): def test_log2(self):
a = mx.array([0.5, 1, 2, 10, 16]) a = mx.array([0.5, 1, 2, 10, 16])
result = mx.log2(a) result = mx.log2(a)
@ -852,6 +857,11 @@ class TestOps(mlx_tests.MLXTestCase):
self.assertTrue(np.allclose(result, expected)) self.assertTrue(np.allclose(result, expected))
a = mx.array(1.0) + 1j * mx.array(2.0)
result = mx.log2(a)
expected = np.log2(np.array(a))
self.assertTrue(np.allclose(result, expected))
def test_log10(self): def test_log10(self):
a = mx.array([0.1, 1, 10, 20, 100]) a = mx.array([0.1, 1, 10, 20, 100])
result = mx.log10(a) result = mx.log10(a)
@ -859,6 +869,11 @@ class TestOps(mlx_tests.MLXTestCase):
self.assertTrue(np.allclose(result, expected)) self.assertTrue(np.allclose(result, expected))
a = mx.array(1.0) + 1j * mx.array(2.0)
result = mx.log10(a)
expected = np.log10(np.array(a))
self.assertTrue(np.allclose(result, expected))
def test_exp(self): def test_exp(self):
a = mx.array([0, 0.5, -0.5, 5]) a = mx.array([0, 0.5, -0.5, 5])
result = mx.exp(a) result = mx.exp(a)