mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
fix power (#2523)
This commit is contained in:
@@ -234,6 +234,7 @@ Simd<T, N> remainder(Simd<T, N> a, Simd<T, N> b) {
|
||||
|
||||
template <typename MaskT, typename T1, typename T2, int N>
|
||||
Simd<T1, N> select(Simd<MaskT, N> mask, Simd<T1, N> x, Simd<T2, N> y) {
|
||||
static_assert(std::is_same_v<MaskT, bool>);
|
||||
if constexpr (sizeof(T1) == 1) {
|
||||
return asd::bitselect(y.value, x.value, asd::convert<char>(mask.value));
|
||||
} else if constexpr (sizeof(T1) == 2) {
|
||||
@@ -251,9 +252,13 @@ Simd<T, N> pow(Simd<T, N> base, Simd<T, N> exp) {
|
||||
return asd::pow(base.value, exp.value);
|
||||
} else {
|
||||
Simd<T, N> res = 1;
|
||||
while (any(exp)) {
|
||||
res = select(exp & 1, res * base, res);
|
||||
base = select(exp, base * base, base);
|
||||
// Raising an integer to a negative power is undefined
|
||||
if (any(exp < 0)) {
|
||||
return 0;
|
||||
}
|
||||
while (any(exp > 0)) {
|
||||
res = select((exp & 1) != 0, res * base, res);
|
||||
base = select(exp > 0, base * base, base);
|
||||
exp = exp >> 1;
|
||||
}
|
||||
return res;
|
||||
|
||||
Reference in New Issue
Block a user