From 15307c5367e65736322ba597f9cc0a5a9bddd989 Mon Sep 17 00:00:00 2001 From: NripeshN Date: Sun, 31 Dec 2023 01:36:34 +0530 Subject: [PATCH] feat: Add new loss functions for neural networks --- python/mlx/nn/losses.py | 172 +++++++++++++++++++++++++++++++++++++++- python/tests/test_nn.py | 64 +++++++++++++++ 2 files changed, 235 insertions(+), 1 deletion(-) diff --git a/python/mlx/nn/losses.py b/python/mlx/nn/losses.py index 91316fd04..c3e1a86fa 100644 --- a/python/mlx/nn/losses.py +++ b/python/mlx/nn/losses.py @@ -3,7 +3,6 @@ import math import mlx.core as mx -from mlx.nn.layers.base import Module def cross_entropy( @@ -372,3 +371,174 @@ def log_cosh_loss( loss = mx.logaddexp(errors, -errors) - math.log(2) return _reduce(loss, reduction) + +def focal_loss( + inputs: mx.array, + targets: mx.array, + alpha: float = 0.25, + gamma: float = 2.0, + reduction: str = "none" +) -> mx.array: + r""" + Computes the Focal Loss between inputs and targets, which is designed to address + class imbalance by focusing more on hard-to-classify examples. + + .. math:: + + FL(p_t) = -\alpha_t (1 - p_t)^{\gamma} \log(p_t) + + Args: + inputs (array): The predicted logits or probabilities. + targets (array): The ground truth target values. + alpha (float, optional): The balancing parameter. Default: 0.25. + gamma (float, optional): The focusing parameter. Default: 2.0. + reduction (str, optional): Specifies the reduction to apply to the output: + ``'none'`` | ``'mean'`` | ``'sum'``. Default: ``'none'``. + + Returns: + array: The computed Focal Loss. + """ + if gamma < 0: + raise ValueError(f"Focal loss gamma must be non-negative, got {gamma}.") + + # Calculating the cross-entropy loss + ce_loss = mx.logaddexp(0.0, inputs) - targets * inputs + + # Calculating the probability + pt = mx.exp(-ce_loss) + + # Calculating Focal Loss + focal_loss = -alpha * ((1 - pt) ** gamma) * ce_loss + + return _reduce(focal_loss, reduction) + +def dice_loss( + inputs: mx.array, + targets: mx.array, + epsilon: float = 1e-6, + reduction: str = "none" +) -> mx.array: + r""" + Computes the Dice Loss, which is a measure of overlap between two samples. + This loss is commonly used for binary segmentation tasks. + + .. math:: + + \text{Dice Loss} = 1 - \frac{2 \times |X \cap Y|}{|X| + |Y|} + + Args: + inputs (array): The predicted values. + targets (array): The ground truth values. + epsilon (float, optional): Small constant for numerical stability. Default: 1e-6. + reduction (str, optional): Specifies the reduction to apply to the output: + ``'none'`` | ``'mean'`` | ``'sum'``. Default: ``'none'``. + + Returns: + array: The computed Dice Loss. + """ + intersection = mx.sum(inputs * targets, axis=1) + cardinality = mx.sum(inputs + targets, axis=1) + dice_score = (2. * intersection + epsilon) / (cardinality + epsilon) + loss = 1 - dice_score + + return _reduce(loss, reduction) + +def iou_loss( + inputs: mx.array, + targets: mx.array, + epsilon: float = 1e-6, + reduction: str = "none" +) -> mx.array: + r""" + Computes the Intersection over Union (IoU) Loss, which is a measure of the + overlap between two sets, typically used in segmentation tasks. + + .. math:: + + \text{IoU Loss} = 1 - \frac{X \cap Y}{X \cup Y} + + Args: + inputs (array): The predicted values. + targets (array): The ground truth values. + epsilon (float, optional): Small constant for numerical stability. Default: 1e-6. + reduction (str, optional): Specifies the reduction to apply to the output: + ``'none'`` | ``'mean'`` | ``'sum'``. Default: ``'none'``. + + Returns: + array: The computed IoU Loss. + """ + intersection = mx.sum(inputs * targets, axis=1) + union = mx.sum(inputs + targets - inputs * targets, axis=1) + iou_score = (intersection + epsilon) / (union + epsilon) + loss = 1 - iou_score + + return _reduce(loss, reduction) + +def contrastive_loss( + anchors: mx.array, + positives: mx.array, + negatives: mx.array, + margin: float = 1.0, + p: int = 2, + reduction: str = "none" +) -> mx.array: + r""" + Computes the Contrastive Loss for a set of anchor, positive, and negative samples. + + .. math:: + + L_{\text{contrastive}} = \max\left(\|A - P\|_p - \|A - N\|_p + \text{margin}, 0\right) + + Args: + anchors (array): The anchor samples. + positives (array): The positive samples. + negatives (array): The negative samples. + margin (float, optional): Margin for the contrastive loss. Defaults to ``1.0``. + p (int, optional): The norm degree for pairwise distance. Default: ``2``. + reduction (str, optional): Specifies the reduction to apply to the output: + ``'none'`` | ``'mean'`` | ``'sum'``. Default: ``'none'``. + + Returns: + array: Computed contrastive loss. + """ + positive_distance = mx.sqrt(mx.power(anchors - positives, p).sum(axis=1)) + negative_distance = mx.sqrt(mx.power(anchors - negatives, p).sum(axis=1)) + loss = mx.maximum(positive_distance - negative_distance + margin, 0) + + return _reduce(loss, reduction) + +def tversky_loss( + inputs: mx.array, + targets: mx.array, + alpha: float = 0.5, + beta: float = 0.5, + epsilon: float = 1e-6, + reduction: str = "none" +) -> mx.array: + r""" + Computes the Tversky Loss, a generalization of the Dice Loss, allowing more control over false + positives and false negatives. It is particularly useful in segmentation tasks with imbalanced datasets. + + .. math:: + + \text{Tversky Loss} = 1 - \frac{|X \cap Y|}{|X \cap Y| + \alpha |X \backslash Y| + \beta |Y \backslash X|} + + Args: + inputs (array): The predicted values. + targets (array): The ground truth values. + alpha (float, optional): Controls the penalty for false negatives. Default: 0.5. + beta (float, optional): Controls the penalty for false positives. Default: 0.5. + epsilon (float, optional): Small constant for numerical stability. Default: 1e-6. + reduction (str, optional): Specifies the reduction to apply to the output: + ``'none'`` | ``'mean'`` | ``'sum'``. Default: ``'none'``. + + Returns: + array: The computed Tversky Loss. + """ + intersection = mx.sum(inputs * targets, axis=1) + false_negatives = mx.sum(inputs * (1 - targets), axis=1) + false_positives = mx.sum((1 - inputs) * targets, axis=1) + tversky_index = (intersection + epsilon) / (intersection + alpha * false_negatives + beta * false_positives + epsilon) + loss = 1 - tversky_index + + return _reduce(loss, reduction) diff --git a/python/tests/test_nn.py b/python/tests/test_nn.py index 204145c01..2be9a4c65 100644 --- a/python/tests/test_nn.py +++ b/python/tests/test_nn.py @@ -791,6 +791,70 @@ class TestNN(mlx_tests.MLXTestCase): targets = mx.zeros((2, 4)) loss = nn.losses.log_cosh_loss(inputs, targets, reduction="mean") self.assertAlmostEqual(loss.item(), 0.433781, places=6) + + def test_focal_loss(self): + inputs = mx.array([[2.0, -1.0, 3.0, 0.1], [-1.0, 2.0, -0.5, 0.2]]) + targets = mx.array([[1.0, 0.0, 0.0, 1.0], [0.0, 1.0, 0.0, 1.0]]) + alpha = 0.25 + gamma = 2.0 + ce_loss = mx.logaddexp(0.0, inputs) - targets * inputs + pt = mx.exp(-ce_loss) + expected_loss = -alpha * ((1 - pt) ** gamma) * ce_loss + expected_loss = mx.mean(expected_loss) + loss = nn.losses.focal_loss(inputs, targets, alpha=alpha, gamma=gamma, reduction="mean") + self.assertAlmostEqual(loss.item(), expected_loss.item(), places=6) + + def test_dice_loss(self): + inputs = mx.array([[1, 0, 1, 1], [0, 1, 1, 0]]) + targets = mx.array([[1, 1, 1, 0], [0, 0, 1, 1]]) + epsilon = 1e-6 + intersection = mx.sum(inputs * targets, axis=1) + cardinality = mx.sum(inputs + targets, axis=1) + dice_score = (2. * intersection + epsilon) / (cardinality + epsilon) + expected_loss = 1 - dice_score + expected_loss = mx.mean(expected_loss) + loss = nn.losses.dice_loss(inputs, targets, epsilon=epsilon, reduction="mean") + self.assertAlmostEqual(loss.item(), expected_loss.item(), places=6) + + def test_iou_loss(self): + inputs = mx.array([[1, 0, 1, 1], [0, 1, 1, 0]]) + targets = mx.array([[1, 1, 1, 0], [0, 0, 1, 1]]) + epsilon = 1e-6 + intersection = mx.sum(inputs * targets, axis=1) + union = mx.sum(inputs + targets - inputs * targets, axis=1) + iou_score = (intersection + epsilon) / (union + epsilon) + expected_loss = 1 - iou_score + expected_loss = mx.mean(expected_loss) + loss = nn.losses.iou_loss(inputs, targets, epsilon=epsilon, reduction="mean") + self.assertAlmostEqual(loss.item(), expected_loss.item(), places=6) + + def test_contrastive_loss(self): + anchors = mx.array([[1, 2], [3, 4]]) + positives = mx.array([[1, 3], [2, 4]]) + negatives = mx.array([[5, 6], [7, 8]]) + margin = 1.0 + p = 2 + positive_distance = mx.sqrt(mx.power(anchors - positives, p).sum(axis=1)) + negative_distance = mx.sqrt(mx.power(anchors - negatives, p).sum(axis=1)) + expected_loss = mx.maximum(positive_distance - negative_distance + margin, 0) + expected_loss = mx.mean(expected_loss) + loss = nn.losses.contrastive_loss(anchors, positives, negatives, margin=margin, p=p, reduction="mean") + self.assertAlmostEqual(loss.item(), expected_loss.item(), places=6) + + def test_tversky_loss(self): + inputs = mx.array([[1, 0, 1, 1], [0, 1, 1, 0]]) + targets = mx.array([[1, 1, 1, 0], [0, 0, 1, 1]]) + alpha = 0.5 + beta = 0.5 + epsilon = 1e-6 + intersection = mx.sum(inputs * targets, axis=1) + false_negatives = mx.sum(inputs * (1 - targets), axis=1) + false_positives = mx.sum((1 - inputs) * targets, axis=1) + tversky_index = (intersection + epsilon) / (intersection + alpha * false_negatives + beta * false_positives + epsilon) + expected_loss = 1 - tversky_index + expected_loss = mx.mean(expected_loss) + loss = nn.losses.tversky_loss(inputs, targets, alpha=alpha, beta=beta, epsilon=epsilon, reduction="mean") + self.assertAlmostEqual(loss.item(), expected_loss.item(), places=6) if __name__ == "__main__":