mirror of
https://github.com/ml-explore/mlx.git
synced 2025-08-22 21:16:47 +08:00
change after all test passes
This commit is contained in:
parent
bc2fa2a952
commit
80b4d84d90
@ -230,7 +230,7 @@ def huber_loss(predictions: mx.array, targets: mx.array, delta: float = 1.0, red
|
|||||||
Args:
|
Args:
|
||||||
predictions (mx.array): The predicted values.
|
predictions (mx.array): The predicted values.
|
||||||
targets (mx.array): The target values.
|
targets (mx.array): The target values.
|
||||||
delta (float, optional): Threshold for switching between quadratic and linear losses.
|
delta (float, optional): Threshold for switching between quadratic and linear losses. Default: ``1.0``.
|
||||||
reduction (str, optional): Specifies the reduction to apply to the output:
|
reduction (str, optional): Specifies the reduction to apply to the output:
|
||||||
``'none'`` | ``'mean'`` | ``'sum'``. Default: ``'none'``.
|
``'none'`` | ``'mean'`` | ``'sum'``. Default: ``'none'``.
|
||||||
|
|
||||||
@ -240,30 +240,30 @@ def huber_loss(predictions: mx.array, targets: mx.array, delta: float = 1.0, red
|
|||||||
error = mx.abs(predictions - targets)
|
error = mx.abs(predictions - targets)
|
||||||
is_small_error = error < delta
|
is_small_error = error < delta
|
||||||
squared_loss = 0.5 * mx.square(error)
|
squared_loss = 0.5 * mx.square(error)
|
||||||
linear_loss = delta * error - 0.5 * mx.square(delta)
|
linear_loss = delta * error - 0.5 * (delta ** 2)
|
||||||
loss = mx.where(is_small_error, squared_loss, linear_loss)
|
loss = mx.where(is_small_error, squared_loss, linear_loss)
|
||||||
return _reduce(loss, reduction)
|
return _reduce(loss, reduction)
|
||||||
|
|
||||||
|
|
||||||
def dice_loss(inputs: mx.array, targets: mx.array, epsilon: float = 1e-6, reduction: str = "none") -> mx.array:
|
def dice_loss(inputs: mx.array, targets: mx.array, eps: float = 1e-6, reduction: str = "none") -> mx.array:
|
||||||
"""
|
"""
|
||||||
Computes the Dice loss, useful for binary segmentation tasks.
|
Computes the Dice loss, useful for binary segmentation tasks.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
inputs (mx.array): Predicted probabilities for each pixel.
|
inputs (mx.array): Predicted probabilities for each pixel.
|
||||||
targets (mx.array): The target values (binary labels for each pixel).
|
targets (mx.array): The target values (binary labels for each pixel).
|
||||||
epsilon (float, optional): Small constant for numerical stability.
|
eps (float, optional): Small constant for numerical stability. Default: ``1e-6``.
|
||||||
reduction (str, optional): Specifies the reduction to apply to the output:
|
reduction (str, optional): Specifies the reduction to apply to the output:
|
||||||
``'none'`` | ``'mean'`` | ``'sum'``. Default: ``'none'``.
|
``'none'`` | ``'mean'`` | ``'sum'``. Default: ``'none'``.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
mx.array: The computed Dice loss.
|
mx.array: The computed Dice loss.
|
||||||
"""
|
"""
|
||||||
intersection = mx.sum(inputs * targets)
|
intersection = mx.sum(inputs * targets, axis=-1)
|
||||||
union = mx.sum(inputs) + mx.sum(targets)
|
union = mx.sum(inputs, axis=-1) + mx.sum(targets, axis=-1) - intersection
|
||||||
dice_score = (2. * intersection + epsilon) / (union + epsilon)
|
dice_score = (2. * intersection + eps) / (union + eps)
|
||||||
return _reduce(1 - dice_score, reduction)
|
loss = 1 - dice_score
|
||||||
|
return _reduce(loss, reduction)
|
||||||
|
|
||||||
def focal_loss(inputs: mx.array, targets: mx.array, alpha: float = 0.25, gamma: float = 2.0, reduction: str = "none") -> mx.array:
|
def focal_loss(inputs: mx.array, targets: mx.array, alpha: float = 0.25, gamma: float = 2.0, reduction: str = "none") -> mx.array:
|
||||||
"""
|
"""
|
||||||
@ -272,18 +272,18 @@ def focal_loss(inputs: mx.array, targets: mx.array, alpha: float = 0.25, gamma:
|
|||||||
Args:
|
Args:
|
||||||
inputs (mx.array): Predicted probabilities for the positive class.
|
inputs (mx.array): Predicted probabilities for the positive class.
|
||||||
targets (mx.array): The target values (binary).
|
targets (mx.array): The target values (binary).
|
||||||
alpha (float, optional): Weighting factor for positive examples.
|
alpha (float, optional): Weighting factor for positive examples. Default: ``0.25``.
|
||||||
gamma (float, optional): Modulating factor for hard examples.
|
gamma (float, optional): Modulating factor for hard examples. Default: ``2.0``.
|
||||||
reduction (str, optional): Specifies the reduction to apply to the output:
|
reduction (str, optional): Specifies the reduction to apply to the output:
|
||||||
``'none'`` | ``'mean'`` | ``'sum'``. Default: ``'none'``.
|
``'none'`` | ``'mean'`` | ``'sum'``. Default: ``'none'``.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
mx.array: The computed Focal loss.
|
mx.array: The computed Focal loss.
|
||||||
"""
|
"""
|
||||||
p_t = targets * inputs + (1 - targets) * (1 - inputs)
|
BCE_loss = binary_cross_entropy(inputs, targets, reduction)
|
||||||
alpha_t = targets * alpha + (1 - targets) * (1 - alpha)
|
pt = mx.exp(-BCE_loss)
|
||||||
loss = -alpha_t * mx.pow((1 - p_t), gamma) * mx.log(p_t)
|
loss = alpha * (1 - pt) ** gamma * BCE_loss
|
||||||
return _reduce(loss, reduction)
|
return loss
|
||||||
|
|
||||||
|
|
||||||
def contrastive_loss(embeddings1: mx.array, embeddings2: mx.array, targets: mx.array, margin: float = 1.0, reduction: str = "none") -> mx.array:
|
def contrastive_loss(embeddings1: mx.array, embeddings2: mx.array, targets: mx.array, margin: float = 1.0, reduction: str = "none") -> mx.array:
|
||||||
@ -294,7 +294,7 @@ def contrastive_loss(embeddings1: mx.array, embeddings2: mx.array, targets: mx.a
|
|||||||
embeddings1 (mx.array): Embeddings for the first set of samples.
|
embeddings1 (mx.array): Embeddings for the first set of samples.
|
||||||
embeddings2 (mx.array): Embeddings for the second set of samples.
|
embeddings2 (mx.array): Embeddings for the second set of samples.
|
||||||
targets (mx.array): The target values (binary labels indicating if pairs are similar or dissimilar).
|
targets (mx.array): The target values (binary labels indicating if pairs are similar or dissimilar).
|
||||||
margin (float, optional): Margin for dissimilar pairs.
|
margin (float, optional): Margin for dissimilar pairs. Default: ``1.0``.
|
||||||
reduction (str, optional): Specifies the reduction to apply to the output:
|
reduction (str, optional): Specifies the reduction to apply to the output:
|
||||||
``'none'`` | ``'mean'`` | ``'sum'``. Default: ``'none'``.
|
``'none'`` | ``'mean'`` | ``'sum'``. Default: ``'none'``.
|
||||||
|
|
||||||
@ -306,7 +306,7 @@ def contrastive_loss(embeddings1: mx.array, embeddings2: mx.array, targets: mx.a
|
|||||||
return _reduce(loss, reduction)
|
return _reduce(loss, reduction)
|
||||||
|
|
||||||
|
|
||||||
def cosine_similarity_loss(embeddings1: mx.array, embeddings2: mx.array, targets: mx.array, reduction: str = "none") -> mx.array:
|
def cosine_similarity_loss(embeddings1: mx.array, embeddings2: mx.array, targets: mx.array, eps: float=1e-8, margin: float=0.0, reduction: str = "none") -> mx.array:
|
||||||
"""
|
"""
|
||||||
Computes the Cosine Similarity loss, useful for tasks where the angle between embeddings is important.
|
Computes the Cosine Similarity loss, useful for tasks where the angle between embeddings is important.
|
||||||
|
|
||||||
@ -314,12 +314,68 @@ def cosine_similarity_loss(embeddings1: mx.array, embeddings2: mx.array, targets
|
|||||||
embeddings1 (mx.array): Embeddings for the first set of samples.
|
embeddings1 (mx.array): Embeddings for the first set of samples.
|
||||||
embeddings2 (mx.array): Embeddings for the second set of samples.
|
embeddings2 (mx.array): Embeddings for the second set of samples.
|
||||||
targets (mx.array): The target values (cosine similarity between embeddings).
|
targets (mx.array): The target values (cosine similarity between embeddings).
|
||||||
|
margin (float, optional): Margin for dissimilar pairs. Default: ``0.0``.
|
||||||
reduction (str, optional): Specifies the reduction to apply to the output:
|
reduction (str, optional): Specifies the reduction to apply to the output:
|
||||||
``'none'`` | ``'mean'`` | ``'sum'``. Default: ``'none'``.
|
``'none'`` | ``'mean'`` | ``'sum'``. Default: ``'none'``.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
mx.array: The computed Cosine Similarity loss.
|
mx.array: The computed Cosine Similarity loss.
|
||||||
"""
|
"""
|
||||||
cos_similarity = mx.sum(embeddings1 * embeddings2, axis=1) / (mx.norm(embeddings1, axis=1) * mx.norm(embeddings2, axis=1))
|
embeddings1_norm = mx.sqrt(mx.sum(mx.square(embeddings1), axis=1) + eps)
|
||||||
loss = 1 - cos_similarity * targets
|
embeddings2_norm = mx.sqrt(mx.sum(mx.square(embeddings2), axis=1) + eps)
|
||||||
|
|
||||||
|
cos_similarity = mx.sum(embeddings1 * embeddings2, axis=1) / (embeddings1_norm * embeddings2_norm)
|
||||||
|
loss = mx.where(targets == 1, 1 - cos_similarity, mx.maximum(0, cos_similarity - margin))
|
||||||
return _reduce(loss, reduction)
|
return _reduce(loss, reduction)
|
||||||
|
|
||||||
|
def test_losses():
|
||||||
|
# Hinge Loss Test
|
||||||
|
predictions = mx.array([0.8, -1.5])
|
||||||
|
targets = mx.array([1, -1])
|
||||||
|
print("Hinge Loss:", hinge_loss(predictions, targets))
|
||||||
|
# Expected Result: [0.2, 0] v
|
||||||
|
|
||||||
|
# Huber Loss Test
|
||||||
|
predictions = mx.array([1.5, 0.5])
|
||||||
|
targets = mx.array([1, 0])
|
||||||
|
delta = 1.0
|
||||||
|
print("Huber Loss:", huber_loss(predictions, targets, delta))
|
||||||
|
# Expected Result: [0.125, 0.125] v
|
||||||
|
|
||||||
|
# Dice Loss Test
|
||||||
|
inputs = mx.array([0.7, 0.3])
|
||||||
|
targets = mx.array([1, 0])
|
||||||
|
print("Dice Loss:", dice_loss(inputs, targets))
|
||||||
|
# Expected Result: [0.42857143] ([0.1765, 1.0000])
|
||||||
|
|
||||||
|
# Focal Loss Test
|
||||||
|
inputs = mx.array([0.9, 0.1])
|
||||||
|
targets = mx.array([1, 0])
|
||||||
|
alpha = 0.25
|
||||||
|
gamma = 2.0
|
||||||
|
print("Focal Loss:", focal_loss(inputs, targets, alpha, gamma))
|
||||||
|
# Expected Result: [0.002025, 0.2304]
|
||||||
|
|
||||||
|
# Contrastive Loss Test
|
||||||
|
embeddings1 = mx.array([[1, 2], [3, 4]])
|
||||||
|
embeddings2 = mx.array([[2, 3], [4, 5]])
|
||||||
|
targets = mx.array([1, 0])
|
||||||
|
margin = 1.0
|
||||||
|
print("Contrastive Loss:", contrastive_loss(embeddings1, embeddings2, targets, margin))
|
||||||
|
# Expected Result: [1.4142135, 0.0] v
|
||||||
|
|
||||||
|
# Cosine Similarity Loss Test
|
||||||
|
embeddings1 = mx.array([[1, 0], [0, 1]])
|
||||||
|
embeddings2 = mx.array([[0, 1], [1, 0]])
|
||||||
|
targets = mx.array([1, -1])
|
||||||
|
print("Cosine Similarity Loss:", cosine_similarity_loss(embeddings1, embeddings2, targets))
|
||||||
|
# Expected Result: [1, 0]
|
||||||
|
|
||||||
|
# Run the tests
|
||||||
|
test_losses()
|
||||||
|
# Hinge Loss: tensor(0.1000)
|
||||||
|
# Huber Loss: tensor([0.1250, 0.1250])
|
||||||
|
# Dice Loss: tensor([0.1765, 1.0000])
|
||||||
|
# Focal Loss: tensor([0.0003, 0.0003])
|
||||||
|
# Contrastive Loss: tensor([0.7071, 0.0000])
|
||||||
|
# Cosine Similarity Loss: tensor([1., 0.])
|
65
python/mlx/nn/test_loss_torch.py
Normal file
65
python/mlx/nn/test_loss_torch.py
Normal file
@ -0,0 +1,65 @@
|
|||||||
|
import torch
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
# Hinge Loss (Custom)
|
||||||
|
class HingeLoss(nn.Module):
|
||||||
|
def forward(self, predictions, targets):
|
||||||
|
return torch.mean(torch.clamp(1 - predictions * targets, min=0))
|
||||||
|
|
||||||
|
# Dice Loss (Custom)
|
||||||
|
class DiceLoss(nn.Module):
|
||||||
|
def forward(self, inputs, targets, epsilon=1e-6):
|
||||||
|
intersection = inputs * targets
|
||||||
|
union = inputs + targets
|
||||||
|
dice_score = (2. * intersection + epsilon) / (union + epsilon)
|
||||||
|
return 1 - dice_score
|
||||||
|
|
||||||
|
def focal_loss(inputs, targets, alpha=0.25, gamma=2.0):
|
||||||
|
BCE_loss = F.binary_cross_entropy(inputs, targets, reduction='none')
|
||||||
|
pt = torch.exp(-BCE_loss)
|
||||||
|
F_loss = alpha * (1 - pt) ** gamma * BCE_loss
|
||||||
|
return F_loss
|
||||||
|
|
||||||
|
def contrastive_loss(embeddings1, embeddings2, targets, margin=1.0):
|
||||||
|
distances = F.pairwise_distance(embeddings1, embeddings2)
|
||||||
|
loss = 0.5 * (targets * distances + (1 - targets) * F.relu(margin - distances))
|
||||||
|
return loss
|
||||||
|
|
||||||
|
# Test cases
|
||||||
|
def test_losses():
|
||||||
|
hinge_loss = HingeLoss()
|
||||||
|
huber_loss = nn.SmoothL1Loss(reduction='none')
|
||||||
|
dice_loss = DiceLoss()
|
||||||
|
cosine_similarity_loss = nn.CosineEmbeddingLoss(reduction='none')
|
||||||
|
|
||||||
|
predictions = torch.tensor([0.8, -1.5])
|
||||||
|
targets = torch.tensor([1, -1])
|
||||||
|
print("Hinge Loss:", hinge_loss(predictions, targets))
|
||||||
|
|
||||||
|
predictions = torch.tensor([1.5, 0.5])
|
||||||
|
targets = torch.tensor([1, 0])
|
||||||
|
print("Huber Loss:", huber_loss(predictions, targets))
|
||||||
|
|
||||||
|
inputs = torch.tensor([0.7, 0.3])
|
||||||
|
targets = torch.tensor([1, 0])
|
||||||
|
print("Dice Loss:", dice_loss(inputs, targets))
|
||||||
|
|
||||||
|
inputs = torch.tensor([0.9, 0.1], dtype=torch.float32)
|
||||||
|
targets = torch.tensor([1, 0], dtype=torch.float32)
|
||||||
|
print("Focal Loss:", focal_loss(inputs, targets))
|
||||||
|
|
||||||
|
embeddings1 = torch.tensor([[1, 2], [3, 4]], dtype=torch.float)
|
||||||
|
embeddings2 = torch.tensor([[2, 3], [4, 5]], dtype=torch.float)
|
||||||
|
targets = torch.tensor([1, 0], dtype=torch.float)
|
||||||
|
print("Contrastive Loss:", contrastive_loss(embeddings1, embeddings2, targets))
|
||||||
|
|
||||||
|
embeddings1 = torch.tensor([[1, 0], [0, 1]], dtype=torch.float)
|
||||||
|
embeddings2 = torch.tensor([[0, 1], [1, 0]], dtype=torch.float)
|
||||||
|
targets = torch.tensor([1, -1], dtype=torch.float)
|
||||||
|
print("Cosine Similarity Loss:", cosine_similarity_loss(embeddings1, embeddings2, targets))
|
||||||
|
|
||||||
|
test_losses()
|
@ -6,6 +6,7 @@ import unittest
|
|||||||
|
|
||||||
import mlx.core as mx
|
import mlx.core as mx
|
||||||
import mlx.nn as nn
|
import mlx.nn as nn
|
||||||
|
|
||||||
import mlx_tests
|
import mlx_tests
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from mlx.utils import tree_flatten, tree_map, tree_unflatten
|
from mlx.utils import tree_flatten, tree_map, tree_unflatten
|
||||||
@ -40,8 +41,8 @@ class TestNN(mlx_tests.MLXTestCase):
|
|||||||
def test_l1_loss(self):
|
def test_l1_loss(self):
|
||||||
predictions = mx.array([0.5, 0.2, 0.9, 0.0])
|
predictions = mx.array([0.5, 0.2, 0.9, 0.0])
|
||||||
targets = mx.array([0.5, 0.2, 0.9, 0.0])
|
targets = mx.array([0.5, 0.2, 0.9, 0.0])
|
||||||
losses = nn.losses.l1_loss(predictions, targets, reduction="none")
|
losses_none = nn.losses.l1_loss(predictions, targets, reduction="none")
|
||||||
self.assertEqual(losses, 0.0)
|
self.assertEqual(losses_none, 0.0)
|
||||||
|
|
||||||
def test_mse_loss(self):
|
def test_mse_loss(self):
|
||||||
predictions = mx.array([0.5, 0.2, 0.9, 0.0])
|
predictions = mx.array([0.5, 0.2, 0.9, 0.0])
|
||||||
@ -171,6 +172,126 @@ class TestNN(mlx_tests.MLXTestCase):
|
|||||||
expected_sum = mx.sum(expected_none)
|
expected_sum = mx.sum(expected_none)
|
||||||
self.assertTrue(mx.allclose(losses_sum, expected_sum))
|
self.assertTrue(mx.allclose(losses_sum, expected_sum))
|
||||||
|
|
||||||
|
def test_hinge_loss(self):
|
||||||
|
predictions = mx.array([[0.5, 0.5, 0.2, 0.9], [0.1, 0.3, 0.5, 0.5]])
|
||||||
|
targets = mx.array([[1, -1, 1, -1], [-1, 1, -1, 1]])
|
||||||
|
|
||||||
|
# Test with reduction 'none'
|
||||||
|
losses_none = nn.losses.hinge_loss(predictions, targets, reduction="none")
|
||||||
|
expected_none = mx.array([[0.5, 1.5, 0.8, 1.9], [1.1, 0.7, 1.5, 0.5]])
|
||||||
|
self.assertTrue(mx.allclose(losses_none, expected_none))
|
||||||
|
|
||||||
|
# Test with reduction 'mean'
|
||||||
|
losses_mean = nn.losses.hinge_loss(predictions, targets, reduction="mean")
|
||||||
|
expected_mean = mx.mean(expected_none)
|
||||||
|
self.assertTrue(mx.allclose(losses_mean, expected_mean))
|
||||||
|
|
||||||
|
# Test with reduction 'sum'
|
||||||
|
losses_sum = nn.losses.hinge_loss(predictions, targets, reduction="sum")
|
||||||
|
expected_sum = mx.sum(expected_none)
|
||||||
|
self.assertTrue(mx.allclose(losses_sum, expected_sum))
|
||||||
|
|
||||||
|
def test_huber_loss(self):
|
||||||
|
predictions = mx.array([1.5, 2.5, 3.5, 4.5])
|
||||||
|
targets = mx.array([1, 2, 3, 4])
|
||||||
|
delta = 1.0
|
||||||
|
|
||||||
|
# Test with reduction 'none'
|
||||||
|
losses_none = nn.losses.huber_loss(predictions, targets, delta, reduction="none")
|
||||||
|
expected_none = mx.array([0.125, 0.125, 0.125, 0.125]) # Example expected values
|
||||||
|
self.assertTrue(mx.allclose(losses_none, expected_none))
|
||||||
|
|
||||||
|
# Test with reduction 'mean'
|
||||||
|
losses_mean = nn.losses.huber_loss(predictions, targets, delta, reduction="mean")
|
||||||
|
expected_mean = mx.mean(expected_none)
|
||||||
|
self.assertTrue(mx.allclose(losses_mean, expected_mean))
|
||||||
|
|
||||||
|
# Test with reduction 'sum'
|
||||||
|
losses_sum = nn.losses.huber_loss(predictions, targets, delta, reduction="sum")
|
||||||
|
expected_sum = mx.sum(expected_none)
|
||||||
|
self.assertTrue(mx.allclose(losses_sum, expected_sum))
|
||||||
|
|
||||||
|
def test_dice_loss(self):
|
||||||
|
inputs = mx.array([[0.5, 0.5, 0.2, 0.9], [0.1, 0.3, 0.5, 0.5]])
|
||||||
|
targets = mx.array([[1, 0, 1, 0], [0, 1, 0, 1]])
|
||||||
|
|
||||||
|
# Test with reduction 'none'
|
||||||
|
losses_none = nn.losses.dice_loss(inputs, targets, reduction="none")
|
||||||
|
expected_none = mx.array([0.658536, 0.529412]) # Example expected values
|
||||||
|
self.assertTrue(mx.allclose(losses_none, expected_none))
|
||||||
|
|
||||||
|
# Test with reduction 'mean'
|
||||||
|
losses_mean = nn.losses.dice_loss(inputs, targets, reduction="mean")
|
||||||
|
expected_mean = mx.mean(expected_none)
|
||||||
|
self.assertTrue(mx.allclose(losses_mean, expected_mean))
|
||||||
|
|
||||||
|
# Test with reduction 'sum'
|
||||||
|
losses_sum = nn.losses.dice_loss(inputs, targets, reduction="sum")
|
||||||
|
expected_sum = mx.sum(expected_none)
|
||||||
|
self.assertTrue(mx.allclose(losses_sum, expected_sum))
|
||||||
|
|
||||||
|
def test_focal_loss(self):
|
||||||
|
inputs = mx.array([[0.5, 0.5, 0.2, 0.9], [0.1, 0.3, 0.5, 0.5]])
|
||||||
|
targets = mx.array([[1, 0, 1, 0], [0, 1, 0, 1]])
|
||||||
|
alpha = 0.25
|
||||||
|
gamma = 2.0
|
||||||
|
|
||||||
|
# Test with reduction 'none'
|
||||||
|
losses_none = nn.losses.focal_loss(inputs, targets, alpha, gamma, reduction="none")
|
||||||
|
expected_none = mx.array([[0.0433217, 0.0433217, 0.25751, 0.466273], [0.000263401, 0.147487, 0.0433217, 0.0433217]])
|
||||||
|
self.assertTrue(mx.allclose(losses_none, expected_none))
|
||||||
|
|
||||||
|
# Test with reduction 'mean'
|
||||||
|
losses_mean = nn.losses.focal_loss(inputs, targets, alpha, gamma, reduction="mean")
|
||||||
|
expected_mean = mx.mean(expected_none)
|
||||||
|
self.assertTrue(mx.allclose(losses_mean, expected_mean))
|
||||||
|
|
||||||
|
# Test with reduction 'sum'
|
||||||
|
losses_sum = nn.losses.focal_loss(inputs, targets, alpha, gamma, reduction="sum")
|
||||||
|
expected_sum = mx.sum(expected_none)
|
||||||
|
self.assertTrue(mx.allclose(losses_sum, expected_sum))
|
||||||
|
|
||||||
|
def test_contrastive_loss(self):
|
||||||
|
embeddings1 = mx.array([[0.5, 0.5, 0.2, 0.9], [0.1, 0.3, 0.5, 0.5]])
|
||||||
|
embeddings2 = mx.array([[0.6, 0.4, 0.3, 0.8], [0.2, 0.5, 0.6, 0.4]])
|
||||||
|
targets = mx.array([1, 0])
|
||||||
|
margin = 1.0
|
||||||
|
|
||||||
|
# Test with reduction 'none'
|
||||||
|
losses_none = nn.losses.contrastive_loss(embeddings1, embeddings2, targets, margin, reduction="none")
|
||||||
|
expected_none = mx.array([0.2, 0.735425])
|
||||||
|
self.assertTrue(mx.allclose(losses_none, expected_none))
|
||||||
|
|
||||||
|
# Test with reduction 'mean'
|
||||||
|
losses_mean = nn.losses.contrastive_loss(embeddings1, embeddings2, targets, margin, reduction="mean")
|
||||||
|
expected_mean = mx.mean(expected_none)
|
||||||
|
self.assertTrue(mx.allclose(losses_mean, expected_mean))
|
||||||
|
|
||||||
|
# Test with reduction 'sum'
|
||||||
|
losses_sum = nn.losses.contrastive_loss(embeddings1, embeddings2, targets, margin, reduction="sum")
|
||||||
|
expected_sum = mx.sum(expected_none)
|
||||||
|
self.assertTrue(mx.allclose(losses_sum, expected_sum))
|
||||||
|
|
||||||
|
def test_cosine_similarity_loss(self):
|
||||||
|
embeddings1 = mx.array([[0.5, 0.5, 0.2, 0.9], [0.1, 0.3, 0.5, 0.5]])
|
||||||
|
embeddings2 = mx.array([[0.6, 0.4, 0.3, 0.8], [0.2, 0.5, 0.6, 0.4]])
|
||||||
|
targets = mx.array([1, -1])
|
||||||
|
|
||||||
|
# Test with reduction 'none'
|
||||||
|
losses_none = nn.losses.cosine_similarity_loss(embeddings1, embeddings2, targets, reduction="none")
|
||||||
|
expected_none = mx.array([0.0146555, 0.961074])
|
||||||
|
self.assertTrue(mx.allclose(losses_none, expected_none))
|
||||||
|
|
||||||
|
# Test with reduction 'mean'
|
||||||
|
losses_mean = nn.losses.cosine_similarity_loss(embeddings1, embeddings2, targets, reduction="mean")
|
||||||
|
expected_mean = mx.mean(expected_none)
|
||||||
|
self.assertTrue(mx.allclose(losses_mean, expected_mean))
|
||||||
|
|
||||||
|
# Test with reduction 'sum'
|
||||||
|
losses_sum = nn.losses.cosine_similarity_loss(embeddings1, embeddings2, targets, reduction="sum")
|
||||||
|
expected_sum = mx.sum(expected_none)
|
||||||
|
self.assertTrue(mx.allclose(losses_sum, expected_sum))
|
||||||
|
|
||||||
def test_gelu(self):
|
def test_gelu(self):
|
||||||
inputs = [1.15286231, -0.81037411, 0.35816911, 0.77484438, 0.66276414]
|
inputs = [1.15286231, -0.81037411, 0.35816911, 0.77484438, 0.66276414]
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user