Additoinal losses (#336)

* cosine similarity loss

---------

Co-authored-by: Awni Hannun <awni@apple.com>

* Docstring nits
This commit is contained in:
YUN, Junwoo 2024-01-09 07:01:13 +09:00 committed by GitHub
parent 432ee5650b
commit 0b8aeddac6
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 74 additions and 12 deletions

View File

@ -19,4 +19,5 @@ Loss Functions
triplet_loss
hinge_loss
huber_loss
log_cosh_loss
log_cosh_loss
cosine_similarity_loss

View File

@ -6,6 +6,17 @@ import mlx.core as mx
from mlx.nn.layers.base import Module
def _reduce(loss: mx.array, reduction: str = "none"):
if reduction == "mean":
return mx.mean(loss)
elif reduction == "sum":
return mx.sum(loss)
elif reduction == "none":
return loss
else:
raise ValueError("Invalid reduction. Must be 'none', 'mean', or 'sum'.")
def cross_entropy(
logits: mx.array,
targets: mx.array,
@ -272,17 +283,6 @@ def triplet_loss(
return _reduce(loss, reduction)
def _reduce(loss: mx.array, reduction: str = "none"):
if reduction == "mean":
return mx.mean(loss)
elif reduction == "sum":
return mx.sum(loss)
elif reduction == "none":
return loss
else:
raise ValueError("Invalid reduction. Must be 'none', 'mean', or 'sum'.")
def hinge_loss(
inputs: mx.array, targets: mx.array, reduction: str = "none"
) -> mx.array:
@ -372,3 +372,39 @@ def log_cosh_loss(
loss = mx.logaddexp(errors, -errors) - math.log(2)
return _reduce(loss, reduction)
def cosine_similarity_loss(
x1: mx.array,
x2: mx.array,
axis: int = 1,
eps: float = 1e-8,
reduction: str = "none",
) -> mx.array:
r"""
Computes the cosine similarity between the two inputs.
The cosine similarity loss is given by
.. math::
\frac{x_1 \cdot x_2}{\max(\|x_1\| \cdot \|x_2\|, \epsilon)}
Args:
x1 (mx.array): The first set of inputs.
x2 (mx.array): The second set of inputs.
axis (int, optional): The embedding axis. Default: ``1``.
eps (float, optional): The minimum value of the denominator used for
numerical stability. Default: ``1e-8``.
reduction (str, optional): Specifies the reduction to apply to the output:
``'none'`` | ``'mean'`` | ``'sum'``. Default: ``'none'``.
Returns:
mx.array: The computed cosine similarity loss.
"""
x1_norm = mx.linalg.norm(x1, axis=axis)
x2_norm = mx.linalg.norm(x2, axis=axis)
loss = mx.sum(x1 * x2, axis=axis) / mx.maximum(x1_norm * x2_norm, eps)
return _reduce(loss, reduction)

View File

@ -274,6 +274,31 @@ class TestLosses(mlx_tests.MLXTestCase):
loss = nn.losses.log_cosh_loss(inputs, targets, reduction="mean")
self.assertAlmostEqual(loss.item(), 0.433781, places=6)
def test_cosine_similarity_loss(self):
embeddings1 = mx.array([[0.5, 0.5, 0.2, 0.9], [0.1, 0.3, 0.5, 0.5]])
embeddings2 = mx.array([[0.6, 0.4, 0.3, 0.8], [0.2, 0.5, 0.6, 0.4]])
# Test with reduction 'none'
losses_none = nn.losses.cosine_similarity_loss(
embeddings1, embeddings2, reduction="none"
)
expected_none = mx.array([0.985344, 0.961074])
self.assertTrue(mx.allclose(losses_none, expected_none))
# Test with reduction 'mean'
losses_mean = nn.losses.cosine_similarity_loss(
embeddings1, embeddings2, reduction="mean"
)
expected_mean = mx.mean(expected_none)
self.assertTrue(mx.allclose(losses_mean, expected_mean))
# Test with reduction 'sum'
losses_sum = nn.losses.cosine_similarity_loss(
embeddings1, embeddings2, reduction="sum"
)
expected_sum = mx.sum(expected_none)
self.assertTrue(mx.allclose(losses_sum, expected_sum))
if __name__ == "__main__":
unittest.main()