Improve stability of BCE loss calculation for input probabilities close to or exactly 0 or 1 (#1280)

* Improve stability of BCE loss calculation

* Standardize comment

* Apply formatting with black via pre-commit

* Add usage recommendation to docstring

* Update python/mlx/nn/losses.py

---------

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>
This commit is contained in:
Paul Paczuski
2024-07-24 11:38:22 -04:00
committed by GitHub
parent 50eff6a10a
commit ebd7135b50
3 changed files with 35 additions and 1 deletions

View File

@@ -124,6 +124,11 @@ def binary_cross_entropy(
"""
Computes the binary cross entropy loss.
By default, this function takes the pre-sigmoid logits, which results in a faster
and more precise loss. For improved numerical stability when ``with_logits=False``,
the loss calculation clips the input probabilities (in log-space) to a minimum value
of ``-100``.
Args:
inputs (array): The predicted values. If ``with_logits`` is ``True``, then
``inputs`` are unnormalized logits. Otherwise, ``inputs`` are probabilities.
@@ -159,7 +164,9 @@ def binary_cross_entropy(
if with_logits:
loss = mx.logaddexp(0.0, inputs) - inputs * targets
else:
loss = -(targets * mx.log(inputs) + (1 - targets) * mx.log(1 - inputs))
log_inputs_clip = mx.clip(mx.log(inputs), a_min=-100, a_max=None)
log_inputs_inv_clip = mx.clip(mx.log(1 - inputs), a_min=-100, a_max=None)
loss = -(targets * log_inputs_clip + (1 - targets) * log_inputs_inv_clip)
# Apply weights if provided
if weights is not None: