From 80b4d84d90c52400133ff7828b50893b5e476797 Mon Sep 17 00:00:00 2001 From: Jyun1998 Date: Tue, 2 Jan 2024 02:51:11 +0900 Subject: [PATCH] change after all test passes --- python/mlx/nn/losses.py | 96 +++++++++++++++++++----- python/mlx/nn/test_loss_torch.py | 65 ++++++++++++++++ python/tests/test_nn.py | 125 ++++++++++++++++++++++++++++++- 3 files changed, 264 insertions(+), 22 deletions(-) create mode 100644 python/mlx/nn/test_loss_torch.py diff --git a/python/mlx/nn/losses.py b/python/mlx/nn/losses.py index fd2068a3d..60f5bd12b 100644 --- a/python/mlx/nn/losses.py +++ b/python/mlx/nn/losses.py @@ -230,7 +230,7 @@ def huber_loss(predictions: mx.array, targets: mx.array, delta: float = 1.0, red Args: predictions (mx.array): The predicted values. targets (mx.array): The target values. - delta (float, optional): Threshold for switching between quadratic and linear losses. + delta (float, optional): Threshold for switching between quadratic and linear losses. Default: ``1.0``. reduction (str, optional): Specifies the reduction to apply to the output: ``'none'`` | ``'mean'`` | ``'sum'``. Default: ``'none'``. @@ -240,30 +240,30 @@ def huber_loss(predictions: mx.array, targets: mx.array, delta: float = 1.0, red error = mx.abs(predictions - targets) is_small_error = error < delta squared_loss = 0.5 * mx.square(error) - linear_loss = delta * error - 0.5 * mx.square(delta) + linear_loss = delta * error - 0.5 * (delta ** 2) loss = mx.where(is_small_error, squared_loss, linear_loss) return _reduce(loss, reduction) -def dice_loss(inputs: mx.array, targets: mx.array, epsilon: float = 1e-6, reduction: str = "none") -> mx.array: +def dice_loss(inputs: mx.array, targets: mx.array, eps: float = 1e-6, reduction: str = "none") -> mx.array: """ Computes the Dice loss, useful for binary segmentation tasks. Args: inputs (mx.array): Predicted probabilities for each pixel. targets (mx.array): The target values (binary labels for each pixel). - epsilon (float, optional): Small constant for numerical stability. + eps (float, optional): Small constant for numerical stability. Default: ``1e-6``. reduction (str, optional): Specifies the reduction to apply to the output: ``'none'`` | ``'mean'`` | ``'sum'``. Default: ``'none'``. Returns: mx.array: The computed Dice loss. """ - intersection = mx.sum(inputs * targets) - union = mx.sum(inputs) + mx.sum(targets) - dice_score = (2. * intersection + epsilon) / (union + epsilon) - return _reduce(1 - dice_score, reduction) - + intersection = mx.sum(inputs * targets, axis=-1) + union = mx.sum(inputs, axis=-1) + mx.sum(targets, axis=-1) - intersection + dice_score = (2. * intersection + eps) / (union + eps) + loss = 1 - dice_score + return _reduce(loss, reduction) def focal_loss(inputs: mx.array, targets: mx.array, alpha: float = 0.25, gamma: float = 2.0, reduction: str = "none") -> mx.array: """ @@ -272,18 +272,18 @@ def focal_loss(inputs: mx.array, targets: mx.array, alpha: float = 0.25, gamma: Args: inputs (mx.array): Predicted probabilities for the positive class. targets (mx.array): The target values (binary). - alpha (float, optional): Weighting factor for positive examples. - gamma (float, optional): Modulating factor for hard examples. + alpha (float, optional): Weighting factor for positive examples. Default: ``0.25``. + gamma (float, optional): Modulating factor for hard examples. Default: ``2.0``. reduction (str, optional): Specifies the reduction to apply to the output: ``'none'`` | ``'mean'`` | ``'sum'``. Default: ``'none'``. Returns: mx.array: The computed Focal loss. """ - p_t = targets * inputs + (1 - targets) * (1 - inputs) - alpha_t = targets * alpha + (1 - targets) * (1 - alpha) - loss = -alpha_t * mx.pow((1 - p_t), gamma) * mx.log(p_t) - return _reduce(loss, reduction) + BCE_loss = binary_cross_entropy(inputs, targets, reduction) + pt = mx.exp(-BCE_loss) + loss = alpha * (1 - pt) ** gamma * BCE_loss + return loss def contrastive_loss(embeddings1: mx.array, embeddings2: mx.array, targets: mx.array, margin: float = 1.0, reduction: str = "none") -> mx.array: @@ -294,7 +294,7 @@ def contrastive_loss(embeddings1: mx.array, embeddings2: mx.array, targets: mx.a embeddings1 (mx.array): Embeddings for the first set of samples. embeddings2 (mx.array): Embeddings for the second set of samples. targets (mx.array): The target values (binary labels indicating if pairs are similar or dissimilar). - margin (float, optional): Margin for dissimilar pairs. + margin (float, optional): Margin for dissimilar pairs. Default: ``1.0``. reduction (str, optional): Specifies the reduction to apply to the output: ``'none'`` | ``'mean'`` | ``'sum'``. Default: ``'none'``. @@ -306,7 +306,7 @@ def contrastive_loss(embeddings1: mx.array, embeddings2: mx.array, targets: mx.a return _reduce(loss, reduction) -def cosine_similarity_loss(embeddings1: mx.array, embeddings2: mx.array, targets: mx.array, reduction: str = "none") -> mx.array: +def cosine_similarity_loss(embeddings1: mx.array, embeddings2: mx.array, targets: mx.array, eps: float=1e-8, margin: float=0.0, reduction: str = "none") -> mx.array: """ Computes the Cosine Similarity loss, useful for tasks where the angle between embeddings is important. @@ -314,12 +314,68 @@ def cosine_similarity_loss(embeddings1: mx.array, embeddings2: mx.array, targets embeddings1 (mx.array): Embeddings for the first set of samples. embeddings2 (mx.array): Embeddings for the second set of samples. targets (mx.array): The target values (cosine similarity between embeddings). + margin (float, optional): Margin for dissimilar pairs. Default: ``0.0``. reduction (str, optional): Specifies the reduction to apply to the output: ``'none'`` | ``'mean'`` | ``'sum'``. Default: ``'none'``. Returns: mx.array: The computed Cosine Similarity loss. """ - cos_similarity = mx.sum(embeddings1 * embeddings2, axis=1) / (mx.norm(embeddings1, axis=1) * mx.norm(embeddings2, axis=1)) - loss = 1 - cos_similarity * targets - return _reduce(loss, reduction) \ No newline at end of file + embeddings1_norm = mx.sqrt(mx.sum(mx.square(embeddings1), axis=1) + eps) + embeddings2_norm = mx.sqrt(mx.sum(mx.square(embeddings2), axis=1) + eps) + + cos_similarity = mx.sum(embeddings1 * embeddings2, axis=1) / (embeddings1_norm * embeddings2_norm) + loss = mx.where(targets == 1, 1 - cos_similarity, mx.maximum(0, cos_similarity - margin)) + return _reduce(loss, reduction) + +def test_losses(): + # Hinge Loss Test + predictions = mx.array([0.8, -1.5]) + targets = mx.array([1, -1]) + print("Hinge Loss:", hinge_loss(predictions, targets)) + # Expected Result: [0.2, 0] v + + # Huber Loss Test + predictions = mx.array([1.5, 0.5]) + targets = mx.array([1, 0]) + delta = 1.0 + print("Huber Loss:", huber_loss(predictions, targets, delta)) + # Expected Result: [0.125, 0.125] v + + # Dice Loss Test + inputs = mx.array([0.7, 0.3]) + targets = mx.array([1, 0]) + print("Dice Loss:", dice_loss(inputs, targets)) + # Expected Result: [0.42857143] ([0.1765, 1.0000]) + + # Focal Loss Test + inputs = mx.array([0.9, 0.1]) + targets = mx.array([1, 0]) + alpha = 0.25 + gamma = 2.0 + print("Focal Loss:", focal_loss(inputs, targets, alpha, gamma)) + # Expected Result: [0.002025, 0.2304] + + # Contrastive Loss Test + embeddings1 = mx.array([[1, 2], [3, 4]]) + embeddings2 = mx.array([[2, 3], [4, 5]]) + targets = mx.array([1, 0]) + margin = 1.0 + print("Contrastive Loss:", contrastive_loss(embeddings1, embeddings2, targets, margin)) + # Expected Result: [1.4142135, 0.0] v + + # Cosine Similarity Loss Test + embeddings1 = mx.array([[1, 0], [0, 1]]) + embeddings2 = mx.array([[0, 1], [1, 0]]) + targets = mx.array([1, -1]) + print("Cosine Similarity Loss:", cosine_similarity_loss(embeddings1, embeddings2, targets)) + # Expected Result: [1, 0] + +# Run the tests +test_losses() +# Hinge Loss: tensor(0.1000) +# Huber Loss: tensor([0.1250, 0.1250]) +# Dice Loss: tensor([0.1765, 1.0000]) +# Focal Loss: tensor([0.0003, 0.0003]) +# Contrastive Loss: tensor([0.7071, 0.0000]) +# Cosine Similarity Loss: tensor([1., 0.]) \ No newline at end of file diff --git a/python/mlx/nn/test_loss_torch.py b/python/mlx/nn/test_loss_torch.py new file mode 100644 index 000000000..b0e7185b1 --- /dev/null +++ b/python/mlx/nn/test_loss_torch.py @@ -0,0 +1,65 @@ +import torch +import torch.nn.functional as F + +import torch +import torch.nn as nn +import torch.nn.functional as F + +# Hinge Loss (Custom) +class HingeLoss(nn.Module): + def forward(self, predictions, targets): + return torch.mean(torch.clamp(1 - predictions * targets, min=0)) + +# Dice Loss (Custom) +class DiceLoss(nn.Module): + def forward(self, inputs, targets, epsilon=1e-6): + intersection = inputs * targets + union = inputs + targets + dice_score = (2. * intersection + epsilon) / (union + epsilon) + return 1 - dice_score + +def focal_loss(inputs, targets, alpha=0.25, gamma=2.0): + BCE_loss = F.binary_cross_entropy(inputs, targets, reduction='none') + pt = torch.exp(-BCE_loss) + F_loss = alpha * (1 - pt) ** gamma * BCE_loss + return F_loss + +def contrastive_loss(embeddings1, embeddings2, targets, margin=1.0): + distances = F.pairwise_distance(embeddings1, embeddings2) + loss = 0.5 * (targets * distances + (1 - targets) * F.relu(margin - distances)) + return loss + +# Test cases +def test_losses(): + hinge_loss = HingeLoss() + huber_loss = nn.SmoothL1Loss(reduction='none') + dice_loss = DiceLoss() + cosine_similarity_loss = nn.CosineEmbeddingLoss(reduction='none') + + predictions = torch.tensor([0.8, -1.5]) + targets = torch.tensor([1, -1]) + print("Hinge Loss:", hinge_loss(predictions, targets)) + + predictions = torch.tensor([1.5, 0.5]) + targets = torch.tensor([1, 0]) + print("Huber Loss:", huber_loss(predictions, targets)) + + inputs = torch.tensor([0.7, 0.3]) + targets = torch.tensor([1, 0]) + print("Dice Loss:", dice_loss(inputs, targets)) + + inputs = torch.tensor([0.9, 0.1], dtype=torch.float32) + targets = torch.tensor([1, 0], dtype=torch.float32) + print("Focal Loss:", focal_loss(inputs, targets)) + + embeddings1 = torch.tensor([[1, 2], [3, 4]], dtype=torch.float) + embeddings2 = torch.tensor([[2, 3], [4, 5]], dtype=torch.float) + targets = torch.tensor([1, 0], dtype=torch.float) + print("Contrastive Loss:", contrastive_loss(embeddings1, embeddings2, targets)) + + embeddings1 = torch.tensor([[1, 0], [0, 1]], dtype=torch.float) + embeddings2 = torch.tensor([[0, 1], [1, 0]], dtype=torch.float) + targets = torch.tensor([1, -1], dtype=torch.float) + print("Cosine Similarity Loss:", cosine_similarity_loss(embeddings1, embeddings2, targets)) + +test_losses() diff --git a/python/tests/test_nn.py b/python/tests/test_nn.py index f5597474d..5ddd0c0ca 100644 --- a/python/tests/test_nn.py +++ b/python/tests/test_nn.py @@ -6,6 +6,7 @@ import unittest import mlx.core as mx import mlx.nn as nn + import mlx_tests import numpy as np from mlx.utils import tree_flatten, tree_map, tree_unflatten @@ -40,8 +41,8 @@ class TestNN(mlx_tests.MLXTestCase): def test_l1_loss(self): predictions = mx.array([0.5, 0.2, 0.9, 0.0]) targets = mx.array([0.5, 0.2, 0.9, 0.0]) - losses = nn.losses.l1_loss(predictions, targets, reduction="none") - self.assertEqual(losses, 0.0) + losses_none = nn.losses.l1_loss(predictions, targets, reduction="none") + self.assertEqual(losses_none, 0.0) def test_mse_loss(self): predictions = mx.array([0.5, 0.2, 0.9, 0.0]) @@ -171,6 +172,126 @@ class TestNN(mlx_tests.MLXTestCase): expected_sum = mx.sum(expected_none) self.assertTrue(mx.allclose(losses_sum, expected_sum)) + def test_hinge_loss(self): + predictions = mx.array([[0.5, 0.5, 0.2, 0.9], [0.1, 0.3, 0.5, 0.5]]) + targets = mx.array([[1, -1, 1, -1], [-1, 1, -1, 1]]) + + # Test with reduction 'none' + losses_none = nn.losses.hinge_loss(predictions, targets, reduction="none") + expected_none = mx.array([[0.5, 1.5, 0.8, 1.9], [1.1, 0.7, 1.5, 0.5]]) + self.assertTrue(mx.allclose(losses_none, expected_none)) + + # Test with reduction 'mean' + losses_mean = nn.losses.hinge_loss(predictions, targets, reduction="mean") + expected_mean = mx.mean(expected_none) + self.assertTrue(mx.allclose(losses_mean, expected_mean)) + + # Test with reduction 'sum' + losses_sum = nn.losses.hinge_loss(predictions, targets, reduction="sum") + expected_sum = mx.sum(expected_none) + self.assertTrue(mx.allclose(losses_sum, expected_sum)) + + def test_huber_loss(self): + predictions = mx.array([1.5, 2.5, 3.5, 4.5]) + targets = mx.array([1, 2, 3, 4]) + delta = 1.0 + + # Test with reduction 'none' + losses_none = nn.losses.huber_loss(predictions, targets, delta, reduction="none") + expected_none = mx.array([0.125, 0.125, 0.125, 0.125]) # Example expected values + self.assertTrue(mx.allclose(losses_none, expected_none)) + + # Test with reduction 'mean' + losses_mean = nn.losses.huber_loss(predictions, targets, delta, reduction="mean") + expected_mean = mx.mean(expected_none) + self.assertTrue(mx.allclose(losses_mean, expected_mean)) + + # Test with reduction 'sum' + losses_sum = nn.losses.huber_loss(predictions, targets, delta, reduction="sum") + expected_sum = mx.sum(expected_none) + self.assertTrue(mx.allclose(losses_sum, expected_sum)) + + def test_dice_loss(self): + inputs = mx.array([[0.5, 0.5, 0.2, 0.9], [0.1, 0.3, 0.5, 0.5]]) + targets = mx.array([[1, 0, 1, 0], [0, 1, 0, 1]]) + + # Test with reduction 'none' + losses_none = nn.losses.dice_loss(inputs, targets, reduction="none") + expected_none = mx.array([0.658536, 0.529412]) # Example expected values + self.assertTrue(mx.allclose(losses_none, expected_none)) + + # Test with reduction 'mean' + losses_mean = nn.losses.dice_loss(inputs, targets, reduction="mean") + expected_mean = mx.mean(expected_none) + self.assertTrue(mx.allclose(losses_mean, expected_mean)) + + # Test with reduction 'sum' + losses_sum = nn.losses.dice_loss(inputs, targets, reduction="sum") + expected_sum = mx.sum(expected_none) + self.assertTrue(mx.allclose(losses_sum, expected_sum)) + + def test_focal_loss(self): + inputs = mx.array([[0.5, 0.5, 0.2, 0.9], [0.1, 0.3, 0.5, 0.5]]) + targets = mx.array([[1, 0, 1, 0], [0, 1, 0, 1]]) + alpha = 0.25 + gamma = 2.0 + + # Test with reduction 'none' + losses_none = nn.losses.focal_loss(inputs, targets, alpha, gamma, reduction="none") + expected_none = mx.array([[0.0433217, 0.0433217, 0.25751, 0.466273], [0.000263401, 0.147487, 0.0433217, 0.0433217]]) + self.assertTrue(mx.allclose(losses_none, expected_none)) + + # Test with reduction 'mean' + losses_mean = nn.losses.focal_loss(inputs, targets, alpha, gamma, reduction="mean") + expected_mean = mx.mean(expected_none) + self.assertTrue(mx.allclose(losses_mean, expected_mean)) + + # Test with reduction 'sum' + losses_sum = nn.losses.focal_loss(inputs, targets, alpha, gamma, reduction="sum") + expected_sum = mx.sum(expected_none) + self.assertTrue(mx.allclose(losses_sum, expected_sum)) + + def test_contrastive_loss(self): + embeddings1 = mx.array([[0.5, 0.5, 0.2, 0.9], [0.1, 0.3, 0.5, 0.5]]) + embeddings2 = mx.array([[0.6, 0.4, 0.3, 0.8], [0.2, 0.5, 0.6, 0.4]]) + targets = mx.array([1, 0]) + margin = 1.0 + + # Test with reduction 'none' + losses_none = nn.losses.contrastive_loss(embeddings1, embeddings2, targets, margin, reduction="none") + expected_none = mx.array([0.2, 0.735425]) + self.assertTrue(mx.allclose(losses_none, expected_none)) + + # Test with reduction 'mean' + losses_mean = nn.losses.contrastive_loss(embeddings1, embeddings2, targets, margin, reduction="mean") + expected_mean = mx.mean(expected_none) + self.assertTrue(mx.allclose(losses_mean, expected_mean)) + + # Test with reduction 'sum' + losses_sum = nn.losses.contrastive_loss(embeddings1, embeddings2, targets, margin, reduction="sum") + expected_sum = mx.sum(expected_none) + self.assertTrue(mx.allclose(losses_sum, expected_sum)) + + def test_cosine_similarity_loss(self): + embeddings1 = mx.array([[0.5, 0.5, 0.2, 0.9], [0.1, 0.3, 0.5, 0.5]]) + embeddings2 = mx.array([[0.6, 0.4, 0.3, 0.8], [0.2, 0.5, 0.6, 0.4]]) + targets = mx.array([1, -1]) + + # Test with reduction 'none' + losses_none = nn.losses.cosine_similarity_loss(embeddings1, embeddings2, targets, reduction="none") + expected_none = mx.array([0.0146555, 0.961074]) + self.assertTrue(mx.allclose(losses_none, expected_none)) + + # Test with reduction 'mean' + losses_mean = nn.losses.cosine_similarity_loss(embeddings1, embeddings2, targets, reduction="mean") + expected_mean = mx.mean(expected_none) + self.assertTrue(mx.allclose(losses_mean, expected_mean)) + + # Test with reduction 'sum' + losses_sum = nn.losses.cosine_similarity_loss(embeddings1, embeddings2, targets, reduction="sum") + expected_sum = mx.sum(expected_none) + self.assertTrue(mx.allclose(losses_sum, expected_sum)) + def test_gelu(self): inputs = [1.15286231, -0.81037411, 0.35816911, 0.77484438, 0.66276414]