Improve stability of BCE loss calculation for input probabilities close to or exactly 0 or 1 (#1280)

* Improve stability of BCE loss calculation

* Standardize comment

* Apply formatting with black via pre-commit

* Add usage recommendation to docstring

* Update python/mlx/nn/losses.py

---------

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>
This commit is contained in:
Paul Paczuski
2024-07-24 11:38:22 -04:00
committed by GitHub
parent 50eff6a10a
commit ebd7135b50
3 changed files with 35 additions and 1 deletions

View File

@@ -125,8 +125,34 @@ class TestLosses(mlx_tests.MLXTestCase):
expected_sum = mx.sum(expected_none)
self.assertTrue(mx.allclose(losses_sum, expected_sum))
def _test_tiny_probs_as_inputs():
TINY_PROB = 1e-59
probs = mx.array([0, TINY_PROB, 1 - TINY_PROB, 1])
targets = mx.array([0, 0, 1, 1])
losses_none = nn.losses.binary_cross_entropy(
probs, targets, with_logits=False, reduction="none"
)
expected_none = mx.array([0.0, TINY_PROB, TINY_PROB, 0.0])
self.assertTrue(mx.allclose(losses_none, expected_none))
# Test with reduction 'mean'
losses_mean = nn.losses.binary_cross_entropy(
probs, targets, with_logits=False, reduction="mean"
)
expected_mean = mx.mean(expected_none)
self.assertTrue(mx.allclose(losses_mean, expected_mean))
# Test with reduction 'sum'
losses_sum = nn.losses.binary_cross_entropy(
probs, targets, with_logits=False, reduction="sum"
)
expected_sum = mx.sum(expected_none)
self.assertTrue(mx.allclose(losses_sum, expected_sum))
_test_logits_as_inputs()
_test_probs_as_inputs()
_test_tiny_probs_as_inputs()
def test_l1_loss(self):
predictions = mx.array([0.5, 0.2, 0.9, 0.0])