mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-26 10:41:14 +08:00
Fix log1p with inf inputs (#592)
This commit is contained in:
parent
09b9275027
commit
1895d34c20
@ -235,12 +235,24 @@ inline size_t ceildiv(size_t N, size_t M) {
|
|||||||
// https://docs.oracle.com/cd/E19957-01/806-3568/ncg_goldberg.html#1202
|
// https://docs.oracle.com/cd/E19957-01/806-3568/ncg_goldberg.html#1202
|
||||||
inline float log1p(float x) {
|
inline float log1p(float x) {
|
||||||
float xp1 = 1.0f + x;
|
float xp1 = 1.0f + x;
|
||||||
return (xp1 == 1.0f) ? x : x * (metal::log(xp1) / (xp1 - 1.0f));
|
if (xp1 == Limits<float>::max) {
|
||||||
|
return Limits<float>::max;
|
||||||
|
}
|
||||||
|
if (xp1 == 1.0f) {
|
||||||
|
return x;
|
||||||
|
}
|
||||||
|
|
||||||
|
return x * (metal::log(xp1) / (xp1 - 1.0f));
|
||||||
}
|
}
|
||||||
|
|
||||||
inline bfloat16_t log1p(bfloat16_t x) {
|
inline bfloat16_t log1p(bfloat16_t x) {
|
||||||
float xp1 = 1.0f + static_cast<float>(x);
|
float xp1 = 1.0f + static_cast<float>(x);
|
||||||
bfloat16_t ret =
|
if (xp1 == Limits<float>::max) {
|
||||||
(xp1 == 1.0f) ? x : bfloat16_t(x * (metal::log(xp1) / (xp1 - 1.0f)));
|
return Limits<bfloat16_t>::max;
|
||||||
return ret;
|
}
|
||||||
|
if (xp1 == 1.0f) {
|
||||||
|
return x;
|
||||||
|
}
|
||||||
|
|
||||||
|
return bfloat16_t(x * (metal::log(xp1) / (xp1 - 1.0f)));
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user