mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-24 09:21:16 +08:00
Improve stability of BCE loss calculation for input probabilities close to or exactly 0 or 1 (#1280)
* Improve stability of BCE loss calculation * Standardize comment * Apply formatting with black via pre-commit * Add usage recommendation to docstring * Update python/mlx/nn/losses.py --------- Co-authored-by: Awni Hannun <awni.hannun@gmail.com>
This commit is contained in:
parent
50eff6a10a
commit
ebd7135b50
@ -17,6 +17,7 @@ MLX was developed with contributions from the following individuals:
|
||||
- Brian Keene & Atila Orhon, with Argmax Inc.: Added `fast.scaled_dot_product_attention`
|
||||
- AmirHossein Razlighi: Added chaining support for some of the ops in `nn.Module`. Comparison works for non array objects in `mlx.core.array`. Exception handling for invalid operations in `mlx.core.array`.
|
||||
- Gleb Pobudzey: Added the `where` primitive, and groups in 1D and 2D convolutions.
|
||||
- Paul Paczuski: Improved stability of BCE loss calculation
|
||||
|
||||
<a href="https://github.com/ml-explore/mlx/graphs/contributors">
|
||||
<img class="dark-light" src="https://contrib.rocks/image?repo=ml-explore/mlx&anon=0&columns=20&max=100&r=true" />
|
||||
|
@ -124,6 +124,11 @@ def binary_cross_entropy(
|
||||
"""
|
||||
Computes the binary cross entropy loss.
|
||||
|
||||
By default, this function takes the pre-sigmoid logits, which results in a faster
|
||||
and more precise loss. For improved numerical stability when ``with_logits=False``,
|
||||
the loss calculation clips the input probabilities (in log-space) to a minimum value
|
||||
of ``-100``.
|
||||
|
||||
Args:
|
||||
inputs (array): The predicted values. If ``with_logits`` is ``True``, then
|
||||
``inputs`` are unnormalized logits. Otherwise, ``inputs`` are probabilities.
|
||||
@ -159,7 +164,9 @@ def binary_cross_entropy(
|
||||
if with_logits:
|
||||
loss = mx.logaddexp(0.0, inputs) - inputs * targets
|
||||
else:
|
||||
loss = -(targets * mx.log(inputs) + (1 - targets) * mx.log(1 - inputs))
|
||||
log_inputs_clip = mx.clip(mx.log(inputs), a_min=-100, a_max=None)
|
||||
log_inputs_inv_clip = mx.clip(mx.log(1 - inputs), a_min=-100, a_max=None)
|
||||
loss = -(targets * log_inputs_clip + (1 - targets) * log_inputs_inv_clip)
|
||||
|
||||
# Apply weights if provided
|
||||
if weights is not None:
|
||||
|
@ -125,8 +125,34 @@ class TestLosses(mlx_tests.MLXTestCase):
|
||||
expected_sum = mx.sum(expected_none)
|
||||
self.assertTrue(mx.allclose(losses_sum, expected_sum))
|
||||
|
||||
def _test_tiny_probs_as_inputs():
|
||||
TINY_PROB = 1e-59
|
||||
probs = mx.array([0, TINY_PROB, 1 - TINY_PROB, 1])
|
||||
targets = mx.array([0, 0, 1, 1])
|
||||
|
||||
losses_none = nn.losses.binary_cross_entropy(
|
||||
probs, targets, with_logits=False, reduction="none"
|
||||
)
|
||||
expected_none = mx.array([0.0, TINY_PROB, TINY_PROB, 0.0])
|
||||
self.assertTrue(mx.allclose(losses_none, expected_none))
|
||||
|
||||
# Test with reduction 'mean'
|
||||
losses_mean = nn.losses.binary_cross_entropy(
|
||||
probs, targets, with_logits=False, reduction="mean"
|
||||
)
|
||||
expected_mean = mx.mean(expected_none)
|
||||
self.assertTrue(mx.allclose(losses_mean, expected_mean))
|
||||
|
||||
# Test with reduction 'sum'
|
||||
losses_sum = nn.losses.binary_cross_entropy(
|
||||
probs, targets, with_logits=False, reduction="sum"
|
||||
)
|
||||
expected_sum = mx.sum(expected_none)
|
||||
self.assertTrue(mx.allclose(losses_sum, expected_sum))
|
||||
|
||||
_test_logits_as_inputs()
|
||||
_test_probs_as_inputs()
|
||||
_test_tiny_probs_as_inputs()
|
||||
|
||||
def test_l1_loss(self):
|
||||
predictions = mx.array([0.5, 0.2, 0.9, 0.0])
|
||||
|
Loading…
Reference in New Issue
Block a user