mirror of
https://github.com/ml-explore/mlx.git
synced 2025-10-19 16:48:10 +08:00
Improve stability of BCE loss calculation for input probabilities close to or exactly 0 or 1 (#1280)
* Improve stability of BCE loss calculation * Standardize comment * Apply formatting with black via pre-commit * Add usage recommendation to docstring * Update python/mlx/nn/losses.py --------- Co-authored-by: Awni Hannun <awni.hannun@gmail.com>
This commit is contained in:
@@ -125,8 +125,34 @@ class TestLosses(mlx_tests.MLXTestCase):
|
||||
expected_sum = mx.sum(expected_none)
|
||||
self.assertTrue(mx.allclose(losses_sum, expected_sum))
|
||||
|
||||
def _test_tiny_probs_as_inputs():
|
||||
TINY_PROB = 1e-59
|
||||
probs = mx.array([0, TINY_PROB, 1 - TINY_PROB, 1])
|
||||
targets = mx.array([0, 0, 1, 1])
|
||||
|
||||
losses_none = nn.losses.binary_cross_entropy(
|
||||
probs, targets, with_logits=False, reduction="none"
|
||||
)
|
||||
expected_none = mx.array([0.0, TINY_PROB, TINY_PROB, 0.0])
|
||||
self.assertTrue(mx.allclose(losses_none, expected_none))
|
||||
|
||||
# Test with reduction 'mean'
|
||||
losses_mean = nn.losses.binary_cross_entropy(
|
||||
probs, targets, with_logits=False, reduction="mean"
|
||||
)
|
||||
expected_mean = mx.mean(expected_none)
|
||||
self.assertTrue(mx.allclose(losses_mean, expected_mean))
|
||||
|
||||
# Test with reduction 'sum'
|
||||
losses_sum = nn.losses.binary_cross_entropy(
|
||||
probs, targets, with_logits=False, reduction="sum"
|
||||
)
|
||||
expected_sum = mx.sum(expected_none)
|
||||
self.assertTrue(mx.allclose(losses_sum, expected_sum))
|
||||
|
||||
_test_logits_as_inputs()
|
||||
_test_probs_as_inputs()
|
||||
_test_tiny_probs_as_inputs()
|
||||
|
||||
def test_l1_loss(self):
|
||||
predictions = mx.array([0.5, 0.2, 0.9, 0.0])
|
||||
|
Reference in New Issue
Block a user