mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-24 17:31:16 +08:00
Log for complex numbers in Metal (#2025)
* Log for complex numbers in Metal * fix log2
This commit is contained in:
parent
b2d2b37888
commit
28f39e9038
@ -87,7 +87,6 @@ DEFAULT_UNARY(cosh, std::cosh)
|
|||||||
DEFAULT_UNARY(expm1, std::expm1)
|
DEFAULT_UNARY(expm1, std::expm1)
|
||||||
DEFAULT_UNARY(floor, std::floor)
|
DEFAULT_UNARY(floor, std::floor)
|
||||||
DEFAULT_UNARY(log, std::log)
|
DEFAULT_UNARY(log, std::log)
|
||||||
DEFAULT_UNARY(log2, std::log2)
|
|
||||||
DEFAULT_UNARY(log10, std::log10)
|
DEFAULT_UNARY(log10, std::log10)
|
||||||
DEFAULT_UNARY(log1p, std::log1p)
|
DEFAULT_UNARY(log1p, std::log1p)
|
||||||
DEFAULT_UNARY(sinh, std::sinh)
|
DEFAULT_UNARY(sinh, std::sinh)
|
||||||
@ -95,6 +94,17 @@ DEFAULT_UNARY(sqrt, std::sqrt)
|
|||||||
DEFAULT_UNARY(tan, std::tan)
|
DEFAULT_UNARY(tan, std::tan)
|
||||||
DEFAULT_UNARY(tanh, std::tanh)
|
DEFAULT_UNARY(tanh, std::tanh)
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
Simd<T, 1> log2(Simd<T, 1> in) {
|
||||||
|
if constexpr (is_complex<T>) {
|
||||||
|
auto out = std::log(in.value);
|
||||||
|
auto scale = decltype(out.real())(M_LN2);
|
||||||
|
return Simd<T, 1>{T{out.real() / scale, out.imag() / scale}};
|
||||||
|
} else {
|
||||||
|
return Simd<T, 1>{std::log2(in.value)};
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
Simd<T, 1> operator~(Simd<T, 1> in) {
|
Simd<T, 1> operator~(Simd<T, 1> in) {
|
||||||
return ~in.value;
|
return ~in.value;
|
||||||
|
@ -73,6 +73,9 @@ instantiate_unary_all_same(Conjugate, complex64, complex64_t)
|
|||||||
instantiate_unary_all_same(Cos, complex64, complex64_t)
|
instantiate_unary_all_same(Cos, complex64, complex64_t)
|
||||||
instantiate_unary_all_same(Cosh, complex64, complex64_t)
|
instantiate_unary_all_same(Cosh, complex64, complex64_t)
|
||||||
instantiate_unary_all_same(Exp, complex64, complex64_t)
|
instantiate_unary_all_same(Exp, complex64, complex64_t)
|
||||||
|
instantiate_unary_all_same(Log, complex64, complex64_t)
|
||||||
|
instantiate_unary_all_same(Log2, complex64, complex64_t)
|
||||||
|
instantiate_unary_all_same(Log10, complex64, complex64_t)
|
||||||
instantiate_unary_all_same(Negative, complex64, complex64_t)
|
instantiate_unary_all_same(Negative, complex64, complex64_t)
|
||||||
instantiate_unary_all_same(Sign, complex64, complex64_t)
|
instantiate_unary_all_same(Sign, complex64, complex64_t)
|
||||||
instantiate_unary_all_same(Sin, complex64, complex64_t)
|
instantiate_unary_all_same(Sin, complex64, complex64_t)
|
||||||
|
@ -257,6 +257,13 @@ struct Log {
|
|||||||
T operator()(T x) {
|
T operator()(T x) {
|
||||||
return metal::precise::log(x);
|
return metal::precise::log(x);
|
||||||
};
|
};
|
||||||
|
|
||||||
|
template <>
|
||||||
|
complex64_t operator()(complex64_t x) {
|
||||||
|
auto r = metal::precise::log(Abs{}(x).real);
|
||||||
|
auto i = metal::precise::atan2(x.imag, x.real);
|
||||||
|
return {r, i};
|
||||||
|
};
|
||||||
};
|
};
|
||||||
|
|
||||||
struct Log2 {
|
struct Log2 {
|
||||||
@ -264,6 +271,12 @@ struct Log2 {
|
|||||||
T operator()(T x) {
|
T operator()(T x) {
|
||||||
return metal::precise::log2(x);
|
return metal::precise::log2(x);
|
||||||
};
|
};
|
||||||
|
|
||||||
|
template <>
|
||||||
|
complex64_t operator()(complex64_t x) {
|
||||||
|
auto y = Log{}(x);
|
||||||
|
return {y.real / M_LN2_F, y.imag / M_LN2_F};
|
||||||
|
};
|
||||||
};
|
};
|
||||||
|
|
||||||
struct Log10 {
|
struct Log10 {
|
||||||
@ -271,6 +284,12 @@ struct Log10 {
|
|||||||
T operator()(T x) {
|
T operator()(T x) {
|
||||||
return metal::precise::log10(x);
|
return metal::precise::log10(x);
|
||||||
};
|
};
|
||||||
|
|
||||||
|
template <>
|
||||||
|
complex64_t operator()(complex64_t x) {
|
||||||
|
auto y = Log{}(x);
|
||||||
|
return {y.real / M_LN10_F, y.imag / M_LN10_F};
|
||||||
|
};
|
||||||
};
|
};
|
||||||
|
|
||||||
struct Log1p {
|
struct Log1p {
|
||||||
|
@ -845,6 +845,11 @@ class TestOps(mlx_tests.MLXTestCase):
|
|||||||
|
|
||||||
self.assertTrue(np.allclose(result, expected))
|
self.assertTrue(np.allclose(result, expected))
|
||||||
|
|
||||||
|
a = mx.array(1.0) + 1j * mx.array(2.0)
|
||||||
|
result = mx.log(a)
|
||||||
|
expected = np.log(np.array(a))
|
||||||
|
self.assertTrue(np.allclose(result, expected))
|
||||||
|
|
||||||
def test_log2(self):
|
def test_log2(self):
|
||||||
a = mx.array([0.5, 1, 2, 10, 16])
|
a = mx.array([0.5, 1, 2, 10, 16])
|
||||||
result = mx.log2(a)
|
result = mx.log2(a)
|
||||||
@ -852,6 +857,11 @@ class TestOps(mlx_tests.MLXTestCase):
|
|||||||
|
|
||||||
self.assertTrue(np.allclose(result, expected))
|
self.assertTrue(np.allclose(result, expected))
|
||||||
|
|
||||||
|
a = mx.array(1.0) + 1j * mx.array(2.0)
|
||||||
|
result = mx.log2(a)
|
||||||
|
expected = np.log2(np.array(a))
|
||||||
|
self.assertTrue(np.allclose(result, expected))
|
||||||
|
|
||||||
def test_log10(self):
|
def test_log10(self):
|
||||||
a = mx.array([0.1, 1, 10, 20, 100])
|
a = mx.array([0.1, 1, 10, 20, 100])
|
||||||
result = mx.log10(a)
|
result = mx.log10(a)
|
||||||
@ -859,6 +869,11 @@ class TestOps(mlx_tests.MLXTestCase):
|
|||||||
|
|
||||||
self.assertTrue(np.allclose(result, expected))
|
self.assertTrue(np.allclose(result, expected))
|
||||||
|
|
||||||
|
a = mx.array(1.0) + 1j * mx.array(2.0)
|
||||||
|
result = mx.log10(a)
|
||||||
|
expected = np.log10(np.array(a))
|
||||||
|
self.assertTrue(np.allclose(result, expected))
|
||||||
|
|
||||||
def test_exp(self):
|
def test_exp(self):
|
||||||
a = mx.array([0, 0.5, -0.5, 5])
|
a = mx.array([0, 0.5, -0.5, 5])
|
||||||
result = mx.exp(a)
|
result = mx.exp(a)
|
||||||
|
Loading…
Reference in New Issue
Block a user