mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-18 18:28:12 +08:00
Improve stability of BCE loss calculation for input probabilities close to or exactly 0 or 1 (#1280)
* Improve stability of BCE loss calculation * Standardize comment * Apply formatting with black via pre-commit * Add usage recommendation to docstring * Update python/mlx/nn/losses.py --------- Co-authored-by: Awni Hannun <awni.hannun@gmail.com>
This commit is contained in:
@@ -124,6 +124,11 @@ def binary_cross_entropy(
|
||||
"""
|
||||
Computes the binary cross entropy loss.
|
||||
|
||||
By default, this function takes the pre-sigmoid logits, which results in a faster
|
||||
and more precise loss. For improved numerical stability when ``with_logits=False``,
|
||||
the loss calculation clips the input probabilities (in log-space) to a minimum value
|
||||
of ``-100``.
|
||||
|
||||
Args:
|
||||
inputs (array): The predicted values. If ``with_logits`` is ``True``, then
|
||||
``inputs`` are unnormalized logits. Otherwise, ``inputs`` are probabilities.
|
||||
@@ -159,7 +164,9 @@ def binary_cross_entropy(
|
||||
if with_logits:
|
||||
loss = mx.logaddexp(0.0, inputs) - inputs * targets
|
||||
else:
|
||||
loss = -(targets * mx.log(inputs) + (1 - targets) * mx.log(1 - inputs))
|
||||
log_inputs_clip = mx.clip(mx.log(inputs), a_min=-100, a_max=None)
|
||||
log_inputs_inv_clip = mx.clip(mx.log(1 - inputs), a_min=-100, a_max=None)
|
||||
loss = -(targets * log_inputs_clip + (1 - targets) * log_inputs_inv_clip)
|
||||
|
||||
# Apply weights if provided
|
||||
if weights is not None:
|
||||
|
Reference in New Issue
Block a user