mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-26 02:33:21 +08:00
Fix log1p with inf inputs (#592)
This commit is contained in:
parent
09b9275027
commit
1895d34c20
@ -235,12 +235,24 @@ inline size_t ceildiv(size_t N, size_t M) {
|
||||
// https://docs.oracle.com/cd/E19957-01/806-3568/ncg_goldberg.html#1202
|
||||
inline float log1p(float x) {
|
||||
float xp1 = 1.0f + x;
|
||||
return (xp1 == 1.0f) ? x : x * (metal::log(xp1) / (xp1 - 1.0f));
|
||||
if (xp1 == Limits<float>::max) {
|
||||
return Limits<float>::max;
|
||||
}
|
||||
if (xp1 == 1.0f) {
|
||||
return x;
|
||||
}
|
||||
|
||||
return x * (metal::log(xp1) / (xp1 - 1.0f));
|
||||
}
|
||||
|
||||
inline bfloat16_t log1p(bfloat16_t x) {
|
||||
float xp1 = 1.0f + static_cast<float>(x);
|
||||
bfloat16_t ret =
|
||||
(xp1 == 1.0f) ? x : bfloat16_t(x * (metal::log(xp1) / (xp1 - 1.0f)));
|
||||
return ret;
|
||||
if (xp1 == Limits<float>::max) {
|
||||
return Limits<bfloat16_t>::max;
|
||||
}
|
||||
if (xp1 == 1.0f) {
|
||||
return x;
|
||||
}
|
||||
|
||||
return bfloat16_t(x * (metal::log(xp1) / (xp1 - 1.0f)));
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user