diff --git a/ACKNOWLEDGMENTS.md b/ACKNOWLEDGMENTS.md index 598161f78..b3183f4f8 100644 --- a/ACKNOWLEDGMENTS.md +++ b/ACKNOWLEDGMENTS.md @@ -17,6 +17,7 @@ MLX was developed with contributions from the following individuals: - Brian Keene & Atila Orhon, with Argmax Inc.: Added `fast.scaled_dot_product_attention` - AmirHossein Razlighi: Added chaining support for some of the ops in `nn.Module`. Comparison works for non array objects in `mlx.core.array`. Exception handling for invalid operations in `mlx.core.array`. - Gleb Pobudzey: Added the `where` primitive, and groups in 1D and 2D convolutions. +- Paul Paczuski: Improved stability of BCE loss calculation diff --git a/python/mlx/nn/losses.py b/python/mlx/nn/losses.py index 5bbeb1f06..55b5a68cc 100644 --- a/python/mlx/nn/losses.py +++ b/python/mlx/nn/losses.py @@ -124,6 +124,11 @@ def binary_cross_entropy( """ Computes the binary cross entropy loss. + By default, this function takes the pre-sigmoid logits, which results in a faster + and more precise loss. For improved numerical stability when ``with_logits=False``, + the loss calculation clips the input probabilities (in log-space) to a minimum value + of ``-100``. + Args: inputs (array): The predicted values. If ``with_logits`` is ``True``, then ``inputs`` are unnormalized logits. Otherwise, ``inputs`` are probabilities. @@ -159,7 +164,9 @@ def binary_cross_entropy( if with_logits: loss = mx.logaddexp(0.0, inputs) - inputs * targets else: - loss = -(targets * mx.log(inputs) + (1 - targets) * mx.log(1 - inputs)) + log_inputs_clip = mx.clip(mx.log(inputs), a_min=-100, a_max=None) + log_inputs_inv_clip = mx.clip(mx.log(1 - inputs), a_min=-100, a_max=None) + loss = -(targets * log_inputs_clip + (1 - targets) * log_inputs_inv_clip) # Apply weights if provided if weights is not None: diff --git a/python/tests/test_losses.py b/python/tests/test_losses.py index 3a430be21..102ec857d 100644 --- a/python/tests/test_losses.py +++ b/python/tests/test_losses.py @@ -125,8 +125,34 @@ class TestLosses(mlx_tests.MLXTestCase): expected_sum = mx.sum(expected_none) self.assertTrue(mx.allclose(losses_sum, expected_sum)) + def _test_tiny_probs_as_inputs(): + TINY_PROB = 1e-59 + probs = mx.array([0, TINY_PROB, 1 - TINY_PROB, 1]) + targets = mx.array([0, 0, 1, 1]) + + losses_none = nn.losses.binary_cross_entropy( + probs, targets, with_logits=False, reduction="none" + ) + expected_none = mx.array([0.0, TINY_PROB, TINY_PROB, 0.0]) + self.assertTrue(mx.allclose(losses_none, expected_none)) + + # Test with reduction 'mean' + losses_mean = nn.losses.binary_cross_entropy( + probs, targets, with_logits=False, reduction="mean" + ) + expected_mean = mx.mean(expected_none) + self.assertTrue(mx.allclose(losses_mean, expected_mean)) + + # Test with reduction 'sum' + losses_sum = nn.losses.binary_cross_entropy( + probs, targets, with_logits=False, reduction="sum" + ) + expected_sum = mx.sum(expected_none) + self.assertTrue(mx.allclose(losses_sum, expected_sum)) + _test_logits_as_inputs() _test_probs_as_inputs() + _test_tiny_probs_as_inputs() def test_l1_loss(self): predictions = mx.array([0.5, 0.2, 0.9, 0.0])