mirror of
https://github.com/ml-explore/mlx.git
synced 2025-07-24 02:41:19 +08:00
Additoinal losses (#336)
* cosine similarity loss --------- Co-authored-by: Awni Hannun <awni@apple.com> * Docstring nits
This commit is contained in:
parent
432ee5650b
commit
0b8aeddac6
@ -19,4 +19,5 @@ Loss Functions
|
||||
triplet_loss
|
||||
hinge_loss
|
||||
huber_loss
|
||||
log_cosh_loss
|
||||
log_cosh_loss
|
||||
cosine_similarity_loss
|
@ -6,6 +6,17 @@ import mlx.core as mx
|
||||
from mlx.nn.layers.base import Module
|
||||
|
||||
|
||||
def _reduce(loss: mx.array, reduction: str = "none"):
|
||||
if reduction == "mean":
|
||||
return mx.mean(loss)
|
||||
elif reduction == "sum":
|
||||
return mx.sum(loss)
|
||||
elif reduction == "none":
|
||||
return loss
|
||||
else:
|
||||
raise ValueError("Invalid reduction. Must be 'none', 'mean', or 'sum'.")
|
||||
|
||||
|
||||
def cross_entropy(
|
||||
logits: mx.array,
|
||||
targets: mx.array,
|
||||
@ -272,17 +283,6 @@ def triplet_loss(
|
||||
return _reduce(loss, reduction)
|
||||
|
||||
|
||||
def _reduce(loss: mx.array, reduction: str = "none"):
|
||||
if reduction == "mean":
|
||||
return mx.mean(loss)
|
||||
elif reduction == "sum":
|
||||
return mx.sum(loss)
|
||||
elif reduction == "none":
|
||||
return loss
|
||||
else:
|
||||
raise ValueError("Invalid reduction. Must be 'none', 'mean', or 'sum'.")
|
||||
|
||||
|
||||
def hinge_loss(
|
||||
inputs: mx.array, targets: mx.array, reduction: str = "none"
|
||||
) -> mx.array:
|
||||
@ -372,3 +372,39 @@ def log_cosh_loss(
|
||||
loss = mx.logaddexp(errors, -errors) - math.log(2)
|
||||
|
||||
return _reduce(loss, reduction)
|
||||
|
||||
|
||||
def cosine_similarity_loss(
|
||||
x1: mx.array,
|
||||
x2: mx.array,
|
||||
axis: int = 1,
|
||||
eps: float = 1e-8,
|
||||
reduction: str = "none",
|
||||
) -> mx.array:
|
||||
r"""
|
||||
Computes the cosine similarity between the two inputs.
|
||||
|
||||
The cosine similarity loss is given by
|
||||
|
||||
.. math::
|
||||
|
||||
\frac{x_1 \cdot x_2}{\max(\|x_1\| \cdot \|x_2\|, \epsilon)}
|
||||
|
||||
Args:
|
||||
x1 (mx.array): The first set of inputs.
|
||||
x2 (mx.array): The second set of inputs.
|
||||
axis (int, optional): The embedding axis. Default: ``1``.
|
||||
eps (float, optional): The minimum value of the denominator used for
|
||||
numerical stability. Default: ``1e-8``.
|
||||
reduction (str, optional): Specifies the reduction to apply to the output:
|
||||
``'none'`` | ``'mean'`` | ``'sum'``. Default: ``'none'``.
|
||||
|
||||
Returns:
|
||||
mx.array: The computed cosine similarity loss.
|
||||
"""
|
||||
x1_norm = mx.linalg.norm(x1, axis=axis)
|
||||
x2_norm = mx.linalg.norm(x2, axis=axis)
|
||||
|
||||
loss = mx.sum(x1 * x2, axis=axis) / mx.maximum(x1_norm * x2_norm, eps)
|
||||
|
||||
return _reduce(loss, reduction)
|
||||
|
@ -274,6 +274,31 @@ class TestLosses(mlx_tests.MLXTestCase):
|
||||
loss = nn.losses.log_cosh_loss(inputs, targets, reduction="mean")
|
||||
self.assertAlmostEqual(loss.item(), 0.433781, places=6)
|
||||
|
||||
def test_cosine_similarity_loss(self):
|
||||
embeddings1 = mx.array([[0.5, 0.5, 0.2, 0.9], [0.1, 0.3, 0.5, 0.5]])
|
||||
embeddings2 = mx.array([[0.6, 0.4, 0.3, 0.8], [0.2, 0.5, 0.6, 0.4]])
|
||||
|
||||
# Test with reduction 'none'
|
||||
losses_none = nn.losses.cosine_similarity_loss(
|
||||
embeddings1, embeddings2, reduction="none"
|
||||
)
|
||||
expected_none = mx.array([0.985344, 0.961074])
|
||||
self.assertTrue(mx.allclose(losses_none, expected_none))
|
||||
|
||||
# Test with reduction 'mean'
|
||||
losses_mean = nn.losses.cosine_similarity_loss(
|
||||
embeddings1, embeddings2, reduction="mean"
|
||||
)
|
||||
expected_mean = mx.mean(expected_none)
|
||||
self.assertTrue(mx.allclose(losses_mean, expected_mean))
|
||||
|
||||
# Test with reduction 'sum'
|
||||
losses_sum = nn.losses.cosine_similarity_loss(
|
||||
embeddings1, embeddings2, reduction="sum"
|
||||
)
|
||||
expected_sum = mx.sum(expected_none)
|
||||
self.assertTrue(mx.allclose(losses_sum, expected_sum))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
Loading…
Reference in New Issue
Block a user