From e843c4d8d59013167a6bc4543ae611ac7c7e5f51 Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Thu, 21 Aug 2025 06:46:01 -0700 Subject: [PATCH] fix power (#2523) --- mlx/backend/cpu/simd/accelerate_simd.h | 11 ++++++++--- mlx/backend/cuda/device/binary_ops.cuh | 4 ++++ mlx/backend/metal/kernels/binary_ops.h | 5 +++++ python/tests/test_ops.py | 7 +++++++ 4 files changed, 24 insertions(+), 3 deletions(-) diff --git a/mlx/backend/cpu/simd/accelerate_simd.h b/mlx/backend/cpu/simd/accelerate_simd.h index 37b3cdbd8..ed7d11482 100644 --- a/mlx/backend/cpu/simd/accelerate_simd.h +++ b/mlx/backend/cpu/simd/accelerate_simd.h @@ -234,6 +234,7 @@ Simd remainder(Simd a, Simd b) { template Simd select(Simd mask, Simd x, Simd y) { + static_assert(std::is_same_v); if constexpr (sizeof(T1) == 1) { return asd::bitselect(y.value, x.value, asd::convert(mask.value)); } else if constexpr (sizeof(T1) == 2) { @@ -251,9 +252,13 @@ Simd pow(Simd base, Simd exp) { return asd::pow(base.value, exp.value); } else { Simd res = 1; - while (any(exp)) { - res = select(exp & 1, res * base, res); - base = select(exp, base * base, base); + // Raising an integer to a negative power is undefined + if (any(exp < 0)) { + return 0; + } + while (any(exp > 0)) { + res = select((exp & 1) != 0, res * base, res); + base = select(exp > 0, base * base, base); exp = exp >> 1; } return res; diff --git a/mlx/backend/cuda/device/binary_ops.cuh b/mlx/backend/cuda/device/binary_ops.cuh index 575aced14..31daf34cb 100644 --- a/mlx/backend/cuda/device/binary_ops.cuh +++ b/mlx/backend/cuda/device/binary_ops.cuh @@ -204,6 +204,10 @@ struct Power { __device__ T operator()(T base, T exp) { if constexpr (cuda::std::is_integral_v) { T res = 1; + // Raising an integer to a negative power is undefined + if (exp < 0) { + return 0; + } while (exp) { if (exp & 1) { res *= base; diff --git a/mlx/backend/metal/kernels/binary_ops.h b/mlx/backend/metal/kernels/binary_ops.h index f4deb860e..cb3e8a370 100644 --- a/mlx/backend/metal/kernels/binary_ops.h +++ b/mlx/backend/metal/kernels/binary_ops.h @@ -223,6 +223,11 @@ struct Power { template metal::enable_if_t, T> operator()(T base, T exp) { T res = 1; + // Undefined to raise integer to negative power + if (exp < 0) { + return 0; + } + while (exp) { if (exp & 1) { res *= base; diff --git a/python/tests/test_ops.py b/python/tests/test_ops.py index 5bc51f297..bde57dee2 100644 --- a/python/tests/test_ops.py +++ b/python/tests/test_ops.py @@ -3068,6 +3068,13 @@ class TestOps(mlx_tests.MLXTestCase): d = mx.where(c, a[1:], b) self.assertTrue(mx.all(d == 1.0)) + def test_integer_power(self): + x = mx.power(2, mx.array([8, 8, 8, 8, 8, 8, 8, 8])) + self.assertTrue(mx.all(x == 256)) + + # Doesn't hang + x = mx.power(2, -1) + class TestBroadcast(mlx_tests.MLXTestCase): def test_broadcast_shapes(self):