mirror of
https://github.com/ml-explore/mlx.git
synced 2025-08-22 13:07:51 +08:00
change after all tests
This commit is contained in:
parent
80b4d84d90
commit
19bc6a391a
@ -259,8 +259,8 @@ def dice_loss(inputs: mx.array, targets: mx.array, eps: float = 1e-6, reduction:
|
|||||||
Returns:
|
Returns:
|
||||||
mx.array: The computed Dice loss.
|
mx.array: The computed Dice loss.
|
||||||
"""
|
"""
|
||||||
intersection = mx.sum(inputs * targets, axis=-1)
|
intersection = mx.sum(inputs * targets, axis=1) # Sum over the feature dimension
|
||||||
union = mx.sum(inputs, axis=-1) + mx.sum(targets, axis=-1) - intersection
|
union = mx.sum(inputs, axis=1) + mx.sum(targets, axis=1)
|
||||||
dice_score = (2. * intersection + eps) / (union + eps)
|
dice_score = (2. * intersection + eps) / (union + eps)
|
||||||
loss = 1 - dice_score
|
loss = 1 - dice_score
|
||||||
return _reduce(loss, reduction)
|
return _reduce(loss, reduction)
|
||||||
@ -280,10 +280,10 @@ def focal_loss(inputs: mx.array, targets: mx.array, alpha: float = 0.25, gamma:
|
|||||||
Returns:
|
Returns:
|
||||||
mx.array: The computed Focal loss.
|
mx.array: The computed Focal loss.
|
||||||
"""
|
"""
|
||||||
BCE_loss = binary_cross_entropy(inputs, targets, reduction)
|
BCE_loss = binary_cross_entropy(inputs, targets)
|
||||||
pt = mx.exp(-BCE_loss)
|
pt = mx.exp(-BCE_loss)
|
||||||
loss = alpha * (1 - pt) ** gamma * BCE_loss
|
loss = alpha * (1 - pt) ** gamma * BCE_loss
|
||||||
return loss
|
return _reduce(loss, reduction)
|
||||||
|
|
||||||
|
|
||||||
def contrastive_loss(embeddings1: mx.array, embeddings2: mx.array, targets: mx.array, margin: float = 1.0, reduction: str = "none") -> mx.array:
|
def contrastive_loss(embeddings1: mx.array, embeddings2: mx.array, targets: mx.array, margin: float = 1.0, reduction: str = "none") -> mx.array:
|
||||||
@ -327,55 +327,3 @@ def cosine_similarity_loss(embeddings1: mx.array, embeddings2: mx.array, targets
|
|||||||
cos_similarity = mx.sum(embeddings1 * embeddings2, axis=1) / (embeddings1_norm * embeddings2_norm)
|
cos_similarity = mx.sum(embeddings1 * embeddings2, axis=1) / (embeddings1_norm * embeddings2_norm)
|
||||||
loss = mx.where(targets == 1, 1 - cos_similarity, mx.maximum(0, cos_similarity - margin))
|
loss = mx.where(targets == 1, 1 - cos_similarity, mx.maximum(0, cos_similarity - margin))
|
||||||
return _reduce(loss, reduction)
|
return _reduce(loss, reduction)
|
||||||
|
|
||||||
def test_losses():
|
|
||||||
# Hinge Loss Test
|
|
||||||
predictions = mx.array([0.8, -1.5])
|
|
||||||
targets = mx.array([1, -1])
|
|
||||||
print("Hinge Loss:", hinge_loss(predictions, targets))
|
|
||||||
# Expected Result: [0.2, 0] v
|
|
||||||
|
|
||||||
# Huber Loss Test
|
|
||||||
predictions = mx.array([1.5, 0.5])
|
|
||||||
targets = mx.array([1, 0])
|
|
||||||
delta = 1.0
|
|
||||||
print("Huber Loss:", huber_loss(predictions, targets, delta))
|
|
||||||
# Expected Result: [0.125, 0.125] v
|
|
||||||
|
|
||||||
# Dice Loss Test
|
|
||||||
inputs = mx.array([0.7, 0.3])
|
|
||||||
targets = mx.array([1, 0])
|
|
||||||
print("Dice Loss:", dice_loss(inputs, targets))
|
|
||||||
# Expected Result: [0.42857143] ([0.1765, 1.0000])
|
|
||||||
|
|
||||||
# Focal Loss Test
|
|
||||||
inputs = mx.array([0.9, 0.1])
|
|
||||||
targets = mx.array([1, 0])
|
|
||||||
alpha = 0.25
|
|
||||||
gamma = 2.0
|
|
||||||
print("Focal Loss:", focal_loss(inputs, targets, alpha, gamma))
|
|
||||||
# Expected Result: [0.002025, 0.2304]
|
|
||||||
|
|
||||||
# Contrastive Loss Test
|
|
||||||
embeddings1 = mx.array([[1, 2], [3, 4]])
|
|
||||||
embeddings2 = mx.array([[2, 3], [4, 5]])
|
|
||||||
targets = mx.array([1, 0])
|
|
||||||
margin = 1.0
|
|
||||||
print("Contrastive Loss:", contrastive_loss(embeddings1, embeddings2, targets, margin))
|
|
||||||
# Expected Result: [1.4142135, 0.0] v
|
|
||||||
|
|
||||||
# Cosine Similarity Loss Test
|
|
||||||
embeddings1 = mx.array([[1, 0], [0, 1]])
|
|
||||||
embeddings2 = mx.array([[0, 1], [1, 0]])
|
|
||||||
targets = mx.array([1, -1])
|
|
||||||
print("Cosine Similarity Loss:", cosine_similarity_loss(embeddings1, embeddings2, targets))
|
|
||||||
# Expected Result: [1, 0]
|
|
||||||
|
|
||||||
# Run the tests
|
|
||||||
test_losses()
|
|
||||||
# Hinge Loss: tensor(0.1000)
|
|
||||||
# Huber Loss: tensor([0.1250, 0.1250])
|
|
||||||
# Dice Loss: tensor([0.1765, 1.0000])
|
|
||||||
# Focal Loss: tensor([0.0003, 0.0003])
|
|
||||||
# Contrastive Loss: tensor([0.7071, 0.0000])
|
|
||||||
# Cosine Similarity Loss: tensor([1., 0.])
|
|
@ -1,65 +0,0 @@
|
|||||||
import torch
|
|
||||||
import torch.nn.functional as F
|
|
||||||
|
|
||||||
import torch
|
|
||||||
import torch.nn as nn
|
|
||||||
import torch.nn.functional as F
|
|
||||||
|
|
||||||
# Hinge Loss (Custom)
|
|
||||||
class HingeLoss(nn.Module):
|
|
||||||
def forward(self, predictions, targets):
|
|
||||||
return torch.mean(torch.clamp(1 - predictions * targets, min=0))
|
|
||||||
|
|
||||||
# Dice Loss (Custom)
|
|
||||||
class DiceLoss(nn.Module):
|
|
||||||
def forward(self, inputs, targets, epsilon=1e-6):
|
|
||||||
intersection = inputs * targets
|
|
||||||
union = inputs + targets
|
|
||||||
dice_score = (2. * intersection + epsilon) / (union + epsilon)
|
|
||||||
return 1 - dice_score
|
|
||||||
|
|
||||||
def focal_loss(inputs, targets, alpha=0.25, gamma=2.0):
|
|
||||||
BCE_loss = F.binary_cross_entropy(inputs, targets, reduction='none')
|
|
||||||
pt = torch.exp(-BCE_loss)
|
|
||||||
F_loss = alpha * (1 - pt) ** gamma * BCE_loss
|
|
||||||
return F_loss
|
|
||||||
|
|
||||||
def contrastive_loss(embeddings1, embeddings2, targets, margin=1.0):
|
|
||||||
distances = F.pairwise_distance(embeddings1, embeddings2)
|
|
||||||
loss = 0.5 * (targets * distances + (1 - targets) * F.relu(margin - distances))
|
|
||||||
return loss
|
|
||||||
|
|
||||||
# Test cases
|
|
||||||
def test_losses():
|
|
||||||
hinge_loss = HingeLoss()
|
|
||||||
huber_loss = nn.SmoothL1Loss(reduction='none')
|
|
||||||
dice_loss = DiceLoss()
|
|
||||||
cosine_similarity_loss = nn.CosineEmbeddingLoss(reduction='none')
|
|
||||||
|
|
||||||
predictions = torch.tensor([0.8, -1.5])
|
|
||||||
targets = torch.tensor([1, -1])
|
|
||||||
print("Hinge Loss:", hinge_loss(predictions, targets))
|
|
||||||
|
|
||||||
predictions = torch.tensor([1.5, 0.5])
|
|
||||||
targets = torch.tensor([1, 0])
|
|
||||||
print("Huber Loss:", huber_loss(predictions, targets))
|
|
||||||
|
|
||||||
inputs = torch.tensor([0.7, 0.3])
|
|
||||||
targets = torch.tensor([1, 0])
|
|
||||||
print("Dice Loss:", dice_loss(inputs, targets))
|
|
||||||
|
|
||||||
inputs = torch.tensor([0.9, 0.1], dtype=torch.float32)
|
|
||||||
targets = torch.tensor([1, 0], dtype=torch.float32)
|
|
||||||
print("Focal Loss:", focal_loss(inputs, targets))
|
|
||||||
|
|
||||||
embeddings1 = torch.tensor([[1, 2], [3, 4]], dtype=torch.float)
|
|
||||||
embeddings2 = torch.tensor([[2, 3], [4, 5]], dtype=torch.float)
|
|
||||||
targets = torch.tensor([1, 0], dtype=torch.float)
|
|
||||||
print("Contrastive Loss:", contrastive_loss(embeddings1, embeddings2, targets))
|
|
||||||
|
|
||||||
embeddings1 = torch.tensor([[1, 0], [0, 1]], dtype=torch.float)
|
|
||||||
embeddings2 = torch.tensor([[0, 1], [1, 0]], dtype=torch.float)
|
|
||||||
targets = torch.tensor([1, -1], dtype=torch.float)
|
|
||||||
print("Cosine Similarity Loss:", cosine_similarity_loss(embeddings1, embeddings2, targets))
|
|
||||||
|
|
||||||
test_losses()
|
|
Loading…
Reference in New Issue
Block a user