Fix log1p with inf inputs (#592)

This commit is contained in:
Angelos Katharopoulos 2024-01-30 14:02:50 -08:00 committed by GitHub
parent 09b9275027
commit 1895d34c20
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -235,12 +235,24 @@ inline size_t ceildiv(size_t N, size_t M) {
// https://docs.oracle.com/cd/E19957-01/806-3568/ncg_goldberg.html#1202
inline float log1p(float x) {
float xp1 = 1.0f + x;
return (xp1 == 1.0f) ? x : x * (metal::log(xp1) / (xp1 - 1.0f));
if (xp1 == Limits<float>::max) {
return Limits<float>::max;
}
if (xp1 == 1.0f) {
return x;
}
return x * (metal::log(xp1) / (xp1 - 1.0f));
}
inline bfloat16_t log1p(bfloat16_t x) {
float xp1 = 1.0f + static_cast<float>(x);
bfloat16_t ret =
(xp1 == 1.0f) ? x : bfloat16_t(x * (metal::log(xp1) / (xp1 - 1.0f)));
return ret;
if (xp1 == Limits<float>::max) {
return Limits<bfloat16_t>::max;
}
if (xp1 == 1.0f) {
return x;
}
return bfloat16_t(x * (metal::log(xp1) / (xp1 - 1.0f)));
}