mirror of
https://github.com/ml-explore/mlx.git
synced 2025-08-21 20:46:46 +08:00
Update losses.py
add cos sim
This commit is contained in:
parent
6de4c7a2a5
commit
d66fec6e81
@ -372,3 +372,49 @@ def log_cosh_loss(
|
|||||||
loss = mx.logaddexp(errors, -errors) - math.log(2)
|
loss = mx.logaddexp(errors, -errors) - math.log(2)
|
||||||
|
|
||||||
return _reduce(loss, reduction)
|
return _reduce(loss, reduction)
|
||||||
|
|
||||||
|
|
||||||
|
def cosine_similarity_loss(
|
||||||
|
embeddings1: mx.array,
|
||||||
|
embeddings2: mx.array,
|
||||||
|
targets: mx.array,
|
||||||
|
eps: float = 1e-8,
|
||||||
|
margin: float = 0.0,
|
||||||
|
reduction: str = "none",
|
||||||
|
) -> mx.array:
|
||||||
|
"""
|
||||||
|
Computes the Cosine Similarity loss.
|
||||||
|
|
||||||
|
This loss function calculates the cosine of the angle between two vectors and is often used in
|
||||||
|
tasks involving embeddings, such as natural language processing or image recognition.
|
||||||
|
|
||||||
|
.. math::
|
||||||
|
|
||||||
|
\text{Cosine Similarity Loss}(e_1, e_2, y) =
|
||||||
|
\begin{cases}
|
||||||
|
1 - \frac{e_1 \cdot e_2}{\|e_1\| \|e_2\|} & \text{if } y = 1 \\
|
||||||
|
\max(0, \frac{e_1 \cdot e_2}{\|e_1\| \|e_2\|} - \text{margin}) & \text{if } y = -1
|
||||||
|
\end{cases}
|
||||||
|
|
||||||
|
|
||||||
|
Args:
|
||||||
|
embeddings1 (mx.array): Embeddings for the first set of samples.
|
||||||
|
embeddings2 (mx.array): Embeddings for the second set of samples.
|
||||||
|
targets (mx.array): The target values (cosine similarity between embeddings).
|
||||||
|
margin (float, optional): Margin for dissimilar pairs. Default: ``0.0``.
|
||||||
|
reduction (str, optional): Specifies the reduction to apply to the output:
|
||||||
|
``'none'`` | ``'mean'`` | ``'sum'``. Default: ``'none'``.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
mx.array: The computed Cosine Similarity loss.
|
||||||
|
"""
|
||||||
|
embeddings1_norm = mx.sqrt(mx.sum(mx.square(embeddings1), axis=1) + eps)
|
||||||
|
embeddings2_norm = mx.sqrt(mx.sum(mx.square(embeddings2), axis=1) + eps)
|
||||||
|
|
||||||
|
cos_similarity = mx.sum(embeddings1 * embeddings2, axis=1) / (
|
||||||
|
embeddings1_norm * embeddings2_norm
|
||||||
|
)
|
||||||
|
loss = mx.where(
|
||||||
|
targets == 1, 1 - cos_similarity, mx.maximum(0, cos_similarity - margin)
|
||||||
|
)
|
||||||
|
return _reduce(loss, reduction)
|
||||||
|
Loading…
Reference in New Issue
Block a user