This commit is contained in:
Awni Hannun 2025-08-21 06:46:01 -07:00 committed by GitHub
parent 0c5fc63a36
commit e843c4d8d5
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 24 additions and 3 deletions

View File

@ -234,6 +234,7 @@ Simd<T, N> remainder(Simd<T, N> a, Simd<T, N> b) {
template <typename MaskT, typename T1, typename T2, int N> template <typename MaskT, typename T1, typename T2, int N>
Simd<T1, N> select(Simd<MaskT, N> mask, Simd<T1, N> x, Simd<T2, N> y) { Simd<T1, N> select(Simd<MaskT, N> mask, Simd<T1, N> x, Simd<T2, N> y) {
static_assert(std::is_same_v<MaskT, bool>);
if constexpr (sizeof(T1) == 1) { if constexpr (sizeof(T1) == 1) {
return asd::bitselect(y.value, x.value, asd::convert<char>(mask.value)); return asd::bitselect(y.value, x.value, asd::convert<char>(mask.value));
} else if constexpr (sizeof(T1) == 2) { } else if constexpr (sizeof(T1) == 2) {
@ -251,9 +252,13 @@ Simd<T, N> pow(Simd<T, N> base, Simd<T, N> exp) {
return asd::pow(base.value, exp.value); return asd::pow(base.value, exp.value);
} else { } else {
Simd<T, N> res = 1; Simd<T, N> res = 1;
while (any(exp)) { // Raising an integer to a negative power is undefined
res = select(exp & 1, res * base, res); if (any(exp < 0)) {
base = select(exp, base * base, base); return 0;
}
while (any(exp > 0)) {
res = select((exp & 1) != 0, res * base, res);
base = select(exp > 0, base * base, base);
exp = exp >> 1; exp = exp >> 1;
} }
return res; return res;

View File

@ -204,6 +204,10 @@ struct Power {
__device__ T operator()(T base, T exp) { __device__ T operator()(T base, T exp) {
if constexpr (cuda::std::is_integral_v<T>) { if constexpr (cuda::std::is_integral_v<T>) {
T res = 1; T res = 1;
// Raising an integer to a negative power is undefined
if (exp < 0) {
return 0;
}
while (exp) { while (exp) {
if (exp & 1) { if (exp & 1) {
res *= base; res *= base;

View File

@ -223,6 +223,11 @@ struct Power {
template <typename T> template <typename T>
metal::enable_if_t<metal::is_integral_v<T>, T> operator()(T base, T exp) { metal::enable_if_t<metal::is_integral_v<T>, T> operator()(T base, T exp) {
T res = 1; T res = 1;
// Undefined to raise integer to negative power
if (exp < 0) {
return 0;
}
while (exp) { while (exp) {
if (exp & 1) { if (exp & 1) {
res *= base; res *= base;

View File

@ -3068,6 +3068,13 @@ class TestOps(mlx_tests.MLXTestCase):
d = mx.where(c, a[1:], b) d = mx.where(c, a[1:], b)
self.assertTrue(mx.all(d == 1.0)) self.assertTrue(mx.all(d == 1.0))
def test_integer_power(self):
x = mx.power(2, mx.array([8, 8, 8, 8, 8, 8, 8, 8]))
self.assertTrue(mx.all(x == 256))
# Doesn't hang
x = mx.power(2, -1)
class TestBroadcast(mlx_tests.MLXTestCase): class TestBroadcast(mlx_tests.MLXTestCase):
def test_broadcast_shapes(self): def test_broadcast_shapes(self):