Update losses.py

add cos sim
This commit is contained in:
YUN, Junwoo 2024-01-02 23:10:33 +09:00 committed by GitHub
parent 6de4c7a2a5
commit d66fec6e81
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -372,3 +372,49 @@ def log_cosh_loss(
loss = mx.logaddexp(errors, -errors) - math.log(2)
return _reduce(loss, reduction)
def cosine_similarity_loss(
embeddings1: mx.array,
embeddings2: mx.array,
targets: mx.array,
eps: float = 1e-8,
margin: float = 0.0,
reduction: str = "none",
) -> mx.array:
"""
Computes the Cosine Similarity loss.
This loss function calculates the cosine of the angle between two vectors and is often used in
tasks involving embeddings, such as natural language processing or image recognition.
.. math::
\text{Cosine Similarity Loss}(e_1, e_2, y) =
\begin{cases}
1 - \frac{e_1 \cdot e_2}{\|e_1\| \|e_2\|} & \text{if } y = 1 \\
\max(0, \frac{e_1 \cdot e_2}{\|e_1\| \|e_2\|} - \text{margin}) & \text{if } y = -1
\end{cases}
Args:
embeddings1 (mx.array): Embeddings for the first set of samples.
embeddings2 (mx.array): Embeddings for the second set of samples.
targets (mx.array): The target values (cosine similarity between embeddings).
margin (float, optional): Margin for dissimilar pairs. Default: ``0.0``.
reduction (str, optional): Specifies the reduction to apply to the output:
``'none'`` | ``'mean'`` | ``'sum'``. Default: ``'none'``.
Returns:
mx.array: The computed Cosine Similarity loss.
"""
embeddings1_norm = mx.sqrt(mx.sum(mx.square(embeddings1), axis=1) + eps)
embeddings2_norm = mx.sqrt(mx.sum(mx.square(embeddings2), axis=1) + eps)
cos_similarity = mx.sum(embeddings1 * embeddings2, axis=1) / (
embeddings1_norm * embeddings2_norm
)
loss = mx.where(
targets == 1, 1 - cos_similarity, mx.maximum(0, cos_similarity - margin)
)
return _reduce(loss, reduction)