mirror of
https://github.com/ml-explore/mlx.git
synced 2025-08-22 13:07:51 +08:00
lint fix
This commit is contained in:
parent
15307c5367
commit
e7114f4b91
@ -372,12 +372,13 @@ def log_cosh_loss(
|
|||||||
|
|
||||||
return _reduce(loss, reduction)
|
return _reduce(loss, reduction)
|
||||||
|
|
||||||
|
|
||||||
def focal_loss(
|
def focal_loss(
|
||||||
inputs: mx.array,
|
inputs: mx.array,
|
||||||
targets: mx.array,
|
targets: mx.array,
|
||||||
alpha: float = 0.25,
|
alpha: float = 0.25,
|
||||||
gamma: float = 2.0,
|
gamma: float = 2.0,
|
||||||
reduction: str = "none"
|
reduction: str = "none",
|
||||||
) -> mx.array:
|
) -> mx.array:
|
||||||
r"""
|
r"""
|
||||||
Computes the Focal Loss between inputs and targets, which is designed to address
|
Computes the Focal Loss between inputs and targets, which is designed to address
|
||||||
@ -412,11 +413,9 @@ def focal_loss(
|
|||||||
|
|
||||||
return _reduce(focal_loss, reduction)
|
return _reduce(focal_loss, reduction)
|
||||||
|
|
||||||
|
|
||||||
def dice_loss(
|
def dice_loss(
|
||||||
inputs: mx.array,
|
inputs: mx.array, targets: mx.array, epsilon: float = 1e-6, reduction: str = "none"
|
||||||
targets: mx.array,
|
|
||||||
epsilon: float = 1e-6,
|
|
||||||
reduction: str = "none"
|
|
||||||
) -> mx.array:
|
) -> mx.array:
|
||||||
r"""
|
r"""
|
||||||
Computes the Dice Loss, which is a measure of overlap between two samples.
|
Computes the Dice Loss, which is a measure of overlap between two samples.
|
||||||
@ -438,16 +437,14 @@ def dice_loss(
|
|||||||
"""
|
"""
|
||||||
intersection = mx.sum(inputs * targets, axis=1)
|
intersection = mx.sum(inputs * targets, axis=1)
|
||||||
cardinality = mx.sum(inputs + targets, axis=1)
|
cardinality = mx.sum(inputs + targets, axis=1)
|
||||||
dice_score = (2. * intersection + epsilon) / (cardinality + epsilon)
|
dice_score = (2.0 * intersection + epsilon) / (cardinality + epsilon)
|
||||||
loss = 1 - dice_score
|
loss = 1 - dice_score
|
||||||
|
|
||||||
return _reduce(loss, reduction)
|
return _reduce(loss, reduction)
|
||||||
|
|
||||||
|
|
||||||
def iou_loss(
|
def iou_loss(
|
||||||
inputs: mx.array,
|
inputs: mx.array, targets: mx.array, epsilon: float = 1e-6, reduction: str = "none"
|
||||||
targets: mx.array,
|
|
||||||
epsilon: float = 1e-6,
|
|
||||||
reduction: str = "none"
|
|
||||||
) -> mx.array:
|
) -> mx.array:
|
||||||
r"""
|
r"""
|
||||||
Computes the Intersection over Union (IoU) Loss, which is a measure of the
|
Computes the Intersection over Union (IoU) Loss, which is a measure of the
|
||||||
@ -474,13 +471,14 @@ def iou_loss(
|
|||||||
|
|
||||||
return _reduce(loss, reduction)
|
return _reduce(loss, reduction)
|
||||||
|
|
||||||
|
|
||||||
def contrastive_loss(
|
def contrastive_loss(
|
||||||
anchors: mx.array,
|
anchors: mx.array,
|
||||||
positives: mx.array,
|
positives: mx.array,
|
||||||
negatives: mx.array,
|
negatives: mx.array,
|
||||||
margin: float = 1.0,
|
margin: float = 1.0,
|
||||||
p: int = 2,
|
p: int = 2,
|
||||||
reduction: str = "none"
|
reduction: str = "none",
|
||||||
) -> mx.array:
|
) -> mx.array:
|
||||||
r"""
|
r"""
|
||||||
Computes the Contrastive Loss for a set of anchor, positive, and negative samples.
|
Computes the Contrastive Loss for a set of anchor, positive, and negative samples.
|
||||||
@ -507,13 +505,14 @@ def contrastive_loss(
|
|||||||
|
|
||||||
return _reduce(loss, reduction)
|
return _reduce(loss, reduction)
|
||||||
|
|
||||||
|
|
||||||
def tversky_loss(
|
def tversky_loss(
|
||||||
inputs: mx.array,
|
inputs: mx.array,
|
||||||
targets: mx.array,
|
targets: mx.array,
|
||||||
alpha: float = 0.5,
|
alpha: float = 0.5,
|
||||||
beta: float = 0.5,
|
beta: float = 0.5,
|
||||||
epsilon: float = 1e-6,
|
epsilon: float = 1e-6,
|
||||||
reduction: str = "none"
|
reduction: str = "none",
|
||||||
) -> mx.array:
|
) -> mx.array:
|
||||||
r"""
|
r"""
|
||||||
Computes the Tversky Loss, a generalization of the Dice Loss, allowing more control over false
|
Computes the Tversky Loss, a generalization of the Dice Loss, allowing more control over false
|
||||||
@ -538,7 +537,9 @@ def tversky_loss(
|
|||||||
intersection = mx.sum(inputs * targets, axis=1)
|
intersection = mx.sum(inputs * targets, axis=1)
|
||||||
false_negatives = mx.sum(inputs * (1 - targets), axis=1)
|
false_negatives = mx.sum(inputs * (1 - targets), axis=1)
|
||||||
false_positives = mx.sum((1 - inputs) * targets, axis=1)
|
false_positives = mx.sum((1 - inputs) * targets, axis=1)
|
||||||
tversky_index = (intersection + epsilon) / (intersection + alpha * false_negatives + beta * false_positives + epsilon)
|
tversky_index = (intersection + epsilon) / (
|
||||||
|
intersection + alpha * false_negatives + beta * false_positives + epsilon
|
||||||
|
)
|
||||||
loss = 1 - tversky_index
|
loss = 1 - tversky_index
|
||||||
|
|
||||||
return _reduce(loss, reduction)
|
return _reduce(loss, reduction)
|
||||||
|
@ -801,7 +801,9 @@ class TestNN(mlx_tests.MLXTestCase):
|
|||||||
pt = mx.exp(-ce_loss)
|
pt = mx.exp(-ce_loss)
|
||||||
expected_loss = -alpha * ((1 - pt) ** gamma) * ce_loss
|
expected_loss = -alpha * ((1 - pt) ** gamma) * ce_loss
|
||||||
expected_loss = mx.mean(expected_loss)
|
expected_loss = mx.mean(expected_loss)
|
||||||
loss = nn.losses.focal_loss(inputs, targets, alpha=alpha, gamma=gamma, reduction="mean")
|
loss = nn.losses.focal_loss(
|
||||||
|
inputs, targets, alpha=alpha, gamma=gamma, reduction="mean"
|
||||||
|
)
|
||||||
self.assertAlmostEqual(loss.item(), expected_loss.item(), places=6)
|
self.assertAlmostEqual(loss.item(), expected_loss.item(), places=6)
|
||||||
|
|
||||||
def test_dice_loss(self):
|
def test_dice_loss(self):
|
||||||
@ -810,7 +812,7 @@ class TestNN(mlx_tests.MLXTestCase):
|
|||||||
epsilon = 1e-6
|
epsilon = 1e-6
|
||||||
intersection = mx.sum(inputs * targets, axis=1)
|
intersection = mx.sum(inputs * targets, axis=1)
|
||||||
cardinality = mx.sum(inputs + targets, axis=1)
|
cardinality = mx.sum(inputs + targets, axis=1)
|
||||||
dice_score = (2. * intersection + epsilon) / (cardinality + epsilon)
|
dice_score = (2.0 * intersection + epsilon) / (cardinality + epsilon)
|
||||||
expected_loss = 1 - dice_score
|
expected_loss = 1 - dice_score
|
||||||
expected_loss = mx.mean(expected_loss)
|
expected_loss = mx.mean(expected_loss)
|
||||||
loss = nn.losses.dice_loss(inputs, targets, epsilon=epsilon, reduction="mean")
|
loss = nn.losses.dice_loss(inputs, targets, epsilon=epsilon, reduction="mean")
|
||||||
@ -838,7 +840,9 @@ class TestNN(mlx_tests.MLXTestCase):
|
|||||||
negative_distance = mx.sqrt(mx.power(anchors - negatives, p).sum(axis=1))
|
negative_distance = mx.sqrt(mx.power(anchors - negatives, p).sum(axis=1))
|
||||||
expected_loss = mx.maximum(positive_distance - negative_distance + margin, 0)
|
expected_loss = mx.maximum(positive_distance - negative_distance + margin, 0)
|
||||||
expected_loss = mx.mean(expected_loss)
|
expected_loss = mx.mean(expected_loss)
|
||||||
loss = nn.losses.contrastive_loss(anchors, positives, negatives, margin=margin, p=p, reduction="mean")
|
loss = nn.losses.contrastive_loss(
|
||||||
|
anchors, positives, negatives, margin=margin, p=p, reduction="mean"
|
||||||
|
)
|
||||||
self.assertAlmostEqual(loss.item(), expected_loss.item(), places=6)
|
self.assertAlmostEqual(loss.item(), expected_loss.item(), places=6)
|
||||||
|
|
||||||
def test_tversky_loss(self):
|
def test_tversky_loss(self):
|
||||||
@ -850,10 +854,14 @@ class TestNN(mlx_tests.MLXTestCase):
|
|||||||
intersection = mx.sum(inputs * targets, axis=1)
|
intersection = mx.sum(inputs * targets, axis=1)
|
||||||
false_negatives = mx.sum(inputs * (1 - targets), axis=1)
|
false_negatives = mx.sum(inputs * (1 - targets), axis=1)
|
||||||
false_positives = mx.sum((1 - inputs) * targets, axis=1)
|
false_positives = mx.sum((1 - inputs) * targets, axis=1)
|
||||||
tversky_index = (intersection + epsilon) / (intersection + alpha * false_negatives + beta * false_positives + epsilon)
|
tversky_index = (intersection + epsilon) / (
|
||||||
|
intersection + alpha * false_negatives + beta * false_positives + epsilon
|
||||||
|
)
|
||||||
expected_loss = 1 - tversky_index
|
expected_loss = 1 - tversky_index
|
||||||
expected_loss = mx.mean(expected_loss)
|
expected_loss = mx.mean(expected_loss)
|
||||||
loss = nn.losses.tversky_loss(inputs, targets, alpha=alpha, beta=beta, epsilon=epsilon, reduction="mean")
|
loss = nn.losses.tversky_loss(
|
||||||
|
inputs, targets, alpha=alpha, beta=beta, epsilon=epsilon, reduction="mean"
|
||||||
|
)
|
||||||
self.assertAlmostEqual(loss.item(), expected_loss.item(), places=6)
|
self.assertAlmostEqual(loss.item(), expected_loss.item(), places=6)
|
||||||
|
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user