mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-25 01:41:17 +08:00
Update binary_cross_entropy function to handle both logits and probabilities (#492)
This commit is contained in:
parent
f6e911ced0
commit
550d4bf7c0
@ -74,29 +74,50 @@ def cross_entropy(
|
|||||||
|
|
||||||
|
|
||||||
def binary_cross_entropy(
|
def binary_cross_entropy(
|
||||||
logits: mx.array, targets: mx.array, reduction: Reduction = "none"
|
inputs: mx.array,
|
||||||
|
targets: mx.array,
|
||||||
|
with_logits: bool = True,
|
||||||
|
reduction: Reduction = "mean",
|
||||||
) -> mx.array:
|
) -> mx.array:
|
||||||
"""
|
"""
|
||||||
Computes the binary cross entropy loss.
|
Computes the binary cross entropy loss.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
logits (array): The unnormalized (pre-sigmoid) predicted logits.
|
inputs (array): The predicted values. If ``with_logits`` is ``True``, then
|
||||||
|
``inputs`` are unnormalized logits. Otherwise, ``inputs`` are probabilities.
|
||||||
targets (array): The binary target values in {0, 1}.
|
targets (array): The binary target values in {0, 1}.
|
||||||
|
with_logits (bool, optional): Whether ``inputs`` are logits. Default: ``True``.
|
||||||
reduction (str, optional): Specifies the reduction to apply to the output:
|
reduction (str, optional): Specifies the reduction to apply to the output:
|
||||||
``'none'`` | ``'mean'`` | ``'sum'``. Default: ``'none'``.
|
``'none'`` | ``'mean'`` | ``'sum'``. Default: ``'mean'``.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
array: The computed binary cross entropy loss.
|
array: The computed binary cross entropy loss.
|
||||||
Examples:
|
Examples:
|
||||||
>>> import mlx.core as mx
|
>>> import mlx.core as mx
|
||||||
>>> import mlx.nn as nn
|
>>> import mlx.nn as nn
|
||||||
>>> inputs = mx.array([0.105361, 0.223144, 1.20397, 0.916291])
|
|
||||||
|
>>> logits = mx.array([0.105361, 0.223144, 1.20397, 0.916291])
|
||||||
>>> targets = mx.array([0, 0, 1, 1])
|
>>> targets = mx.array([0, 0, 1, 1])
|
||||||
>>> loss = nn.losses.binary_cross_entropy(inputs, targets, "mean")
|
>>> loss = nn.losses.binary_cross_entropy(logits, targets, reduction="mean")
|
||||||
>>> loss
|
>>> loss
|
||||||
array([0.612192], dtype=float32)
|
array(0.539245, dtype=float32)
|
||||||
|
|
||||||
|
>>> probs = mx.array([0.1, 0.1, 0.4, 0.4])
|
||||||
|
>>> targets = mx.array([0, 0, 1, 1])
|
||||||
|
>>> loss = nn.losses.binary_cross_entropy(probs, targets, with_logits=False, reduction="mean")
|
||||||
|
>>> loss
|
||||||
|
array(0.510826, dtype=float32)
|
||||||
"""
|
"""
|
||||||
loss = mx.logaddexp(0.0, logits) - targets * logits
|
if inputs.shape != targets.shape:
|
||||||
|
raise ValueError(
|
||||||
|
f"Inputs shape {inputs.shape} does not match targets shape {targets.shape}."
|
||||||
|
)
|
||||||
|
|
||||||
|
if with_logits:
|
||||||
|
loss = mx.logaddexp(0.0, inputs) - inputs * targets
|
||||||
|
else:
|
||||||
|
loss = -(targets * mx.log(inputs) + (1 - targets) * mx.log(1 - inputs))
|
||||||
|
|
||||||
return _reduce(loss, reduction)
|
return _reduce(loss, reduction)
|
||||||
|
|
||||||
|
|
||||||
|
@ -105,6 +105,61 @@ class TestLosses(mlx_tests.MLXTestCase):
|
|||||||
"Test case failed for cross_entropy --label_smoothing=0.3 --reduction='sum'",
|
"Test case failed for cross_entropy --label_smoothing=0.3 --reduction='sum'",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def test_binary_cross_entropy(self):
|
||||||
|
def _test_logits_as_inputs():
|
||||||
|
logits = mx.array([0.105361, 0.223144, 1.20397, 0.916291])
|
||||||
|
targets = mx.array([0, 0, 1, 1])
|
||||||
|
|
||||||
|
# Test with reduction 'none'
|
||||||
|
losses_none = nn.losses.binary_cross_entropy(
|
||||||
|
logits, targets, reduction="none"
|
||||||
|
)
|
||||||
|
expected_none = mx.array([0.747215, 0.810930, 0.262365, 0.336472])
|
||||||
|
self.assertTrue(mx.allclose(losses_none, expected_none))
|
||||||
|
|
||||||
|
# Test with reduction 'mean'
|
||||||
|
losses_mean = nn.losses.binary_cross_entropy(
|
||||||
|
logits, targets, reduction="mean"
|
||||||
|
)
|
||||||
|
expected_mean = mx.mean(expected_none)
|
||||||
|
self.assertEqual(losses_mean, expected_mean)
|
||||||
|
|
||||||
|
# Test with reduction 'sum'
|
||||||
|
losses_sum = nn.losses.binary_cross_entropy(
|
||||||
|
logits, targets, reduction="sum"
|
||||||
|
)
|
||||||
|
expected_sum = mx.sum(expected_none)
|
||||||
|
self.assertEqual(losses_sum, expected_sum)
|
||||||
|
|
||||||
|
def _test_probs_as_inputs():
|
||||||
|
probs = mx.array([0.5, 0.6, 0.7, 0.8])
|
||||||
|
targets = mx.array([0, 0, 1, 1])
|
||||||
|
|
||||||
|
# Test with reduction 'none'
|
||||||
|
losses_none = nn.losses.binary_cross_entropy(
|
||||||
|
probs, targets, with_logits=False, reduction="none"
|
||||||
|
)
|
||||||
|
expected_none = mx.array([0.693147, 0.916291, 0.356675, 0.223144])
|
||||||
|
print(losses_none, expected_none)
|
||||||
|
self.assertTrue(mx.allclose(losses_none, expected_none))
|
||||||
|
|
||||||
|
# Test with reduction 'mean'
|
||||||
|
losses_mean = nn.losses.binary_cross_entropy(
|
||||||
|
probs, targets, with_logits=False, reduction="mean"
|
||||||
|
)
|
||||||
|
expected_mean = mx.mean(expected_none)
|
||||||
|
self.assertTrue(mx.allclose(losses_mean, expected_mean))
|
||||||
|
|
||||||
|
# Test with reduction 'sum'
|
||||||
|
losses_sum = nn.losses.binary_cross_entropy(
|
||||||
|
probs, targets, with_logits=False, reduction="sum"
|
||||||
|
)
|
||||||
|
expected_sum = mx.sum(expected_none)
|
||||||
|
self.assertTrue(mx.allclose(losses_sum, expected_sum))
|
||||||
|
|
||||||
|
_test_logits_as_inputs()
|
||||||
|
_test_probs_as_inputs()
|
||||||
|
|
||||||
def test_l1_loss(self):
|
def test_l1_loss(self):
|
||||||
predictions = mx.array([0.5, 0.2, 0.9, 0.0])
|
predictions = mx.array([0.5, 0.2, 0.9, 0.0])
|
||||||
targets = mx.array([0.5, 0.2, 0.9, 0.0])
|
targets = mx.array([0.5, 0.2, 0.9, 0.0])
|
||||||
|
Loading…
Reference in New Issue
Block a user