mirror of
https://github.com/ml-explore/mlx.git
synced 2025-08-22 13:07:51 +08:00
add losses
This commit is contained in:
parent
f935cfb0fa
commit
bc2fa2a952
@ -205,3 +205,121 @@ def kl_div_loss(
|
||||
|
||||
return _reduce(loss, reduction)
|
||||
|
||||
|
||||
def hinge_loss(predictions: mx.array, targets: mx.array, reduction: str = "none") -> mx.array:
|
||||
"""
|
||||
Computes the hinge loss between predictions and targets for binary classification tasks.
|
||||
|
||||
Args:
|
||||
predictions (mx.array): The predicted values.
|
||||
targets (mx.array): The target values (-1 or 1).
|
||||
reduction (str, optional): Specifies the reduction to apply to the output:
|
||||
``'none'`` | ``'mean'`` | ``'sum'``. Default: ``'none'``.
|
||||
|
||||
Returns:
|
||||
mx.array: The computed hinge loss.
|
||||
"""
|
||||
loss = mx.maximum(0, 1 - targets * predictions)
|
||||
return _reduce(loss, reduction)
|
||||
|
||||
|
||||
def huber_loss(predictions: mx.array, targets: mx.array, delta: float = 1.0, reduction: str = "none") -> mx.array:
|
||||
"""
|
||||
Computes the Huber loss, a robust loss function for regression tasks.
|
||||
|
||||
Args:
|
||||
predictions (mx.array): The predicted values.
|
||||
targets (mx.array): The target values.
|
||||
delta (float, optional): Threshold for switching between quadratic and linear losses.
|
||||
reduction (str, optional): Specifies the reduction to apply to the output:
|
||||
``'none'`` | ``'mean'`` | ``'sum'``. Default: ``'none'``.
|
||||
|
||||
Returns:
|
||||
mx.array: The computed Huber loss.
|
||||
"""
|
||||
error = mx.abs(predictions - targets)
|
||||
is_small_error = error < delta
|
||||
squared_loss = 0.5 * mx.square(error)
|
||||
linear_loss = delta * error - 0.5 * mx.square(delta)
|
||||
loss = mx.where(is_small_error, squared_loss, linear_loss)
|
||||
return _reduce(loss, reduction)
|
||||
|
||||
|
||||
def dice_loss(inputs: mx.array, targets: mx.array, epsilon: float = 1e-6, reduction: str = "none") -> mx.array:
|
||||
"""
|
||||
Computes the Dice loss, useful for binary segmentation tasks.
|
||||
|
||||
Args:
|
||||
inputs (mx.array): Predicted probabilities for each pixel.
|
||||
targets (mx.array): The target values (binary labels for each pixel).
|
||||
epsilon (float, optional): Small constant for numerical stability.
|
||||
reduction (str, optional): Specifies the reduction to apply to the output:
|
||||
``'none'`` | ``'mean'`` | ``'sum'``. Default: ``'none'``.
|
||||
|
||||
Returns:
|
||||
mx.array: The computed Dice loss.
|
||||
"""
|
||||
intersection = mx.sum(inputs * targets)
|
||||
union = mx.sum(inputs) + mx.sum(targets)
|
||||
dice_score = (2. * intersection + epsilon) / (union + epsilon)
|
||||
return _reduce(1 - dice_score, reduction)
|
||||
|
||||
|
||||
def focal_loss(inputs: mx.array, targets: mx.array, alpha: float = 0.25, gamma: float = 2.0, reduction: str = "none") -> mx.array:
|
||||
"""
|
||||
Computes the Focal loss, useful for handling class imbalance in binary classification tasks.
|
||||
|
||||
Args:
|
||||
inputs (mx.array): Predicted probabilities for the positive class.
|
||||
targets (mx.array): The target values (binary).
|
||||
alpha (float, optional): Weighting factor for positive examples.
|
||||
gamma (float, optional): Modulating factor for hard examples.
|
||||
reduction (str, optional): Specifies the reduction to apply to the output:
|
||||
``'none'`` | ``'mean'`` | ``'sum'``. Default: ``'none'``.
|
||||
|
||||
Returns:
|
||||
mx.array: The computed Focal loss.
|
||||
"""
|
||||
p_t = targets * inputs + (1 - targets) * (1 - inputs)
|
||||
alpha_t = targets * alpha + (1 - targets) * (1 - alpha)
|
||||
loss = -alpha_t * mx.pow((1 - p_t), gamma) * mx.log(p_t)
|
||||
return _reduce(loss, reduction)
|
||||
|
||||
|
||||
def contrastive_loss(embeddings1: mx.array, embeddings2: mx.array, targets: mx.array, margin: float = 1.0, reduction: str = "none") -> mx.array:
|
||||
"""
|
||||
Computes the Contrastive loss, useful for learning embeddings.
|
||||
|
||||
Args:
|
||||
embeddings1 (mx.array): Embeddings for the first set of samples.
|
||||
embeddings2 (mx.array): Embeddings for the second set of samples.
|
||||
targets (mx.array): The target values (binary labels indicating if pairs are similar or dissimilar).
|
||||
margin (float, optional): Margin for dissimilar pairs.
|
||||
reduction (str, optional): Specifies the reduction to apply to the output:
|
||||
``'none'`` | ``'mean'`` | ``'sum'``. Default: ``'none'``.
|
||||
|
||||
Returns:
|
||||
mx.array: The computed Contrastive loss.
|
||||
"""
|
||||
distances = mx.sqrt(mx.sum(mx.square(embeddings1 - embeddings2), axis=1))
|
||||
loss = targets * distances + (1 - targets) * mx.maximum(0, margin - distances)
|
||||
return _reduce(loss, reduction)
|
||||
|
||||
|
||||
def cosine_similarity_loss(embeddings1: mx.array, embeddings2: mx.array, targets: mx.array, reduction: str = "none") -> mx.array:
|
||||
"""
|
||||
Computes the Cosine Similarity loss, useful for tasks where the angle between embeddings is important.
|
||||
|
||||
Args:
|
||||
embeddings1 (mx.array): Embeddings for the first set of samples.
|
||||
embeddings2 (mx.array): Embeddings for the second set of samples.
|
||||
targets (mx.array): The target values (cosine similarity between embeddings).
|
||||
reduction (str, optional): Specifies the reduction to apply to the output:
|
||||
``'none'`` | ``'mean'`` | ``'sum'``. Default: ``'none'``.
|
||||
|
||||
Returns:
|
||||
mx.array: The computed Cosine Similarity loss.
|
||||
"""
|
||||
cos_similarity = mx.sum(embeddings1 * embeddings2, axis=1) / (mx.norm(embeddings1, axis=1) * mx.norm(embeddings2, axis=1))
|
||||
loss = 1 - cos_similarity * targets
|
||||
return _reduce(loss, reduction)
|
Loading…
Reference in New Issue
Block a user