diff --git a/mlx/backend/metal/kernels/utils.h b/mlx/backend/metal/kernels/utils.h index 72cdd8b20..ce6dc2954 100644 --- a/mlx/backend/metal/kernels/utils.h +++ b/mlx/backend/metal/kernels/utils.h @@ -235,12 +235,24 @@ inline size_t ceildiv(size_t N, size_t M) { // https://docs.oracle.com/cd/E19957-01/806-3568/ncg_goldberg.html#1202 inline float log1p(float x) { float xp1 = 1.0f + x; - return (xp1 == 1.0f) ? x : x * (metal::log(xp1) / (xp1 - 1.0f)); + if (xp1 == Limits::max) { + return Limits::max; + } + if (xp1 == 1.0f) { + return x; + } + + return x * (metal::log(xp1) / (xp1 - 1.0f)); } inline bfloat16_t log1p(bfloat16_t x) { float xp1 = 1.0f + static_cast(x); - bfloat16_t ret = - (xp1 == 1.0f) ? x : bfloat16_t(x * (metal::log(xp1) / (xp1 - 1.0f))); - return ret; + if (xp1 == Limits::max) { + return Limits::max; + } + if (xp1 == 1.0f) { + return x; + } + + return bfloat16_t(x * (metal::log(xp1) / (xp1 - 1.0f))); }