From d66fec6e81d2e6acc448636534dbc2949d56236d Mon Sep 17 00:00:00 2001 From: "YUN, Junwoo" <61632100+Jyun1998@users.noreply.github.com> Date: Tue, 2 Jan 2024 23:10:33 +0900 Subject: [PATCH] Update losses.py add cos sim --- python/mlx/nn/losses.py | 46 +++++++++++++++++++++++++++++++++++++++++ 1 file changed, 46 insertions(+) diff --git a/python/mlx/nn/losses.py b/python/mlx/nn/losses.py index 1bb99d215..8898ebd09 100644 --- a/python/mlx/nn/losses.py +++ b/python/mlx/nn/losses.py @@ -372,3 +372,49 @@ def log_cosh_loss( loss = mx.logaddexp(errors, -errors) - math.log(2) return _reduce(loss, reduction) + + +def cosine_similarity_loss( + embeddings1: mx.array, + embeddings2: mx.array, + targets: mx.array, + eps: float = 1e-8, + margin: float = 0.0, + reduction: str = "none", +) -> mx.array: + """ + Computes the Cosine Similarity loss. + + This loss function calculates the cosine of the angle between two vectors and is often used in + tasks involving embeddings, such as natural language processing or image recognition. + + .. math:: + + \text{Cosine Similarity Loss}(e_1, e_2, y) = + \begin{cases} + 1 - \frac{e_1 \cdot e_2}{\|e_1\| \|e_2\|} & \text{if } y = 1 \\ + \max(0, \frac{e_1 \cdot e_2}{\|e_1\| \|e_2\|} - \text{margin}) & \text{if } y = -1 + \end{cases} + + + Args: + embeddings1 (mx.array): Embeddings for the first set of samples. + embeddings2 (mx.array): Embeddings for the second set of samples. + targets (mx.array): The target values (cosine similarity between embeddings). + margin (float, optional): Margin for dissimilar pairs. Default: ``0.0``. + reduction (str, optional): Specifies the reduction to apply to the output: + ``'none'`` | ``'mean'`` | ``'sum'``. Default: ``'none'``. + + Returns: + mx.array: The computed Cosine Similarity loss. + """ + embeddings1_norm = mx.sqrt(mx.sum(mx.square(embeddings1), axis=1) + eps) + embeddings2_norm = mx.sqrt(mx.sum(mx.square(embeddings2), axis=1) + eps) + + cos_similarity = mx.sum(embeddings1 * embeddings2, axis=1) / ( + embeddings1_norm * embeddings2_norm + ) + loss = mx.where( + targets == 1, 1 - cos_similarity, mx.maximum(0, cos_similarity - margin) + ) + return _reduce(loss, reduction)