This commit is contained in:
NripeshN 2023-12-31 01:36:48 +05:30
parent 15307c5367
commit e7114f4b91
2 changed files with 38 additions and 29 deletions

View File

@ -372,12 +372,13 @@ def log_cosh_loss(
return _reduce(loss, reduction) return _reduce(loss, reduction)
def focal_loss( def focal_loss(
inputs: mx.array, inputs: mx.array,
targets: mx.array, targets: mx.array,
alpha: float = 0.25, alpha: float = 0.25,
gamma: float = 2.0, gamma: float = 2.0,
reduction: str = "none" reduction: str = "none",
) -> mx.array: ) -> mx.array:
r""" r"""
Computes the Focal Loss between inputs and targets, which is designed to address Computes the Focal Loss between inputs and targets, which is designed to address
@ -400,26 +401,24 @@ def focal_loss(
""" """
if gamma < 0: if gamma < 0:
raise ValueError(f"Focal loss gamma must be non-negative, got {gamma}.") raise ValueError(f"Focal loss gamma must be non-negative, got {gamma}.")
# Calculating the cross-entropy loss # Calculating the cross-entropy loss
ce_loss = mx.logaddexp(0.0, inputs) - targets * inputs ce_loss = mx.logaddexp(0.0, inputs) - targets * inputs
# Calculating the probability # Calculating the probability
pt = mx.exp(-ce_loss) pt = mx.exp(-ce_loss)
# Calculating Focal Loss # Calculating Focal Loss
focal_loss = -alpha * ((1 - pt) ** gamma) * ce_loss focal_loss = -alpha * ((1 - pt) ** gamma) * ce_loss
return _reduce(focal_loss, reduction) return _reduce(focal_loss, reduction)
def dice_loss( def dice_loss(
inputs: mx.array, inputs: mx.array, targets: mx.array, epsilon: float = 1e-6, reduction: str = "none"
targets: mx.array,
epsilon: float = 1e-6,
reduction: str = "none"
) -> mx.array: ) -> mx.array:
r""" r"""
Computes the Dice Loss, which is a measure of overlap between two samples. Computes the Dice Loss, which is a measure of overlap between two samples.
This loss is commonly used for binary segmentation tasks. This loss is commonly used for binary segmentation tasks.
.. math:: .. math::
@ -438,19 +437,17 @@ def dice_loss(
""" """
intersection = mx.sum(inputs * targets, axis=1) intersection = mx.sum(inputs * targets, axis=1)
cardinality = mx.sum(inputs + targets, axis=1) cardinality = mx.sum(inputs + targets, axis=1)
dice_score = (2. * intersection + epsilon) / (cardinality + epsilon) dice_score = (2.0 * intersection + epsilon) / (cardinality + epsilon)
loss = 1 - dice_score loss = 1 - dice_score
return _reduce(loss, reduction) return _reduce(loss, reduction)
def iou_loss( def iou_loss(
inputs: mx.array, inputs: mx.array, targets: mx.array, epsilon: float = 1e-6, reduction: str = "none"
targets: mx.array,
epsilon: float = 1e-6,
reduction: str = "none"
) -> mx.array: ) -> mx.array:
r""" r"""
Computes the Intersection over Union (IoU) Loss, which is a measure of the Computes the Intersection over Union (IoU) Loss, which is a measure of the
overlap between two sets, typically used in segmentation tasks. overlap between two sets, typically used in segmentation tasks.
.. math:: .. math::
@ -474,13 +471,14 @@ def iou_loss(
return _reduce(loss, reduction) return _reduce(loss, reduction)
def contrastive_loss( def contrastive_loss(
anchors: mx.array, anchors: mx.array,
positives: mx.array, positives: mx.array,
negatives: mx.array, negatives: mx.array,
margin: float = 1.0, margin: float = 1.0,
p: int = 2, p: int = 2,
reduction: str = "none" reduction: str = "none",
) -> mx.array: ) -> mx.array:
r""" r"""
Computes the Contrastive Loss for a set of anchor, positive, and negative samples. Computes the Contrastive Loss for a set of anchor, positive, and negative samples.
@ -504,16 +502,17 @@ def contrastive_loss(
positive_distance = mx.sqrt(mx.power(anchors - positives, p).sum(axis=1)) positive_distance = mx.sqrt(mx.power(anchors - positives, p).sum(axis=1))
negative_distance = mx.sqrt(mx.power(anchors - negatives, p).sum(axis=1)) negative_distance = mx.sqrt(mx.power(anchors - negatives, p).sum(axis=1))
loss = mx.maximum(positive_distance - negative_distance + margin, 0) loss = mx.maximum(positive_distance - negative_distance + margin, 0)
return _reduce(loss, reduction) return _reduce(loss, reduction)
def tversky_loss( def tversky_loss(
inputs: mx.array, inputs: mx.array,
targets: mx.array, targets: mx.array,
alpha: float = 0.5, alpha: float = 0.5,
beta: float = 0.5, beta: float = 0.5,
epsilon: float = 1e-6, epsilon: float = 1e-6,
reduction: str = "none" reduction: str = "none",
) -> mx.array: ) -> mx.array:
r""" r"""
Computes the Tversky Loss, a generalization of the Dice Loss, allowing more control over false Computes the Tversky Loss, a generalization of the Dice Loss, allowing more control over false
@ -538,7 +537,9 @@ def tversky_loss(
intersection = mx.sum(inputs * targets, axis=1) intersection = mx.sum(inputs * targets, axis=1)
false_negatives = mx.sum(inputs * (1 - targets), axis=1) false_negatives = mx.sum(inputs * (1 - targets), axis=1)
false_positives = mx.sum((1 - inputs) * targets, axis=1) false_positives = mx.sum((1 - inputs) * targets, axis=1)
tversky_index = (intersection + epsilon) / (intersection + alpha * false_negatives + beta * false_positives + epsilon) tversky_index = (intersection + epsilon) / (
intersection + alpha * false_negatives + beta * false_positives + epsilon
)
loss = 1 - tversky_index loss = 1 - tversky_index
return _reduce(loss, reduction) return _reduce(loss, reduction)

View File

@ -791,7 +791,7 @@ class TestNN(mlx_tests.MLXTestCase):
targets = mx.zeros((2, 4)) targets = mx.zeros((2, 4))
loss = nn.losses.log_cosh_loss(inputs, targets, reduction="mean") loss = nn.losses.log_cosh_loss(inputs, targets, reduction="mean")
self.assertAlmostEqual(loss.item(), 0.433781, places=6) self.assertAlmostEqual(loss.item(), 0.433781, places=6)
def test_focal_loss(self): def test_focal_loss(self):
inputs = mx.array([[2.0, -1.0, 3.0, 0.1], [-1.0, 2.0, -0.5, 0.2]]) inputs = mx.array([[2.0, -1.0, 3.0, 0.1], [-1.0, 2.0, -0.5, 0.2]])
targets = mx.array([[1.0, 0.0, 0.0, 1.0], [0.0, 1.0, 0.0, 1.0]]) targets = mx.array([[1.0, 0.0, 0.0, 1.0], [0.0, 1.0, 0.0, 1.0]])
@ -801,16 +801,18 @@ class TestNN(mlx_tests.MLXTestCase):
pt = mx.exp(-ce_loss) pt = mx.exp(-ce_loss)
expected_loss = -alpha * ((1 - pt) ** gamma) * ce_loss expected_loss = -alpha * ((1 - pt) ** gamma) * ce_loss
expected_loss = mx.mean(expected_loss) expected_loss = mx.mean(expected_loss)
loss = nn.losses.focal_loss(inputs, targets, alpha=alpha, gamma=gamma, reduction="mean") loss = nn.losses.focal_loss(
inputs, targets, alpha=alpha, gamma=gamma, reduction="mean"
)
self.assertAlmostEqual(loss.item(), expected_loss.item(), places=6) self.assertAlmostEqual(loss.item(), expected_loss.item(), places=6)
def test_dice_loss(self): def test_dice_loss(self):
inputs = mx.array([[1, 0, 1, 1], [0, 1, 1, 0]]) inputs = mx.array([[1, 0, 1, 1], [0, 1, 1, 0]])
targets = mx.array([[1, 1, 1, 0], [0, 0, 1, 1]]) targets = mx.array([[1, 1, 1, 0], [0, 0, 1, 1]])
epsilon = 1e-6 epsilon = 1e-6
intersection = mx.sum(inputs * targets, axis=1) intersection = mx.sum(inputs * targets, axis=1)
cardinality = mx.sum(inputs + targets, axis=1) cardinality = mx.sum(inputs + targets, axis=1)
dice_score = (2. * intersection + epsilon) / (cardinality + epsilon) dice_score = (2.0 * intersection + epsilon) / (cardinality + epsilon)
expected_loss = 1 - dice_score expected_loss = 1 - dice_score
expected_loss = mx.mean(expected_loss) expected_loss = mx.mean(expected_loss)
loss = nn.losses.dice_loss(inputs, targets, epsilon=epsilon, reduction="mean") loss = nn.losses.dice_loss(inputs, targets, epsilon=epsilon, reduction="mean")
@ -827,7 +829,7 @@ class TestNN(mlx_tests.MLXTestCase):
expected_loss = mx.mean(expected_loss) expected_loss = mx.mean(expected_loss)
loss = nn.losses.iou_loss(inputs, targets, epsilon=epsilon, reduction="mean") loss = nn.losses.iou_loss(inputs, targets, epsilon=epsilon, reduction="mean")
self.assertAlmostEqual(loss.item(), expected_loss.item(), places=6) self.assertAlmostEqual(loss.item(), expected_loss.item(), places=6)
def test_contrastive_loss(self): def test_contrastive_loss(self):
anchors = mx.array([[1, 2], [3, 4]]) anchors = mx.array([[1, 2], [3, 4]])
positives = mx.array([[1, 3], [2, 4]]) positives = mx.array([[1, 3], [2, 4]])
@ -838,9 +840,11 @@ class TestNN(mlx_tests.MLXTestCase):
negative_distance = mx.sqrt(mx.power(anchors - negatives, p).sum(axis=1)) negative_distance = mx.sqrt(mx.power(anchors - negatives, p).sum(axis=1))
expected_loss = mx.maximum(positive_distance - negative_distance + margin, 0) expected_loss = mx.maximum(positive_distance - negative_distance + margin, 0)
expected_loss = mx.mean(expected_loss) expected_loss = mx.mean(expected_loss)
loss = nn.losses.contrastive_loss(anchors, positives, negatives, margin=margin, p=p, reduction="mean") loss = nn.losses.contrastive_loss(
anchors, positives, negatives, margin=margin, p=p, reduction="mean"
)
self.assertAlmostEqual(loss.item(), expected_loss.item(), places=6) self.assertAlmostEqual(loss.item(), expected_loss.item(), places=6)
def test_tversky_loss(self): def test_tversky_loss(self):
inputs = mx.array([[1, 0, 1, 1], [0, 1, 1, 0]]) inputs = mx.array([[1, 0, 1, 1], [0, 1, 1, 0]])
targets = mx.array([[1, 1, 1, 0], [0, 0, 1, 1]]) targets = mx.array([[1, 1, 1, 0], [0, 0, 1, 1]])
@ -850,10 +854,14 @@ class TestNN(mlx_tests.MLXTestCase):
intersection = mx.sum(inputs * targets, axis=1) intersection = mx.sum(inputs * targets, axis=1)
false_negatives = mx.sum(inputs * (1 - targets), axis=1) false_negatives = mx.sum(inputs * (1 - targets), axis=1)
false_positives = mx.sum((1 - inputs) * targets, axis=1) false_positives = mx.sum((1 - inputs) * targets, axis=1)
tversky_index = (intersection + epsilon) / (intersection + alpha * false_negatives + beta * false_positives + epsilon) tversky_index = (intersection + epsilon) / (
intersection + alpha * false_negatives + beta * false_positives + epsilon
)
expected_loss = 1 - tversky_index expected_loss = 1 - tversky_index
expected_loss = mx.mean(expected_loss) expected_loss = mx.mean(expected_loss)
loss = nn.losses.tversky_loss(inputs, targets, alpha=alpha, beta=beta, epsilon=epsilon, reduction="mean") loss = nn.losses.tversky_loss(
inputs, targets, alpha=alpha, beta=beta, epsilon=epsilon, reduction="mean"
)
self.assertAlmostEqual(loss.item(), expected_loss.item(), places=6) self.assertAlmostEqual(loss.item(), expected_loss.item(), places=6)