mirror of
https://github.com/ml-explore/mlx.git
synced 2025-08-21 12:06:42 +08:00
Update losses.py
add cos sim
This commit is contained in:
parent
6de4c7a2a5
commit
d66fec6e81
@ -372,3 +372,49 @@ def log_cosh_loss(
|
||||
loss = mx.logaddexp(errors, -errors) - math.log(2)
|
||||
|
||||
return _reduce(loss, reduction)
|
||||
|
||||
|
||||
def cosine_similarity_loss(
|
||||
embeddings1: mx.array,
|
||||
embeddings2: mx.array,
|
||||
targets: mx.array,
|
||||
eps: float = 1e-8,
|
||||
margin: float = 0.0,
|
||||
reduction: str = "none",
|
||||
) -> mx.array:
|
||||
"""
|
||||
Computes the Cosine Similarity loss.
|
||||
|
||||
This loss function calculates the cosine of the angle between two vectors and is often used in
|
||||
tasks involving embeddings, such as natural language processing or image recognition.
|
||||
|
||||
.. math::
|
||||
|
||||
\text{Cosine Similarity Loss}(e_1, e_2, y) =
|
||||
\begin{cases}
|
||||
1 - \frac{e_1 \cdot e_2}{\|e_1\| \|e_2\|} & \text{if } y = 1 \\
|
||||
\max(0, \frac{e_1 \cdot e_2}{\|e_1\| \|e_2\|} - \text{margin}) & \text{if } y = -1
|
||||
\end{cases}
|
||||
|
||||
|
||||
Args:
|
||||
embeddings1 (mx.array): Embeddings for the first set of samples.
|
||||
embeddings2 (mx.array): Embeddings for the second set of samples.
|
||||
targets (mx.array): The target values (cosine similarity between embeddings).
|
||||
margin (float, optional): Margin for dissimilar pairs. Default: ``0.0``.
|
||||
reduction (str, optional): Specifies the reduction to apply to the output:
|
||||
``'none'`` | ``'mean'`` | ``'sum'``. Default: ``'none'``.
|
||||
|
||||
Returns:
|
||||
mx.array: The computed Cosine Similarity loss.
|
||||
"""
|
||||
embeddings1_norm = mx.sqrt(mx.sum(mx.square(embeddings1), axis=1) + eps)
|
||||
embeddings2_norm = mx.sqrt(mx.sum(mx.square(embeddings2), axis=1) + eps)
|
||||
|
||||
cos_similarity = mx.sum(embeddings1 * embeddings2, axis=1) / (
|
||||
embeddings1_norm * embeddings2_norm
|
||||
)
|
||||
loss = mx.where(
|
||||
targets == 1, 1 - cos_similarity, mx.maximum(0, cos_similarity - margin)
|
||||
)
|
||||
return _reduce(loss, reduction)
|
||||
|
Loading…
Reference in New Issue
Block a user