mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-21 12:38:10 +08:00
Additoinal losses (#336)
* cosine similarity loss --------- Co-authored-by: Awni Hannun <awni@apple.com> * Docstring nits
This commit is contained in:
@@ -6,6 +6,17 @@ import mlx.core as mx
|
||||
from mlx.nn.layers.base import Module
|
||||
|
||||
|
||||
def _reduce(loss: mx.array, reduction: str = "none"):
|
||||
if reduction == "mean":
|
||||
return mx.mean(loss)
|
||||
elif reduction == "sum":
|
||||
return mx.sum(loss)
|
||||
elif reduction == "none":
|
||||
return loss
|
||||
else:
|
||||
raise ValueError("Invalid reduction. Must be 'none', 'mean', or 'sum'.")
|
||||
|
||||
|
||||
def cross_entropy(
|
||||
logits: mx.array,
|
||||
targets: mx.array,
|
||||
@@ -272,17 +283,6 @@ def triplet_loss(
|
||||
return _reduce(loss, reduction)
|
||||
|
||||
|
||||
def _reduce(loss: mx.array, reduction: str = "none"):
|
||||
if reduction == "mean":
|
||||
return mx.mean(loss)
|
||||
elif reduction == "sum":
|
||||
return mx.sum(loss)
|
||||
elif reduction == "none":
|
||||
return loss
|
||||
else:
|
||||
raise ValueError("Invalid reduction. Must be 'none', 'mean', or 'sum'.")
|
||||
|
||||
|
||||
def hinge_loss(
|
||||
inputs: mx.array, targets: mx.array, reduction: str = "none"
|
||||
) -> mx.array:
|
||||
@@ -372,3 +372,39 @@ def log_cosh_loss(
|
||||
loss = mx.logaddexp(errors, -errors) - math.log(2)
|
||||
|
||||
return _reduce(loss, reduction)
|
||||
|
||||
|
||||
def cosine_similarity_loss(
|
||||
x1: mx.array,
|
||||
x2: mx.array,
|
||||
axis: int = 1,
|
||||
eps: float = 1e-8,
|
||||
reduction: str = "none",
|
||||
) -> mx.array:
|
||||
r"""
|
||||
Computes the cosine similarity between the two inputs.
|
||||
|
||||
The cosine similarity loss is given by
|
||||
|
||||
.. math::
|
||||
|
||||
\frac{x_1 \cdot x_2}{\max(\|x_1\| \cdot \|x_2\|, \epsilon)}
|
||||
|
||||
Args:
|
||||
x1 (mx.array): The first set of inputs.
|
||||
x2 (mx.array): The second set of inputs.
|
||||
axis (int, optional): The embedding axis. Default: ``1``.
|
||||
eps (float, optional): The minimum value of the denominator used for
|
||||
numerical stability. Default: ``1e-8``.
|
||||
reduction (str, optional): Specifies the reduction to apply to the output:
|
||||
``'none'`` | ``'mean'`` | ``'sum'``. Default: ``'none'``.
|
||||
|
||||
Returns:
|
||||
mx.array: The computed cosine similarity loss.
|
||||
"""
|
||||
x1_norm = mx.linalg.norm(x1, axis=axis)
|
||||
x2_norm = mx.linalg.norm(x2, axis=axis)
|
||||
|
||||
loss = mx.sum(x1 * x2, axis=axis) / mx.maximum(x1_norm * x2_norm, eps)
|
||||
|
||||
return _reduce(loss, reduction)
|
||||
|
Reference in New Issue
Block a user