From 19bc6a391afaa6021718c5fbd74772b7c21debbe Mon Sep 17 00:00:00 2001 From: Jyun1998 Date: Tue, 2 Jan 2024 02:51:40 +0900 Subject: [PATCH] change after all tests --- python/mlx/nn/losses.py | 62 +++--------------------------- python/mlx/nn/test_loss_torch.py | 65 -------------------------------- 2 files changed, 5 insertions(+), 122 deletions(-) delete mode 100644 python/mlx/nn/test_loss_torch.py diff --git a/python/mlx/nn/losses.py b/python/mlx/nn/losses.py index 60f5bd12b..d395a0f04 100644 --- a/python/mlx/nn/losses.py +++ b/python/mlx/nn/losses.py @@ -259,8 +259,8 @@ def dice_loss(inputs: mx.array, targets: mx.array, eps: float = 1e-6, reduction: Returns: mx.array: The computed Dice loss. """ - intersection = mx.sum(inputs * targets, axis=-1) - union = mx.sum(inputs, axis=-1) + mx.sum(targets, axis=-1) - intersection + intersection = mx.sum(inputs * targets, axis=1) # Sum over the feature dimension + union = mx.sum(inputs, axis=1) + mx.sum(targets, axis=1) dice_score = (2. * intersection + eps) / (union + eps) loss = 1 - dice_score return _reduce(loss, reduction) @@ -280,10 +280,10 @@ def focal_loss(inputs: mx.array, targets: mx.array, alpha: float = 0.25, gamma: Returns: mx.array: The computed Focal loss. """ - BCE_loss = binary_cross_entropy(inputs, targets, reduction) + BCE_loss = binary_cross_entropy(inputs, targets) pt = mx.exp(-BCE_loss) loss = alpha * (1 - pt) ** gamma * BCE_loss - return loss + return _reduce(loss, reduction) def contrastive_loss(embeddings1: mx.array, embeddings2: mx.array, targets: mx.array, margin: float = 1.0, reduction: str = "none") -> mx.array: @@ -326,56 +326,4 @@ def cosine_similarity_loss(embeddings1: mx.array, embeddings2: mx.array, targets cos_similarity = mx.sum(embeddings1 * embeddings2, axis=1) / (embeddings1_norm * embeddings2_norm) loss = mx.where(targets == 1, 1 - cos_similarity, mx.maximum(0, cos_similarity - margin)) - return _reduce(loss, reduction) - -def test_losses(): - # Hinge Loss Test - predictions = mx.array([0.8, -1.5]) - targets = mx.array([1, -1]) - print("Hinge Loss:", hinge_loss(predictions, targets)) - # Expected Result: [0.2, 0] v - - # Huber Loss Test - predictions = mx.array([1.5, 0.5]) - targets = mx.array([1, 0]) - delta = 1.0 - print("Huber Loss:", huber_loss(predictions, targets, delta)) - # Expected Result: [0.125, 0.125] v - - # Dice Loss Test - inputs = mx.array([0.7, 0.3]) - targets = mx.array([1, 0]) - print("Dice Loss:", dice_loss(inputs, targets)) - # Expected Result: [0.42857143] ([0.1765, 1.0000]) - - # Focal Loss Test - inputs = mx.array([0.9, 0.1]) - targets = mx.array([1, 0]) - alpha = 0.25 - gamma = 2.0 - print("Focal Loss:", focal_loss(inputs, targets, alpha, gamma)) - # Expected Result: [0.002025, 0.2304] - - # Contrastive Loss Test - embeddings1 = mx.array([[1, 2], [3, 4]]) - embeddings2 = mx.array([[2, 3], [4, 5]]) - targets = mx.array([1, 0]) - margin = 1.0 - print("Contrastive Loss:", contrastive_loss(embeddings1, embeddings2, targets, margin)) - # Expected Result: [1.4142135, 0.0] v - - # Cosine Similarity Loss Test - embeddings1 = mx.array([[1, 0], [0, 1]]) - embeddings2 = mx.array([[0, 1], [1, 0]]) - targets = mx.array([1, -1]) - print("Cosine Similarity Loss:", cosine_similarity_loss(embeddings1, embeddings2, targets)) - # Expected Result: [1, 0] - -# Run the tests -test_losses() -# Hinge Loss: tensor(0.1000) -# Huber Loss: tensor([0.1250, 0.1250]) -# Dice Loss: tensor([0.1765, 1.0000]) -# Focal Loss: tensor([0.0003, 0.0003]) -# Contrastive Loss: tensor([0.7071, 0.0000]) -# Cosine Similarity Loss: tensor([1., 0.]) \ No newline at end of file + return _reduce(loss, reduction) \ No newline at end of file diff --git a/python/mlx/nn/test_loss_torch.py b/python/mlx/nn/test_loss_torch.py deleted file mode 100644 index b0e7185b1..000000000 --- a/python/mlx/nn/test_loss_torch.py +++ /dev/null @@ -1,65 +0,0 @@ -import torch -import torch.nn.functional as F - -import torch -import torch.nn as nn -import torch.nn.functional as F - -# Hinge Loss (Custom) -class HingeLoss(nn.Module): - def forward(self, predictions, targets): - return torch.mean(torch.clamp(1 - predictions * targets, min=0)) - -# Dice Loss (Custom) -class DiceLoss(nn.Module): - def forward(self, inputs, targets, epsilon=1e-6): - intersection = inputs * targets - union = inputs + targets - dice_score = (2. * intersection + epsilon) / (union + epsilon) - return 1 - dice_score - -def focal_loss(inputs, targets, alpha=0.25, gamma=2.0): - BCE_loss = F.binary_cross_entropy(inputs, targets, reduction='none') - pt = torch.exp(-BCE_loss) - F_loss = alpha * (1 - pt) ** gamma * BCE_loss - return F_loss - -def contrastive_loss(embeddings1, embeddings2, targets, margin=1.0): - distances = F.pairwise_distance(embeddings1, embeddings2) - loss = 0.5 * (targets * distances + (1 - targets) * F.relu(margin - distances)) - return loss - -# Test cases -def test_losses(): - hinge_loss = HingeLoss() - huber_loss = nn.SmoothL1Loss(reduction='none') - dice_loss = DiceLoss() - cosine_similarity_loss = nn.CosineEmbeddingLoss(reduction='none') - - predictions = torch.tensor([0.8, -1.5]) - targets = torch.tensor([1, -1]) - print("Hinge Loss:", hinge_loss(predictions, targets)) - - predictions = torch.tensor([1.5, 0.5]) - targets = torch.tensor([1, 0]) - print("Huber Loss:", huber_loss(predictions, targets)) - - inputs = torch.tensor([0.7, 0.3]) - targets = torch.tensor([1, 0]) - print("Dice Loss:", dice_loss(inputs, targets)) - - inputs = torch.tensor([0.9, 0.1], dtype=torch.float32) - targets = torch.tensor([1, 0], dtype=torch.float32) - print("Focal Loss:", focal_loss(inputs, targets)) - - embeddings1 = torch.tensor([[1, 2], [3, 4]], dtype=torch.float) - embeddings2 = torch.tensor([[2, 3], [4, 5]], dtype=torch.float) - targets = torch.tensor([1, 0], dtype=torch.float) - print("Contrastive Loss:", contrastive_loss(embeddings1, embeddings2, targets)) - - embeddings1 = torch.tensor([[1, 0], [0, 1]], dtype=torch.float) - embeddings2 = torch.tensor([[0, 1], [1, 0]], dtype=torch.float) - targets = torch.tensor([1, -1], dtype=torch.float) - print("Cosine Similarity Loss:", cosine_similarity_loss(embeddings1, embeddings2, targets)) - -test_losses()