mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-24 09:21:16 +08:00
Complex scan (#2094)
This commit is contained in:
parent
e8ac6bd2f5
commit
1d2c9d6a07
@ -172,9 +172,12 @@ void binary_float(
|
||||
case bfloat16:
|
||||
binary_op<bfloat16_t, Op>(a, b, out, bopt);
|
||||
break;
|
||||
case complex64:
|
||||
binary_op<complex64_t, Op>(a, b, out, bopt);
|
||||
break;
|
||||
default:
|
||||
throw std::runtime_error(
|
||||
"[binary_float] Only supports non-complex floating point types.");
|
||||
"[binary_float] Only supports floating point types.");
|
||||
}
|
||||
});
|
||||
}
|
||||
|
@ -330,7 +330,8 @@ void Scan::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
reduce_type_, in, out, axis_, reverse_, inclusive_);
|
||||
break;
|
||||
case complex64:
|
||||
throw std::runtime_error("Scan ops do not support complex types yet");
|
||||
scan_dispatch<complex64_t, complex64_t>(
|
||||
reduce_type_, in, out, axis_, reverse_, inclusive_);
|
||||
break;
|
||||
}
|
||||
});
|
||||
|
@ -88,12 +88,33 @@ DEFAULT_UNARY(expm1, std::expm1)
|
||||
DEFAULT_UNARY(floor, std::floor)
|
||||
DEFAULT_UNARY(log, std::log)
|
||||
DEFAULT_UNARY(log10, std::log10)
|
||||
DEFAULT_UNARY(log1p, std::log1p)
|
||||
DEFAULT_UNARY(sinh, std::sinh)
|
||||
DEFAULT_UNARY(sqrt, std::sqrt)
|
||||
DEFAULT_UNARY(tan, std::tan)
|
||||
DEFAULT_UNARY(tanh, std::tanh)
|
||||
|
||||
template <typename T>
|
||||
Simd<T, 1> log1p(Simd<T, 1> in) {
|
||||
if constexpr (is_complex<T>) {
|
||||
auto x = in.value.real();
|
||||
auto y = in.value.imag();
|
||||
auto zabs = std::abs(in.value);
|
||||
auto theta = std::atan2(y, x + 1);
|
||||
if (zabs < 0.5) {
|
||||
auto r = x * (2 + x) + y * y;
|
||||
if (r == 0) { // handle underflow
|
||||
return Simd<T, 1>{T{x, theta}};
|
||||
}
|
||||
return Simd<T, 1>{T{((typeof(x))(0.5)) * std::log1p(r), theta}};
|
||||
} else {
|
||||
auto z0 = std::hypot(x + 1, y);
|
||||
return Simd<T, 1>{T{std::log(z0), theta}};
|
||||
}
|
||||
} else {
|
||||
return Simd<T, 1>{std::log1p(in.value)};
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
Simd<T, 1> log2(Simd<T, 1> in) {
|
||||
if constexpr (is_complex<T>) {
|
||||
|
@ -71,6 +71,7 @@ instantiate_binary_types_bool(Less)
|
||||
instantiate_binary_types_bool(LessEqual)
|
||||
instantiate_binary_types_bool(NotEqual)
|
||||
instantiate_binary_float(LogAddExp)
|
||||
instantiate_binary_all(LogAddExp, complex64, complex64_t, complex64_t)
|
||||
instantiate_binary_types(Maximum)
|
||||
instantiate_binary_types(Minimum)
|
||||
instantiate_binary_types(Multiply)
|
||||
|
@ -130,6 +130,24 @@ struct LogAddExp {
|
||||
? maxval
|
||||
: (maxval + log1p(metal::exp(minval - maxval)));
|
||||
};
|
||||
|
||||
complex64_t operator()(complex64_t x, complex64_t y) {
|
||||
if (metal::isnan(x.real) || metal::isnan(x.imag) || metal::isnan(y.real) ||
|
||||
metal::isnan(y.imag)) {
|
||||
return metal::numeric_limits<float>::quiet_NaN();
|
||||
}
|
||||
constexpr float inf = metal::numeric_limits<float>::infinity();
|
||||
complex64_t maxval = x > y ? x : y;
|
||||
complex64_t minval = x < y ? x : y;
|
||||
if (minval.real == -inf || maxval.real == inf)
|
||||
return maxval;
|
||||
float m = metal::exp(minval.real - maxval.real);
|
||||
complex64_t dexp{
|
||||
m * metal::cos(minval.imag - maxval.imag),
|
||||
m * metal::sin(minval.imag - maxval.imag),
|
||||
};
|
||||
return maxval + log1p(dexp);
|
||||
}
|
||||
};
|
||||
|
||||
struct Maximum {
|
||||
|
@ -104,4 +104,5 @@ instantiate_scan_helper(min_bfloat16_bfloat16, bfloat16_t, bfloat16_t, CumMi
|
||||
instantiate_scan_helper(min_complex64_complex64, complex64_t, complex64_t, CumMin, 2)
|
||||
instantiate_scan_helper(logaddexp_float16_float16, half, half, CumLogaddexp, 4)
|
||||
instantiate_scan_helper(logaddexp_float32_float32, float, float, CumLogaddexp, 4)
|
||||
instantiate_scan_helper(logaddexp_bfloat16_bfloat16, bfloat16_t, bfloat16_t, CumLogaddexp, 4) // clang-format on
|
||||
instantiate_scan_helper(logaddexp_bfloat16_bfloat16, bfloat16_t, bfloat16_t, CumLogaddexp, 4)
|
||||
instantiate_scan_helper(logaddexp_complex64_complex64, complex64_t, complex64_t, CumLogaddexp, 2) // clang-format on
|
||||
|
@ -77,6 +77,7 @@ instantiate_unary_all_same(Cos, complex64, complex64_t)
|
||||
instantiate_unary_all_same(Cosh, complex64, complex64_t)
|
||||
instantiate_unary_all_same(Exp, complex64, complex64_t)
|
||||
instantiate_unary_all_same(Log, complex64, complex64_t)
|
||||
instantiate_unary_all_same(Log1p, complex64, complex64_t)
|
||||
instantiate_unary_all_same(Log2, complex64, complex64_t)
|
||||
instantiate_unary_all_same(Log10, complex64, complex64_t)
|
||||
instantiate_unary_all_same(Negative, complex64, complex64_t)
|
||||
|
@ -328,6 +328,23 @@ inline bfloat16_t log1p(bfloat16_t x) {
|
||||
return bfloat16_t(x * (metal::log(xp1) / (xp1 - 1.0f)));
|
||||
}
|
||||
|
||||
inline complex64_t log1p(complex64_t in) {
|
||||
float x = in.real;
|
||||
float y = in.imag;
|
||||
float zabs = metal::precise::sqrt(x * x + y * y);
|
||||
float theta = metal::atan2(y, x + 1);
|
||||
if (zabs < 0.5f) {
|
||||
float r = x * (2 + x) + y * y;
|
||||
if (r == 0) { // handle underflow
|
||||
return {x, theta};
|
||||
}
|
||||
return {0.5f * log1p(r), theta};
|
||||
} else {
|
||||
auto z0 = metal::sqrt((x + 1) * (x + 1) + y * y);
|
||||
return {metal::log(z0), theta};
|
||||
}
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// SIMD shuffle ops
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
@ -10,6 +10,47 @@ import mlx_tests
|
||||
import numpy as np
|
||||
|
||||
|
||||
def np_wrap_between(x, a):
|
||||
"""Wraps `x` between `[-a, a]`."""
|
||||
two_a = 2 * a
|
||||
zero = 0
|
||||
rem = np.remainder(np.add(x, a), two_a)
|
||||
if isinstance(rem, np.ndarray):
|
||||
rem = np.select(rem < zero, np.add(rem, two_a), rem)
|
||||
else:
|
||||
rem = np.add(rem, two_a) if rem < zero else rem
|
||||
return np.subtract(rem, a)
|
||||
|
||||
|
||||
def np_logaddexp(x1: np.ndarray, x2: np.ndarray):
|
||||
amax = np.maximum(x1, x2)
|
||||
if np.issubdtype(x1.dtype, np.floating):
|
||||
delta = np.subtract(x1, x2)
|
||||
if isinstance(delta, np.ndarray):
|
||||
return np.select(
|
||||
np.isnan(delta),
|
||||
np.add(x1, x2),
|
||||
np.add(amax, np.log1p(np.exp(np.negative(np.abs(delta))))),
|
||||
)
|
||||
else:
|
||||
return (
|
||||
np.add(x1, x2)
|
||||
if np.isnan(delta)
|
||||
else np.add(amax, np.log1p(np.exp(np.negative(np.abs(delta)))))
|
||||
)
|
||||
else:
|
||||
delta = np.subtract(np.add(x1, x2), np.multiply(amax, 2))
|
||||
out = np.add(amax, np.log1p(np.exp(delta)))
|
||||
return np.real(out) + 1j * np_wrap_between(np.imag(out), np.pi)
|
||||
|
||||
|
||||
def np_cumlogaddexp(x1: np.ndarray, axis: int = -1):
|
||||
out = x1.copy()
|
||||
for i in range(1, out.shape[axis]):
|
||||
out[i] = np_logaddexp(out[i], out[i - 1])
|
||||
return out
|
||||
|
||||
|
||||
class TestOps(mlx_tests.MLXTestCase):
|
||||
def test_full_ones_zeros(self):
|
||||
x = mx.full(2, 3.0)
|
||||
@ -853,6 +894,16 @@ class TestOps(mlx_tests.MLXTestCase):
|
||||
|
||||
self.assertTrue(np.allclose(result, expected))
|
||||
|
||||
# Complex test
|
||||
|
||||
a = mx.array([0, 1, 2, 9.0]) + 1j
|
||||
b = mx.array([1, 0, 4, 2.5]) + 1j
|
||||
|
||||
result = mx.logaddexp(a, b)
|
||||
expected = np_logaddexp(np.array(a), np.array(b))
|
||||
|
||||
self.assertTrue(np.allclose(result, expected))
|
||||
|
||||
a = mx.array([float("nan")])
|
||||
b = mx.array([0.0])
|
||||
self.assertTrue(math.isnan(mx.logaddexp(a, b).item()))
|
||||
@ -977,6 +1028,13 @@ class TestOps(mlx_tests.MLXTestCase):
|
||||
|
||||
self.assertTrue(np.allclose(result, expected))
|
||||
|
||||
# Complex test
|
||||
a = mx.array([1, 0.5, 10, 100]) + 1j
|
||||
result = mx.log1p(a)
|
||||
expected = np.log1p(a, dtype=np.complex64)
|
||||
|
||||
self.assertTrue(np.allclose(result, expected))
|
||||
|
||||
def test_sigmoid(self):
|
||||
a = mx.array([0.0, 1.0, -1.0, 5.0, -5.0])
|
||||
result = mx.sigmoid(a)
|
||||
@ -1881,10 +1939,31 @@ class TestOps(mlx_tests.MLXTestCase):
|
||||
c_mlx = mxop(a_mlx, axis=0)
|
||||
self.assertTrue(np.allclose(c_npy, c_mlx, rtol=1e-3, atol=1e-3))
|
||||
|
||||
# Complex tests
|
||||
|
||||
a_npy = np.array([1, 2, 3]).astype(np.float32) + 1j
|
||||
a_mlx = mx.array(a_npy)
|
||||
c_npy = np_cumlogaddexp(a_npy, axis=-1)
|
||||
c_mlx = mxop(a_mlx, axis=-1)
|
||||
self.assertTrue(np.allclose(c_npy, c_mlx, rtol=1e-3, atol=1e-3))
|
||||
|
||||
def test_scans(self):
|
||||
a_npy = np.random.randn(32, 32, 32).astype(np.float32)
|
||||
a_mlx = mx.array(a_npy)
|
||||
|
||||
for op in ["cumsum", "cumprod"]:
|
||||
npop = getattr(np, op)
|
||||
mxop = getattr(mx, op)
|
||||
for axis in (None, 0, 1, 2):
|
||||
c_npy = npop(a_npy, axis=axis)
|
||||
c_mlx = mxop(a_mlx, axis=axis)
|
||||
self.assertTrue(np.allclose(c_npy, c_mlx, rtol=1e-3, atol=1e-3))
|
||||
|
||||
# Complex test
|
||||
|
||||
a_npy = np.random.randn(32, 32, 32).astype(np.float32) + 0.5j
|
||||
a_mlx = mx.array(a_npy)
|
||||
|
||||
for op in ["cumsum", "cumprod"]:
|
||||
npop = getattr(np, op)
|
||||
mxop = getattr(mx, op)
|
||||
|
Loading…
Reference in New Issue
Block a user