diff --git a/mlx/backend/cpu/simd/base_simd.h b/mlx/backend/cpu/simd/base_simd.h index bc416fc22..7e82a4d56 100644 --- a/mlx/backend/cpu/simd/base_simd.h +++ b/mlx/backend/cpu/simd/base_simd.h @@ -87,7 +87,6 @@ DEFAULT_UNARY(cosh, std::cosh) DEFAULT_UNARY(expm1, std::expm1) DEFAULT_UNARY(floor, std::floor) DEFAULT_UNARY(log, std::log) -DEFAULT_UNARY(log2, std::log2) DEFAULT_UNARY(log10, std::log10) DEFAULT_UNARY(log1p, std::log1p) DEFAULT_UNARY(sinh, std::sinh) @@ -95,6 +94,17 @@ DEFAULT_UNARY(sqrt, std::sqrt) DEFAULT_UNARY(tan, std::tan) DEFAULT_UNARY(tanh, std::tanh) +template +Simd log2(Simd in) { + if constexpr (is_complex) { + auto out = std::log(in.value); + auto scale = decltype(out.real())(M_LN2); + return Simd{T{out.real() / scale, out.imag() / scale}}; + } else { + return Simd{std::log2(in.value)}; + } +} + template Simd operator~(Simd in) { return ~in.value; diff --git a/mlx/backend/metal/kernels/unary.metal b/mlx/backend/metal/kernels/unary.metal index 82692c8e5..2209b0665 100644 --- a/mlx/backend/metal/kernels/unary.metal +++ b/mlx/backend/metal/kernels/unary.metal @@ -73,6 +73,9 @@ instantiate_unary_all_same(Conjugate, complex64, complex64_t) instantiate_unary_all_same(Cos, complex64, complex64_t) instantiate_unary_all_same(Cosh, complex64, complex64_t) instantiate_unary_all_same(Exp, complex64, complex64_t) +instantiate_unary_all_same(Log, complex64, complex64_t) +instantiate_unary_all_same(Log2, complex64, complex64_t) +instantiate_unary_all_same(Log10, complex64, complex64_t) instantiate_unary_all_same(Negative, complex64, complex64_t) instantiate_unary_all_same(Sign, complex64, complex64_t) instantiate_unary_all_same(Sin, complex64, complex64_t) diff --git a/mlx/backend/metal/kernels/unary_ops.h b/mlx/backend/metal/kernels/unary_ops.h index ceed3efe5..52e126b40 100644 --- a/mlx/backend/metal/kernels/unary_ops.h +++ b/mlx/backend/metal/kernels/unary_ops.h @@ -257,6 +257,13 @@ struct Log { T operator()(T x) { return metal::precise::log(x); }; + + template <> + complex64_t operator()(complex64_t x) { + auto r = metal::precise::log(Abs{}(x).real); + auto i = metal::precise::atan2(x.imag, x.real); + return {r, i}; + }; }; struct Log2 { @@ -264,6 +271,12 @@ struct Log2 { T operator()(T x) { return metal::precise::log2(x); }; + + template <> + complex64_t operator()(complex64_t x) { + auto y = Log{}(x); + return {y.real / M_LN2_F, y.imag / M_LN2_F}; + }; }; struct Log10 { @@ -271,6 +284,12 @@ struct Log10 { T operator()(T x) { return metal::precise::log10(x); }; + + template <> + complex64_t operator()(complex64_t x) { + auto y = Log{}(x); + return {y.real / M_LN10_F, y.imag / M_LN10_F}; + }; }; struct Log1p { diff --git a/python/tests/test_ops.py b/python/tests/test_ops.py index 302c017a0..d7c79d9db 100644 --- a/python/tests/test_ops.py +++ b/python/tests/test_ops.py @@ -845,6 +845,11 @@ class TestOps(mlx_tests.MLXTestCase): self.assertTrue(np.allclose(result, expected)) + a = mx.array(1.0) + 1j * mx.array(2.0) + result = mx.log(a) + expected = np.log(np.array(a)) + self.assertTrue(np.allclose(result, expected)) + def test_log2(self): a = mx.array([0.5, 1, 2, 10, 16]) result = mx.log2(a) @@ -852,6 +857,11 @@ class TestOps(mlx_tests.MLXTestCase): self.assertTrue(np.allclose(result, expected)) + a = mx.array(1.0) + 1j * mx.array(2.0) + result = mx.log2(a) + expected = np.log2(np.array(a)) + self.assertTrue(np.allclose(result, expected)) + def test_log10(self): a = mx.array([0.1, 1, 10, 20, 100]) result = mx.log10(a) @@ -859,6 +869,11 @@ class TestOps(mlx_tests.MLXTestCase): self.assertTrue(np.allclose(result, expected)) + a = mx.array(1.0) + 1j * mx.array(2.0) + result = mx.log10(a) + expected = np.log10(np.array(a)) + self.assertTrue(np.allclose(result, expected)) + def test_exp(self): a = mx.array([0, 0.5, -0.5, 5]) result = mx.exp(a)