Update binary_cross_entropy function to handle both logits and probabilities (#492)

This commit is contained in:
AtomicVar
2024-01-19 11:22:23 +08:00
committed by GitHub
parent f6e911ced0
commit 550d4bf7c0
2 changed files with 83 additions and 7 deletions

View File

@@ -74,29 +74,50 @@ def cross_entropy(
def binary_cross_entropy(
logits: mx.array, targets: mx.array, reduction: Reduction = "none"
inputs: mx.array,
targets: mx.array,
with_logits: bool = True,
reduction: Reduction = "mean",
) -> mx.array:
"""
Computes the binary cross entropy loss.
Args:
logits (array): The unnormalized (pre-sigmoid) predicted logits.
inputs (array): The predicted values. If ``with_logits`` is ``True``, then
``inputs`` are unnormalized logits. Otherwise, ``inputs`` are probabilities.
targets (array): The binary target values in {0, 1}.
with_logits (bool, optional): Whether ``inputs`` are logits. Default: ``True``.
reduction (str, optional): Specifies the reduction to apply to the output:
``'none'`` | ``'mean'`` | ``'sum'``. Default: ``'none'``.
``'none'`` | ``'mean'`` | ``'sum'``. Default: ``'mean'``.
Returns:
array: The computed binary cross entropy loss.
Examples:
>>> import mlx.core as mx
>>> import mlx.nn as nn
>>> inputs = mx.array([0.105361, 0.223144, 1.20397, 0.916291])
>>> logits = mx.array([0.105361, 0.223144, 1.20397, 0.916291])
>>> targets = mx.array([0, 0, 1, 1])
>>> loss = nn.losses.binary_cross_entropy(inputs, targets, "mean")
>>> loss = nn.losses.binary_cross_entropy(logits, targets, reduction="mean")
>>> loss
array([0.612192], dtype=float32)
array(0.539245, dtype=float32)
>>> probs = mx.array([0.1, 0.1, 0.4, 0.4])
>>> targets = mx.array([0, 0, 1, 1])
>>> loss = nn.losses.binary_cross_entropy(probs, targets, with_logits=False, reduction="mean")
>>> loss
array(0.510826, dtype=float32)
"""
loss = mx.logaddexp(0.0, logits) - targets * logits
if inputs.shape != targets.shape:
raise ValueError(
f"Inputs shape {inputs.shape} does not match targets shape {targets.shape}."
)
if with_logits:
loss = mx.logaddexp(0.0, inputs) - inputs * targets
else:
loss = -(targets * mx.log(inputs) + (1 - targets) * mx.log(1 - inputs))
return _reduce(loss, reduction)