diff --git a/python/mlx/nn/losses.py b/python/mlx/nn/losses.py index c3e1a86fa..b495bc61a 100644 --- a/python/mlx/nn/losses.py +++ b/python/mlx/nn/losses.py @@ -372,12 +372,13 @@ def log_cosh_loss( return _reduce(loss, reduction) + def focal_loss( inputs: mx.array, targets: mx.array, alpha: float = 0.25, gamma: float = 2.0, - reduction: str = "none" + reduction: str = "none", ) -> mx.array: r""" Computes the Focal Loss between inputs and targets, which is designed to address @@ -400,26 +401,24 @@ def focal_loss( """ if gamma < 0: raise ValueError(f"Focal loss gamma must be non-negative, got {gamma}.") - + # Calculating the cross-entropy loss ce_loss = mx.logaddexp(0.0, inputs) - targets * inputs - + # Calculating the probability pt = mx.exp(-ce_loss) - + # Calculating Focal Loss focal_loss = -alpha * ((1 - pt) ** gamma) * ce_loss - + return _reduce(focal_loss, reduction) + def dice_loss( - inputs: mx.array, - targets: mx.array, - epsilon: float = 1e-6, - reduction: str = "none" + inputs: mx.array, targets: mx.array, epsilon: float = 1e-6, reduction: str = "none" ) -> mx.array: r""" - Computes the Dice Loss, which is a measure of overlap between two samples. + Computes the Dice Loss, which is a measure of overlap between two samples. This loss is commonly used for binary segmentation tasks. .. math:: @@ -438,19 +437,17 @@ def dice_loss( """ intersection = mx.sum(inputs * targets, axis=1) cardinality = mx.sum(inputs + targets, axis=1) - dice_score = (2. * intersection + epsilon) / (cardinality + epsilon) + dice_score = (2.0 * intersection + epsilon) / (cardinality + epsilon) loss = 1 - dice_score return _reduce(loss, reduction) + def iou_loss( - inputs: mx.array, - targets: mx.array, - epsilon: float = 1e-6, - reduction: str = "none" + inputs: mx.array, targets: mx.array, epsilon: float = 1e-6, reduction: str = "none" ) -> mx.array: r""" - Computes the Intersection over Union (IoU) Loss, which is a measure of the + Computes the Intersection over Union (IoU) Loss, which is a measure of the overlap between two sets, typically used in segmentation tasks. .. math:: @@ -474,13 +471,14 @@ def iou_loss( return _reduce(loss, reduction) + def contrastive_loss( anchors: mx.array, positives: mx.array, negatives: mx.array, margin: float = 1.0, p: int = 2, - reduction: str = "none" + reduction: str = "none", ) -> mx.array: r""" Computes the Contrastive Loss for a set of anchor, positive, and negative samples. @@ -504,16 +502,17 @@ def contrastive_loss( positive_distance = mx.sqrt(mx.power(anchors - positives, p).sum(axis=1)) negative_distance = mx.sqrt(mx.power(anchors - negatives, p).sum(axis=1)) loss = mx.maximum(positive_distance - negative_distance + margin, 0) - + return _reduce(loss, reduction) + def tversky_loss( inputs: mx.array, targets: mx.array, alpha: float = 0.5, beta: float = 0.5, epsilon: float = 1e-6, - reduction: str = "none" + reduction: str = "none", ) -> mx.array: r""" Computes the Tversky Loss, a generalization of the Dice Loss, allowing more control over false @@ -538,7 +537,9 @@ def tversky_loss( intersection = mx.sum(inputs * targets, axis=1) false_negatives = mx.sum(inputs * (1 - targets), axis=1) false_positives = mx.sum((1 - inputs) * targets, axis=1) - tversky_index = (intersection + epsilon) / (intersection + alpha * false_negatives + beta * false_positives + epsilon) + tversky_index = (intersection + epsilon) / ( + intersection + alpha * false_negatives + beta * false_positives + epsilon + ) loss = 1 - tversky_index return _reduce(loss, reduction) diff --git a/python/tests/test_nn.py b/python/tests/test_nn.py index 2be9a4c65..03c0a2535 100644 --- a/python/tests/test_nn.py +++ b/python/tests/test_nn.py @@ -791,7 +791,7 @@ class TestNN(mlx_tests.MLXTestCase): targets = mx.zeros((2, 4)) loss = nn.losses.log_cosh_loss(inputs, targets, reduction="mean") self.assertAlmostEqual(loss.item(), 0.433781, places=6) - + def test_focal_loss(self): inputs = mx.array([[2.0, -1.0, 3.0, 0.1], [-1.0, 2.0, -0.5, 0.2]]) targets = mx.array([[1.0, 0.0, 0.0, 1.0], [0.0, 1.0, 0.0, 1.0]]) @@ -801,16 +801,18 @@ class TestNN(mlx_tests.MLXTestCase): pt = mx.exp(-ce_loss) expected_loss = -alpha * ((1 - pt) ** gamma) * ce_loss expected_loss = mx.mean(expected_loss) - loss = nn.losses.focal_loss(inputs, targets, alpha=alpha, gamma=gamma, reduction="mean") + loss = nn.losses.focal_loss( + inputs, targets, alpha=alpha, gamma=gamma, reduction="mean" + ) self.assertAlmostEqual(loss.item(), expected_loss.item(), places=6) - + def test_dice_loss(self): inputs = mx.array([[1, 0, 1, 1], [0, 1, 1, 0]]) targets = mx.array([[1, 1, 1, 0], [0, 0, 1, 1]]) epsilon = 1e-6 intersection = mx.sum(inputs * targets, axis=1) cardinality = mx.sum(inputs + targets, axis=1) - dice_score = (2. * intersection + epsilon) / (cardinality + epsilon) + dice_score = (2.0 * intersection + epsilon) / (cardinality + epsilon) expected_loss = 1 - dice_score expected_loss = mx.mean(expected_loss) loss = nn.losses.dice_loss(inputs, targets, epsilon=epsilon, reduction="mean") @@ -827,7 +829,7 @@ class TestNN(mlx_tests.MLXTestCase): expected_loss = mx.mean(expected_loss) loss = nn.losses.iou_loss(inputs, targets, epsilon=epsilon, reduction="mean") self.assertAlmostEqual(loss.item(), expected_loss.item(), places=6) - + def test_contrastive_loss(self): anchors = mx.array([[1, 2], [3, 4]]) positives = mx.array([[1, 3], [2, 4]]) @@ -838,9 +840,11 @@ class TestNN(mlx_tests.MLXTestCase): negative_distance = mx.sqrt(mx.power(anchors - negatives, p).sum(axis=1)) expected_loss = mx.maximum(positive_distance - negative_distance + margin, 0) expected_loss = mx.mean(expected_loss) - loss = nn.losses.contrastive_loss(anchors, positives, negatives, margin=margin, p=p, reduction="mean") + loss = nn.losses.contrastive_loss( + anchors, positives, negatives, margin=margin, p=p, reduction="mean" + ) self.assertAlmostEqual(loss.item(), expected_loss.item(), places=6) - + def test_tversky_loss(self): inputs = mx.array([[1, 0, 1, 1], [0, 1, 1, 0]]) targets = mx.array([[1, 1, 1, 0], [0, 0, 1, 1]]) @@ -850,10 +854,14 @@ class TestNN(mlx_tests.MLXTestCase): intersection = mx.sum(inputs * targets, axis=1) false_negatives = mx.sum(inputs * (1 - targets), axis=1) false_positives = mx.sum((1 - inputs) * targets, axis=1) - tversky_index = (intersection + epsilon) / (intersection + alpha * false_negatives + beta * false_positives + epsilon) + tversky_index = (intersection + epsilon) / ( + intersection + alpha * false_negatives + beta * false_positives + epsilon + ) expected_loss = 1 - tversky_index expected_loss = mx.mean(expected_loss) - loss = nn.losses.tversky_loss(inputs, targets, alpha=alpha, beta=beta, epsilon=epsilon, reduction="mean") + loss = nn.losses.tversky_loss( + inputs, targets, alpha=alpha, beta=beta, epsilon=epsilon, reduction="mean" + ) self.assertAlmostEqual(loss.item(), expected_loss.item(), places=6)