feat: Add new loss functions for neural networks

This commit is contained in:
NripeshN 2023-12-31 01:36:34 +05:30
parent a020a2d49d
commit 15307c5367
2 changed files with 235 additions and 1 deletions

View File

@ -3,7 +3,6 @@
import math import math
import mlx.core as mx import mlx.core as mx
from mlx.nn.layers.base import Module
def cross_entropy( def cross_entropy(
@ -372,3 +371,174 @@ def log_cosh_loss(
loss = mx.logaddexp(errors, -errors) - math.log(2) loss = mx.logaddexp(errors, -errors) - math.log(2)
return _reduce(loss, reduction) return _reduce(loss, reduction)
def focal_loss(
inputs: mx.array,
targets: mx.array,
alpha: float = 0.25,
gamma: float = 2.0,
reduction: str = "none"
) -> mx.array:
r"""
Computes the Focal Loss between inputs and targets, which is designed to address
class imbalance by focusing more on hard-to-classify examples.
.. math::
FL(p_t) = -\alpha_t (1 - p_t)^{\gamma} \log(p_t)
Args:
inputs (array): The predicted logits or probabilities.
targets (array): The ground truth target values.
alpha (float, optional): The balancing parameter. Default: 0.25.
gamma (float, optional): The focusing parameter. Default: 2.0.
reduction (str, optional): Specifies the reduction to apply to the output:
``'none'`` | ``'mean'`` | ``'sum'``. Default: ``'none'``.
Returns:
array: The computed Focal Loss.
"""
if gamma < 0:
raise ValueError(f"Focal loss gamma must be non-negative, got {gamma}.")
# Calculating the cross-entropy loss
ce_loss = mx.logaddexp(0.0, inputs) - targets * inputs
# Calculating the probability
pt = mx.exp(-ce_loss)
# Calculating Focal Loss
focal_loss = -alpha * ((1 - pt) ** gamma) * ce_loss
return _reduce(focal_loss, reduction)
def dice_loss(
inputs: mx.array,
targets: mx.array,
epsilon: float = 1e-6,
reduction: str = "none"
) -> mx.array:
r"""
Computes the Dice Loss, which is a measure of overlap between two samples.
This loss is commonly used for binary segmentation tasks.
.. math::
\text{Dice Loss} = 1 - \frac{2 \times |X \cap Y|}{|X| + |Y|}
Args:
inputs (array): The predicted values.
targets (array): The ground truth values.
epsilon (float, optional): Small constant for numerical stability. Default: 1e-6.
reduction (str, optional): Specifies the reduction to apply to the output:
``'none'`` | ``'mean'`` | ``'sum'``. Default: ``'none'``.
Returns:
array: The computed Dice Loss.
"""
intersection = mx.sum(inputs * targets, axis=1)
cardinality = mx.sum(inputs + targets, axis=1)
dice_score = (2. * intersection + epsilon) / (cardinality + epsilon)
loss = 1 - dice_score
return _reduce(loss, reduction)
def iou_loss(
inputs: mx.array,
targets: mx.array,
epsilon: float = 1e-6,
reduction: str = "none"
) -> mx.array:
r"""
Computes the Intersection over Union (IoU) Loss, which is a measure of the
overlap between two sets, typically used in segmentation tasks.
.. math::
\text{IoU Loss} = 1 - \frac{X \cap Y}{X \cup Y}
Args:
inputs (array): The predicted values.
targets (array): The ground truth values.
epsilon (float, optional): Small constant for numerical stability. Default: 1e-6.
reduction (str, optional): Specifies the reduction to apply to the output:
``'none'`` | ``'mean'`` | ``'sum'``. Default: ``'none'``.
Returns:
array: The computed IoU Loss.
"""
intersection = mx.sum(inputs * targets, axis=1)
union = mx.sum(inputs + targets - inputs * targets, axis=1)
iou_score = (intersection + epsilon) / (union + epsilon)
loss = 1 - iou_score
return _reduce(loss, reduction)
def contrastive_loss(
anchors: mx.array,
positives: mx.array,
negatives: mx.array,
margin: float = 1.0,
p: int = 2,
reduction: str = "none"
) -> mx.array:
r"""
Computes the Contrastive Loss for a set of anchor, positive, and negative samples.
.. math::
L_{\text{contrastive}} = \max\left(\|A - P\|_p - \|A - N\|_p + \text{margin}, 0\right)
Args:
anchors (array): The anchor samples.
positives (array): The positive samples.
negatives (array): The negative samples.
margin (float, optional): Margin for the contrastive loss. Defaults to ``1.0``.
p (int, optional): The norm degree for pairwise distance. Default: ``2``.
reduction (str, optional): Specifies the reduction to apply to the output:
``'none'`` | ``'mean'`` | ``'sum'``. Default: ``'none'``.
Returns:
array: Computed contrastive loss.
"""
positive_distance = mx.sqrt(mx.power(anchors - positives, p).sum(axis=1))
negative_distance = mx.sqrt(mx.power(anchors - negatives, p).sum(axis=1))
loss = mx.maximum(positive_distance - negative_distance + margin, 0)
return _reduce(loss, reduction)
def tversky_loss(
inputs: mx.array,
targets: mx.array,
alpha: float = 0.5,
beta: float = 0.5,
epsilon: float = 1e-6,
reduction: str = "none"
) -> mx.array:
r"""
Computes the Tversky Loss, a generalization of the Dice Loss, allowing more control over false
positives and false negatives. It is particularly useful in segmentation tasks with imbalanced datasets.
.. math::
\text{Tversky Loss} = 1 - \frac{|X \cap Y|}{|X \cap Y| + \alpha |X \backslash Y| + \beta |Y \backslash X|}
Args:
inputs (array): The predicted values.
targets (array): The ground truth values.
alpha (float, optional): Controls the penalty for false negatives. Default: 0.5.
beta (float, optional): Controls the penalty for false positives. Default: 0.5.
epsilon (float, optional): Small constant for numerical stability. Default: 1e-6.
reduction (str, optional): Specifies the reduction to apply to the output:
``'none'`` | ``'mean'`` | ``'sum'``. Default: ``'none'``.
Returns:
array: The computed Tversky Loss.
"""
intersection = mx.sum(inputs * targets, axis=1)
false_negatives = mx.sum(inputs * (1 - targets), axis=1)
false_positives = mx.sum((1 - inputs) * targets, axis=1)
tversky_index = (intersection + epsilon) / (intersection + alpha * false_negatives + beta * false_positives + epsilon)
loss = 1 - tversky_index
return _reduce(loss, reduction)

View File

@ -791,6 +791,70 @@ class TestNN(mlx_tests.MLXTestCase):
targets = mx.zeros((2, 4)) targets = mx.zeros((2, 4))
loss = nn.losses.log_cosh_loss(inputs, targets, reduction="mean") loss = nn.losses.log_cosh_loss(inputs, targets, reduction="mean")
self.assertAlmostEqual(loss.item(), 0.433781, places=6) self.assertAlmostEqual(loss.item(), 0.433781, places=6)
def test_focal_loss(self):
inputs = mx.array([[2.0, -1.0, 3.0, 0.1], [-1.0, 2.0, -0.5, 0.2]])
targets = mx.array([[1.0, 0.0, 0.0, 1.0], [0.0, 1.0, 0.0, 1.0]])
alpha = 0.25
gamma = 2.0
ce_loss = mx.logaddexp(0.0, inputs) - targets * inputs
pt = mx.exp(-ce_loss)
expected_loss = -alpha * ((1 - pt) ** gamma) * ce_loss
expected_loss = mx.mean(expected_loss)
loss = nn.losses.focal_loss(inputs, targets, alpha=alpha, gamma=gamma, reduction="mean")
self.assertAlmostEqual(loss.item(), expected_loss.item(), places=6)
def test_dice_loss(self):
inputs = mx.array([[1, 0, 1, 1], [0, 1, 1, 0]])
targets = mx.array([[1, 1, 1, 0], [0, 0, 1, 1]])
epsilon = 1e-6
intersection = mx.sum(inputs * targets, axis=1)
cardinality = mx.sum(inputs + targets, axis=1)
dice_score = (2. * intersection + epsilon) / (cardinality + epsilon)
expected_loss = 1 - dice_score
expected_loss = mx.mean(expected_loss)
loss = nn.losses.dice_loss(inputs, targets, epsilon=epsilon, reduction="mean")
self.assertAlmostEqual(loss.item(), expected_loss.item(), places=6)
def test_iou_loss(self):
inputs = mx.array([[1, 0, 1, 1], [0, 1, 1, 0]])
targets = mx.array([[1, 1, 1, 0], [0, 0, 1, 1]])
epsilon = 1e-6
intersection = mx.sum(inputs * targets, axis=1)
union = mx.sum(inputs + targets - inputs * targets, axis=1)
iou_score = (intersection + epsilon) / (union + epsilon)
expected_loss = 1 - iou_score
expected_loss = mx.mean(expected_loss)
loss = nn.losses.iou_loss(inputs, targets, epsilon=epsilon, reduction="mean")
self.assertAlmostEqual(loss.item(), expected_loss.item(), places=6)
def test_contrastive_loss(self):
anchors = mx.array([[1, 2], [3, 4]])
positives = mx.array([[1, 3], [2, 4]])
negatives = mx.array([[5, 6], [7, 8]])
margin = 1.0
p = 2
positive_distance = mx.sqrt(mx.power(anchors - positives, p).sum(axis=1))
negative_distance = mx.sqrt(mx.power(anchors - negatives, p).sum(axis=1))
expected_loss = mx.maximum(positive_distance - negative_distance + margin, 0)
expected_loss = mx.mean(expected_loss)
loss = nn.losses.contrastive_loss(anchors, positives, negatives, margin=margin, p=p, reduction="mean")
self.assertAlmostEqual(loss.item(), expected_loss.item(), places=6)
def test_tversky_loss(self):
inputs = mx.array([[1, 0, 1, 1], [0, 1, 1, 0]])
targets = mx.array([[1, 1, 1, 0], [0, 0, 1, 1]])
alpha = 0.5
beta = 0.5
epsilon = 1e-6
intersection = mx.sum(inputs * targets, axis=1)
false_negatives = mx.sum(inputs * (1 - targets), axis=1)
false_positives = mx.sum((1 - inputs) * targets, axis=1)
tversky_index = (intersection + epsilon) / (intersection + alpha * false_negatives + beta * false_positives + epsilon)
expected_loss = 1 - tversky_index
expected_loss = mx.mean(expected_loss)
loss = nn.losses.tversky_loss(inputs, targets, alpha=alpha, beta=beta, epsilon=epsilon, reduction="mean")
self.assertAlmostEqual(loss.item(), expected_loss.item(), places=6)
if __name__ == "__main__": if __name__ == "__main__":