Add smoothed L1 loss and enhancements to cross entropy loss (#166)

* Add smooth_l1_loss
* Add labels moothing for cross entropy loss

---------

Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
jojopuppet 2023-12-18 23:26:21 +08:00 committed by GitHub
parent 0e5807bbcb
commit 18cca64c81
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 278 additions and 172 deletions

View File

@ -9,9 +9,10 @@ Loss Functions
:toctree: _autosummary_functions
:template: nn-module-template.rst
cross_entropy
binary_cross_entropy
cross_entropy
kl_div_loss
l1_loss
mse_loss
nll_loss
kl_div_loss
smooth_l1_loss

View File

@ -4,145 +4,138 @@ import mlx.core as mx
from mlx.nn.layers.base import Module
def _make_loss_module(f):
def decorator(klass):
klass.__call__ = lambda self, inputs, targets: f(
inputs, targets, self.reduction
)
return klass
return decorator
def cross_entropy(
logits: mx.array, targets: mx.array, axis: int = -1, reduction: str = "none"
logits: mx.array,
targets: mx.array,
weights: mx.array = None,
axis: int = -1,
label_smoothing: float = 0.0,
reduction: str = "none",
) -> mx.array:
"""
Computes the cross entropy loss between logits and targets.
Computes the cross entropy loss.
Args:
logits (mx.array): The predicted logits.
targets (mx.array): The target values.
logits (array): The unnormalized predicted logits.
targets (array): The target values, as class indices.
weights (array, optional): Weights for each target. Default: ``None``.
axis (int, optional): The axis over which to compute softmax. Default: ``-1``.
label_smoothing (float, optional): Label smoothing factor. Default: ``0``.
reduction (str, optional): Specifies the reduction to apply to the output:
``'none'`` | ``'mean'`` | ``'sum'``. Default: ``'none'``.
``'none'`` | ``'mean'`` | ``'sum'``. Default: ``'none'``.
Returns:
mx.array: The computed cross entropy loss.
array: The computed cross entropy loss.
"""
score = mx.take_along_axis(logits, targets[..., None], axis).squeeze(-1)
loss = mx.logsumexp(logits, axis=axis) - score
if label_smoothing < 0 or label_smoothing >= 1:
raise ValueError(f"Label smoothing must in [0, 1), got {label_smoothing}.")
score = mx.take_along_axis(logits, targets[..., None], axis).squeeze(-1)
logsumexp_logits = mx.logsumexp(logits, axis=axis)
if label_smoothing > 0:
# Adjust the true class score with label smoothing
adjusted_score = (1 - label_smoothing) * score
# Calculate the mean logit across the classes for smoothed loss
mean_logits = logits.mean(axis=axis)
smoothed_loss = -mean_logits * label_smoothing
# Combine the adjusted score and smoothed loss with the logsumexp logits
loss = logsumexp_logits - adjusted_score + smoothed_loss
else:
loss = logsumexp_logits - score
# Apply weights if provided
if weights is not None:
if weights.shape != targets.shape:
raise ValueError(
f"Weights with shape {weights.shape} is not the same as "
f"targets with shape {targets.shape}."
)
loss *= weights
# Apply reduction
return _reduce(loss, reduction)
def binary_cross_entropy(
inputs: mx.array, targets: mx.array, reduction: str = "none"
logits: mx.array, targets: mx.array, reduction: str = "none"
) -> mx.array:
"""
Computes the binary cross entropy loss between inputs and targets.
Computes the binary cross entropy loss.
Args:
inputs (mx.array): The predicted inputs (post-sigmoid probabilities).
targets (mx.array): The target values (binary labels).
logits (array): The unnormalized (pre-sigmoid) predicted logits.
targets (array): The binary target values in {0, 1}.
reduction (str, optional): Specifies the reduction to apply to the output:
``'none'`` | ``'mean'`` | ``'sum'``. Default: ``'none'``.
Returns:
mx.array: The computed binary cross entropy loss.
array: The computed binary cross entropy loss.
Examples:
>>> import mlx.core as mx
>>> import mlx.nn as nn
>>> inputs = mx.array([0.1, 0.2, 0.3, 0.4])
>>> inputs = mx.array([0.105361, 0.223144, 1.20397, 0.916291])
>>> targets = mx.array([0, 0, 1, 1])
>>> loss = nn.losses.binary_cross_entropy(inputs, targets)
>>> loss = nn.losses.binary_cross_entropy(inputs, targets, "mean")
>>> loss
array([0.612192])
array([0.612192], dtype=float32)
"""
loss = -targets * mx.log(inputs) - (1 - targets) * mx.log(1 - inputs)
loss = mx.logaddexp(0.0, logits) - targets * logits
return _reduce(loss, reduction)
@_make_loss_module(binary_cross_entropy)
class BCELoss(Module):
"""
Binary Cross Entropy Loss module.
It computes the binary cross entropy loss between predicted probabilities (post-sigmoid inputs) and target binary labels.
Args:
reduction (str, optional): Specifies the reduction to apply to the output:
- 'none': no reduction (default)
- 'mean': compute the mean loss
- 'sum': compute the sum of the loss
Examples:
>>> import mlx.core as mx
>>> from mlx.nn.losses import BCELoss
>>>
>>> # Create BCELoss module with default reduction ('none')
>>> loss_module_none = BCELoss()
>>> inputs = mx.array([0.5, 0.7, 0.3])
>>> targets = mx.array([1, 0, 1])
>>> loss_none = loss_module_none(inputs, targets)
>>> print(loss_none)
array([0.693147, 1.20397, 1.20397], dtype=float32)
>>> # Create BCELoss module with reduction 'mean'
>>> loss_module_mean = BCELoss(reduction='mean')
>>> loss_mean = loss_module_mean(inputs, targets)
>>> print(loss_mean)
array(1.0337, dtype=float32)
>>> # Create BCELoss module with reduction 'sum'
>>> loss_module_sum = BCELoss(reduction='sum')
>>> loss_sum = loss_module_sum(inputs, targets)
>>> print(loss_sum)
array(3.10109, dtype=float32)
"""
def __init__(self, reduction: str = "none"):
super().__init__()
self.reduction = reduction
def l1_loss(
predictions: mx.array, targets: mx.array, reduction: str = "none"
predictions: mx.array, targets: mx.array, reduction: str = "mean"
) -> mx.array:
"""
Computes the L1 loss between predictions and targets.
Computes the L1 loss.
Args:
predictions (mx.array): The predicted values.
targets (mx.array): The target values.
predictions (array): The predicted values.
targets (array): The target values.
reduction (str, optional): Specifies the reduction to apply to the output:
``'none'`` | ``'mean'`` | ``'sum'``. Default: ``'none'``.
``'none'`` | ``'mean'`` | ``'sum'``. Default: ``'mean'``.
Returns:
mx.array: The computed L1 loss.
array: The computed L1 loss.
"""
loss = mx.mean(mx.abs(predictions - targets))
if predictions.shape != targets.shape:
raise ValueError(
f"Predictions shape {predictions.shape} does not match "
f"targets shape {targets.shape}."
)
loss = mx.abs(predictions - targets)
return _reduce(loss, reduction)
def mse_loss(
predictions: mx.array, targets: mx.array, reduction: str = "none"
predictions: mx.array, targets: mx.array, reduction: str = "mean"
) -> mx.array:
"""
Computes the mean squared error loss between predictions and targets.
Computes the mean squared error loss.
Args:
predictions (mx.array): The predicted values.
targets (mx.array): The target values.
predictions (array): The predicted values.
targets (array): The target values.
reduction (str, optional): Specifies the reduction to apply to the output:
``'none'`` | ``'mean'`` | ``'sum'``. Default: ``'none'``.
``'none'`` | ``'mean'`` | ``'sum'``. Default: ``'mean'``.
Returns:
mx.array: The computed mean squared error loss.
array: The computed mean squared error loss.
"""
loss = mx.square(predictions - targets)
if predictions.shape != targets.shape:
raise ValueError(
f"Predictions shape {predictions.shape} does not match "
f"targets shape {targets.shape}."
)
assert (
predictions.shape == targets.shape
), f"Shape of predictions {predictions.shape} and targets {targets.shape} must match"
loss = mx.square(predictions - targets)
return _reduce(loss, reduction)
@ -150,17 +143,17 @@ def nll_loss(
inputs: mx.array, targets: mx.array, axis: int = -1, reduction: str = "none"
) -> mx.array:
"""
Computes the negative log likelihood loss between inputs and targets.
Computes the negative log likelihood loss.
Args:
inputs (mx.array): The predicted distribution in log space.
targets (mx.array): The target values.
inputs (array): The predicted distribution in log space.
targets (array): The target values.
axis (int, optional): The distribution axis. Default: ``-1``.
reduction (str, optional): Specifies the reduction to apply to the output:
``'none'`` | ``'mean'`` | ``'sum'``. Default: ``'none'``.
Returns:
mx.array: The computed NLL loss.
array: The computed NLL loss.
"""
loss = -mx.take_along_axis(inputs, targets[..., None], axis).squeeze(-1)
@ -171,8 +164,7 @@ def kl_div_loss(
inputs: mx.array, targets: mx.array, axis: int = -1, reduction: str = "none"
) -> mx.array:
"""
Computes the Kullback-Leibler divergence loss between targets and the
inputs.
Computes the Kullback-Leibler divergence loss.
Computes the following when ``reduction == 'none'``:
@ -181,20 +173,65 @@ def kl_div_loss(
mx.exp(targets) * (targets - inputs).sum(axis)
Args:
inputs (mx.array): Log probabilities for the predicted distribution.
targets (mx.array): Log probabilities for the target distribution.
inputs (array): Log probabilities for the predicted distribution.
targets (array): Log probabilities for the target distribution.
axis (int, optional): The distribution axis. Default: ``-1``.
reduction (str, optional): Specifies the reduction to apply to the output:
``'none'`` | ``'mean'`` | ``'sum'``. Default: ``'none'``.
Returns:
mx.array: The computed Kullback-Leibler divergence loss.
array: The computed Kullback-Leibler divergence loss.
"""
loss = mx.sum(mx.exp(targets) * (targets - inputs), axis)
return _reduce(loss, reduction)
def smooth_l1_loss(
predictions: mx.array, targets: mx.array, beta: float = 1.0, reduction: str = "mean"
) -> mx.array:
r"""
Computes the smooth L1 loss.
The smooth L1 loss is a variant of the L1 loss which replaces the absolute
difference with a squared difference when the absolute difference is less
than ``beta``.
The formula for the smooth L1 Loss is:
.. math::
l =
\begin{cases}
0.5 (x - y)^2, & \text{ if } & (x - y) < \beta \\
|x - y| - 0.5 \beta, & & \text{otherwise}
\end{cases}
Args:
predictions (array): Predicted values.
targets (array): Ground truth values.
beta (float, optional): The threshold after which the loss changes
from the squared to the absolute difference. Default: ``1.0``.
reduction (str, optional): Specifies the reduction to apply to the output:
``'none'`` | ``'mean'`` | ``'sum'``. Default: ``'mean'``.
Returns:
array: The computed smooth L1 loss.
"""
if predictions.shape != targets.shape:
raise ValueError(
f"Predictions shape {predictions.shape} does not match "
f"targets shape {targets.shape}."
)
diff = predictions - targets
loss = mx.where(
diff < beta, 0.5 * mx.square(diff) / beta, mx.abs(diff) - 0.5 * beta
)
return _reduce(loss, reduction)
def _reduce(loss: mx.array, reduction: str = "none"):
if reduction == "mean":
return mx.mean(loss)

View File

@ -37,30 +37,169 @@ class TestNN(mlx_tests.MLXTestCase):
expected_sum = mx.sum(expected_none)
self.assertEqual(losses_sum, expected_sum)
# Test cases with weights and no label smoothing
logits = mx.array([[2.0, -1.0], [-1.0, 2.0]])
targets = mx.array([0, 1])
weights = mx.array([1.0, 2.0])
# Reduction 'none'
losses_none = nn.losses.cross_entropy(
logits,
targets,
weights=weights,
reduction="none",
)
expected_none = mx.array([0.04858735, 0.0971747]) # Calculated losses
self.assertTrue(
np.allclose(losses_none, expected_none, atol=1e-5),
"Test case failed for cross_entropy loss --reduction='none' --weights=[1.0, 2.0]",
)
# Reduction 'mean'
losses_mean = nn.losses.cross_entropy(
logits,
targets,
weights=weights,
reduction="mean",
)
expected_mean = mx.mean(expected_none)
self.assertTrue(
np.allclose(losses_mean, expected_mean, atol=1e-5),
"Test case failed for cross_entropy loss --reduction='mean' --weights=[1.0, 2.0]",
)
# Reduction 'sum'
losses_sum = nn.losses.cross_entropy(
logits,
targets,
weights=weights,
reduction="sum",
)
expected_sum = mx.sum(expected_none)
self.assertTrue(
np.allclose(losses_sum, expected_sum, atol=1e-5),
"Test case failed for cross_entropy loss --reduction='sum' --weights=[1.0, 2.0]",
)
# Test case with equal weights and label smoothing > 0
logits = mx.array(
[[0, 0.2, 0.7, 0.1, 0], [0, 0.9, 0.2, 0.2, 1], [1, 0.2, 0.7, 0.9, 1]]
)
target = mx.array([2, 1, 0])
losses_none = nn.losses.cross_entropy(
logits, target, label_smoothing=0.3, reduction="none"
)
expected_none = mx.array([1.29693, 1.38617, 1.48176])
self.assertTrue(
mx.allclose(expected_none, losses_none),
"Test case failed for cross_entropy --label_smoothing=0.3 --reduction='none'",
)
expected_mean = mx.mean(expected_none)
losses_mean = nn.losses.cross_entropy(
logits, target, label_smoothing=0.3, reduction="mean"
)
self.assertTrue(
mx.allclose(losses_mean, expected_mean),
"Test case failed for cross_entropy --label_smoothing=0.3 --reduction='mean'",
)
expected_sum = mx.sum(expected_none)
losses_sum = nn.losses.cross_entropy(
logits, target, label_smoothing=0.3, reduction="sum"
)
self.assertTrue(
mx.allclose(losses_sum, expected_sum),
"Test case failed for cross_entropy --label_smoothing=0.3 --reduction='sum'",
)
def test_l1_loss(self):
predictions = mx.array([0.5, 0.2, 0.9, 0.0])
targets = mx.array([0.5, 0.2, 0.9, 0.0])
# Expected result
expected_none = mx.array([0, 0, 0, 0]).astype(mx.float32)
expected_sum = mx.sum(expected_none)
expected_mean = mx.mean(expected_none)
losses = nn.losses.l1_loss(predictions, targets, reduction="none")
self.assertEqual(losses, 0.0)
self.assertTrue(
mx.array_equal(losses, expected_none),
"Test failed for l1_loss --reduction='none'",
)
losses = nn.losses.l1_loss(predictions, targets, reduction="sum")
self.assertTrue(mx.array_equal(losses, expected_sum))
losses = nn.losses.l1_loss(predictions, targets, reduction="mean")
self.assertTrue(mx.array_equal(losses, expected_mean))
def test_mse_loss(self):
predictions = mx.array([0.5, 0.2, 0.9, 0.0])
targets = mx.array([0.7, 0.1, 0.8, 0.2])
expected_none = mx.array([0.04, 0.01, 0.01, 0.04])
expected_mean = mx.mean(expected_none)
expected_sum = mx.sum(expected_none)
# Test with reduction 'none'
losses_none = nn.losses.mse_loss(predictions, targets, reduction="none")
expected_none = mx.array([0.04, 0.01, 0.01, 0.04])
self.assertTrue(mx.allclose(losses_none, expected_none))
self.assertTrue(
np.allclose(losses_none, expected_none, 1e-5),
"Test case failed for mse_loss --reduction='none'",
)
# Test with reduction 'mean'
losses_mean = nn.losses.mse_loss(predictions, targets, reduction="mean")
expected_mean = mx.mean(expected_none)
self.assertEqual(losses_mean, expected_mean)
self.assertEqual(
losses_mean,
expected_mean,
"Test case failed for mse_loss --reduction='mean'",
)
# Test with reduction 'sum'
losses_sum = nn.losses.mse_loss(predictions, targets, reduction="sum")
self.assertEqual(
losses_sum, expected_sum, "Test case failed for mse_loss --reduction='sum'"
)
def test_smooth_l1_loss(self):
predictions = mx.array([1.5, 2.5, 0.5, 3.5])
targets = mx.array([1.0, 2.0, 0.5, 2.5])
beta = 1.0
# Expected results
expected_none = mx.array([0.125, 0.125, 0.0, 0.5])
expected_sum = mx.sum(expected_none)
self.assertEqual(losses_sum, expected_sum)
expected_mean = mx.mean(expected_none)
# Test with reduction 'none'
loss_none = nn.losses.smooth_l1_loss(
predictions, targets, beta, reduction="none"
)
self.assertTrue(
mx.array_equal(loss_none, expected_none),
"Test case failed for smooth_l1_loss --reduction='none'",
)
# Test with reduction 'sum'
loss_sum = nn.losses.smooth_l1_loss(predictions, targets, beta, reduction="sum")
self.assertEqual(
loss_sum,
expected_sum,
"Test case failed for smooth_l1_loss --reduction='sum'",
)
# Test with reduction 'mean'
loss_mean = nn.losses.smooth_l1_loss(
predictions, targets, beta, reduction="mean"
)
self.assertEqual(
loss_mean,
expected_mean,
"Test case failed for smooth_l1_loss --reduction='mean'",
)
def test_nll_loss(self):
logits = mx.array([[0.0, -float("inf")], [-float("inf"), 0.0]])
@ -100,77 +239,6 @@ class TestNN(mlx_tests.MLXTestCase):
expected_sum = mx.sum(expected_none)
self.assertTrue(mx.allclose(losses_sum, expected_sum))
def test_binary_cross_entropy(self):
inputs = mx.array([[0.5, 0.5, 0.2, 0.9], [0.1, 0.3, 0.5, 0.5]])
targets = mx.array([[0.0, 1.0, 0.0, 1.0], [1.0, 0.0, 1.0, 0.0]])
# Test with reduction 'none'
losses_none = nn.losses.binary_cross_entropy(inputs, targets, reduction="none")
expected_none = mx.array(
[
[
0.6931471824645996,
0.6931471824645996,
0.2231435477733612,
0.10536054521799088,
],
[
2.3025851249694824,
0.3566749691963196,
0.6931471824645996,
0.6931471824645996,
],
]
)
self.assertTrue(mx.allclose(losses_none, expected_none, rtol=1e-5, atol=1e-8))
# Test with reduction 'mean'
losses_mean = nn.losses.binary_cross_entropy(inputs, targets, reduction="mean")
expected_mean = mx.mean(expected_none)
self.assertTrue(mx.allclose(losses_mean, expected_mean))
# Test with reduction 'sum'
losses_sum = nn.losses.binary_cross_entropy(inputs, targets, reduction="sum")
expected_sum = mx.sum(expected_none)
self.assertTrue(mx.allclose(losses_sum, expected_sum))
def test_bce_loss_module(self):
inputs = mx.array([[0.5, 0.5, 0.2, 0.9], [0.1, 0.3, 0.5, 0.5]])
targets = mx.array([[0.0, 1.0, 0.0, 1.0], [1.0, 0.0, 1.0, 0.0]])
# Test with reduction 'none'
loss_module_none = nn.losses.BCELoss(reduction="none")
losses_none = loss_module_none(inputs, targets)
expected_none = mx.array(
[
[
0.6931471824645996,
0.6931471824645996,
0.2231435477733612,
0.10536054521799088,
],
[
2.3025851249694824,
0.3566749691963196,
0.6931471824645996,
0.6931471824645996,
],
]
)
self.assertTrue(mx.allclose(losses_none, expected_none, rtol=1e-5, atol=1e-8))
# Test with reduction 'mean'
loss_module_mean = nn.losses.BCELoss(reduction="mean")
losses_mean = loss_module_mean(inputs, targets)
expected_mean = mx.mean(expected_none)
self.assertTrue(mx.allclose(losses_mean, expected_mean))
# Test with reduction 'sum'
loss_module_sum = nn.losses.BCELoss(reduction="sum")
losses_sum = loss_module_sum(inputs, targets)
expected_sum = mx.sum(expected_none)
self.assertTrue(mx.allclose(losses_sum, expected_sum))
def test_gelu(self):
inputs = [1.15286231, -0.81037411, 0.35816911, 0.77484438, 0.66276414]