mirror of
https://github.com/ml-explore/mlx.git
synced 2025-08-21 20:46:46 +08:00
feat: Add new loss functions for neural networks
This commit is contained in:
parent
a020a2d49d
commit
15307c5367
@ -3,7 +3,6 @@
|
||||
import math
|
||||
|
||||
import mlx.core as mx
|
||||
from mlx.nn.layers.base import Module
|
||||
|
||||
|
||||
def cross_entropy(
|
||||
@ -372,3 +371,174 @@ def log_cosh_loss(
|
||||
loss = mx.logaddexp(errors, -errors) - math.log(2)
|
||||
|
||||
return _reduce(loss, reduction)
|
||||
|
||||
def focal_loss(
|
||||
inputs: mx.array,
|
||||
targets: mx.array,
|
||||
alpha: float = 0.25,
|
||||
gamma: float = 2.0,
|
||||
reduction: str = "none"
|
||||
) -> mx.array:
|
||||
r"""
|
||||
Computes the Focal Loss between inputs and targets, which is designed to address
|
||||
class imbalance by focusing more on hard-to-classify examples.
|
||||
|
||||
.. math::
|
||||
|
||||
FL(p_t) = -\alpha_t (1 - p_t)^{\gamma} \log(p_t)
|
||||
|
||||
Args:
|
||||
inputs (array): The predicted logits or probabilities.
|
||||
targets (array): The ground truth target values.
|
||||
alpha (float, optional): The balancing parameter. Default: 0.25.
|
||||
gamma (float, optional): The focusing parameter. Default: 2.0.
|
||||
reduction (str, optional): Specifies the reduction to apply to the output:
|
||||
``'none'`` | ``'mean'`` | ``'sum'``. Default: ``'none'``.
|
||||
|
||||
Returns:
|
||||
array: The computed Focal Loss.
|
||||
"""
|
||||
if gamma < 0:
|
||||
raise ValueError(f"Focal loss gamma must be non-negative, got {gamma}.")
|
||||
|
||||
# Calculating the cross-entropy loss
|
||||
ce_loss = mx.logaddexp(0.0, inputs) - targets * inputs
|
||||
|
||||
# Calculating the probability
|
||||
pt = mx.exp(-ce_loss)
|
||||
|
||||
# Calculating Focal Loss
|
||||
focal_loss = -alpha * ((1 - pt) ** gamma) * ce_loss
|
||||
|
||||
return _reduce(focal_loss, reduction)
|
||||
|
||||
def dice_loss(
|
||||
inputs: mx.array,
|
||||
targets: mx.array,
|
||||
epsilon: float = 1e-6,
|
||||
reduction: str = "none"
|
||||
) -> mx.array:
|
||||
r"""
|
||||
Computes the Dice Loss, which is a measure of overlap between two samples.
|
||||
This loss is commonly used for binary segmentation tasks.
|
||||
|
||||
.. math::
|
||||
|
||||
\text{Dice Loss} = 1 - \frac{2 \times |X \cap Y|}{|X| + |Y|}
|
||||
|
||||
Args:
|
||||
inputs (array): The predicted values.
|
||||
targets (array): The ground truth values.
|
||||
epsilon (float, optional): Small constant for numerical stability. Default: 1e-6.
|
||||
reduction (str, optional): Specifies the reduction to apply to the output:
|
||||
``'none'`` | ``'mean'`` | ``'sum'``. Default: ``'none'``.
|
||||
|
||||
Returns:
|
||||
array: The computed Dice Loss.
|
||||
"""
|
||||
intersection = mx.sum(inputs * targets, axis=1)
|
||||
cardinality = mx.sum(inputs + targets, axis=1)
|
||||
dice_score = (2. * intersection + epsilon) / (cardinality + epsilon)
|
||||
loss = 1 - dice_score
|
||||
|
||||
return _reduce(loss, reduction)
|
||||
|
||||
def iou_loss(
|
||||
inputs: mx.array,
|
||||
targets: mx.array,
|
||||
epsilon: float = 1e-6,
|
||||
reduction: str = "none"
|
||||
) -> mx.array:
|
||||
r"""
|
||||
Computes the Intersection over Union (IoU) Loss, which is a measure of the
|
||||
overlap between two sets, typically used in segmentation tasks.
|
||||
|
||||
.. math::
|
||||
|
||||
\text{IoU Loss} = 1 - \frac{X \cap Y}{X \cup Y}
|
||||
|
||||
Args:
|
||||
inputs (array): The predicted values.
|
||||
targets (array): The ground truth values.
|
||||
epsilon (float, optional): Small constant for numerical stability. Default: 1e-6.
|
||||
reduction (str, optional): Specifies the reduction to apply to the output:
|
||||
``'none'`` | ``'mean'`` | ``'sum'``. Default: ``'none'``.
|
||||
|
||||
Returns:
|
||||
array: The computed IoU Loss.
|
||||
"""
|
||||
intersection = mx.sum(inputs * targets, axis=1)
|
||||
union = mx.sum(inputs + targets - inputs * targets, axis=1)
|
||||
iou_score = (intersection + epsilon) / (union + epsilon)
|
||||
loss = 1 - iou_score
|
||||
|
||||
return _reduce(loss, reduction)
|
||||
|
||||
def contrastive_loss(
|
||||
anchors: mx.array,
|
||||
positives: mx.array,
|
||||
negatives: mx.array,
|
||||
margin: float = 1.0,
|
||||
p: int = 2,
|
||||
reduction: str = "none"
|
||||
) -> mx.array:
|
||||
r"""
|
||||
Computes the Contrastive Loss for a set of anchor, positive, and negative samples.
|
||||
|
||||
.. math::
|
||||
|
||||
L_{\text{contrastive}} = \max\left(\|A - P\|_p - \|A - N\|_p + \text{margin}, 0\right)
|
||||
|
||||
Args:
|
||||
anchors (array): The anchor samples.
|
||||
positives (array): The positive samples.
|
||||
negatives (array): The negative samples.
|
||||
margin (float, optional): Margin for the contrastive loss. Defaults to ``1.0``.
|
||||
p (int, optional): The norm degree for pairwise distance. Default: ``2``.
|
||||
reduction (str, optional): Specifies the reduction to apply to the output:
|
||||
``'none'`` | ``'mean'`` | ``'sum'``. Default: ``'none'``.
|
||||
|
||||
Returns:
|
||||
array: Computed contrastive loss.
|
||||
"""
|
||||
positive_distance = mx.sqrt(mx.power(anchors - positives, p).sum(axis=1))
|
||||
negative_distance = mx.sqrt(mx.power(anchors - negatives, p).sum(axis=1))
|
||||
loss = mx.maximum(positive_distance - negative_distance + margin, 0)
|
||||
|
||||
return _reduce(loss, reduction)
|
||||
|
||||
def tversky_loss(
|
||||
inputs: mx.array,
|
||||
targets: mx.array,
|
||||
alpha: float = 0.5,
|
||||
beta: float = 0.5,
|
||||
epsilon: float = 1e-6,
|
||||
reduction: str = "none"
|
||||
) -> mx.array:
|
||||
r"""
|
||||
Computes the Tversky Loss, a generalization of the Dice Loss, allowing more control over false
|
||||
positives and false negatives. It is particularly useful in segmentation tasks with imbalanced datasets.
|
||||
|
||||
.. math::
|
||||
|
||||
\text{Tversky Loss} = 1 - \frac{|X \cap Y|}{|X \cap Y| + \alpha |X \backslash Y| + \beta |Y \backslash X|}
|
||||
|
||||
Args:
|
||||
inputs (array): The predicted values.
|
||||
targets (array): The ground truth values.
|
||||
alpha (float, optional): Controls the penalty for false negatives. Default: 0.5.
|
||||
beta (float, optional): Controls the penalty for false positives. Default: 0.5.
|
||||
epsilon (float, optional): Small constant for numerical stability. Default: 1e-6.
|
||||
reduction (str, optional): Specifies the reduction to apply to the output:
|
||||
``'none'`` | ``'mean'`` | ``'sum'``. Default: ``'none'``.
|
||||
|
||||
Returns:
|
||||
array: The computed Tversky Loss.
|
||||
"""
|
||||
intersection = mx.sum(inputs * targets, axis=1)
|
||||
false_negatives = mx.sum(inputs * (1 - targets), axis=1)
|
||||
false_positives = mx.sum((1 - inputs) * targets, axis=1)
|
||||
tversky_index = (intersection + epsilon) / (intersection + alpha * false_negatives + beta * false_positives + epsilon)
|
||||
loss = 1 - tversky_index
|
||||
|
||||
return _reduce(loss, reduction)
|
||||
|
@ -792,6 +792,70 @@ class TestNN(mlx_tests.MLXTestCase):
|
||||
loss = nn.losses.log_cosh_loss(inputs, targets, reduction="mean")
|
||||
self.assertAlmostEqual(loss.item(), 0.433781, places=6)
|
||||
|
||||
def test_focal_loss(self):
|
||||
inputs = mx.array([[2.0, -1.0, 3.0, 0.1], [-1.0, 2.0, -0.5, 0.2]])
|
||||
targets = mx.array([[1.0, 0.0, 0.0, 1.0], [0.0, 1.0, 0.0, 1.0]])
|
||||
alpha = 0.25
|
||||
gamma = 2.0
|
||||
ce_loss = mx.logaddexp(0.0, inputs) - targets * inputs
|
||||
pt = mx.exp(-ce_loss)
|
||||
expected_loss = -alpha * ((1 - pt) ** gamma) * ce_loss
|
||||
expected_loss = mx.mean(expected_loss)
|
||||
loss = nn.losses.focal_loss(inputs, targets, alpha=alpha, gamma=gamma, reduction="mean")
|
||||
self.assertAlmostEqual(loss.item(), expected_loss.item(), places=6)
|
||||
|
||||
def test_dice_loss(self):
|
||||
inputs = mx.array([[1, 0, 1, 1], [0, 1, 1, 0]])
|
||||
targets = mx.array([[1, 1, 1, 0], [0, 0, 1, 1]])
|
||||
epsilon = 1e-6
|
||||
intersection = mx.sum(inputs * targets, axis=1)
|
||||
cardinality = mx.sum(inputs + targets, axis=1)
|
||||
dice_score = (2. * intersection + epsilon) / (cardinality + epsilon)
|
||||
expected_loss = 1 - dice_score
|
||||
expected_loss = mx.mean(expected_loss)
|
||||
loss = nn.losses.dice_loss(inputs, targets, epsilon=epsilon, reduction="mean")
|
||||
self.assertAlmostEqual(loss.item(), expected_loss.item(), places=6)
|
||||
|
||||
def test_iou_loss(self):
|
||||
inputs = mx.array([[1, 0, 1, 1], [0, 1, 1, 0]])
|
||||
targets = mx.array([[1, 1, 1, 0], [0, 0, 1, 1]])
|
||||
epsilon = 1e-6
|
||||
intersection = mx.sum(inputs * targets, axis=1)
|
||||
union = mx.sum(inputs + targets - inputs * targets, axis=1)
|
||||
iou_score = (intersection + epsilon) / (union + epsilon)
|
||||
expected_loss = 1 - iou_score
|
||||
expected_loss = mx.mean(expected_loss)
|
||||
loss = nn.losses.iou_loss(inputs, targets, epsilon=epsilon, reduction="mean")
|
||||
self.assertAlmostEqual(loss.item(), expected_loss.item(), places=6)
|
||||
|
||||
def test_contrastive_loss(self):
|
||||
anchors = mx.array([[1, 2], [3, 4]])
|
||||
positives = mx.array([[1, 3], [2, 4]])
|
||||
negatives = mx.array([[5, 6], [7, 8]])
|
||||
margin = 1.0
|
||||
p = 2
|
||||
positive_distance = mx.sqrt(mx.power(anchors - positives, p).sum(axis=1))
|
||||
negative_distance = mx.sqrt(mx.power(anchors - negatives, p).sum(axis=1))
|
||||
expected_loss = mx.maximum(positive_distance - negative_distance + margin, 0)
|
||||
expected_loss = mx.mean(expected_loss)
|
||||
loss = nn.losses.contrastive_loss(anchors, positives, negatives, margin=margin, p=p, reduction="mean")
|
||||
self.assertAlmostEqual(loss.item(), expected_loss.item(), places=6)
|
||||
|
||||
def test_tversky_loss(self):
|
||||
inputs = mx.array([[1, 0, 1, 1], [0, 1, 1, 0]])
|
||||
targets = mx.array([[1, 1, 1, 0], [0, 0, 1, 1]])
|
||||
alpha = 0.5
|
||||
beta = 0.5
|
||||
epsilon = 1e-6
|
||||
intersection = mx.sum(inputs * targets, axis=1)
|
||||
false_negatives = mx.sum(inputs * (1 - targets), axis=1)
|
||||
false_positives = mx.sum((1 - inputs) * targets, axis=1)
|
||||
tversky_index = (intersection + epsilon) / (intersection + alpha * false_negatives + beta * false_positives + epsilon)
|
||||
expected_loss = 1 - tversky_index
|
||||
expected_loss = mx.mean(expected_loss)
|
||||
loss = nn.losses.tversky_loss(inputs, targets, alpha=alpha, beta=beta, epsilon=epsilon, reduction="mean")
|
||||
self.assertAlmostEqual(loss.item(), expected_loss.item(), places=6)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
Loading…
Reference in New Issue
Block a user