diff --git a/python/mlx/nn/losses.py b/python/mlx/nn/losses.py index 487182772..fd2068a3d 100644 --- a/python/mlx/nn/losses.py +++ b/python/mlx/nn/losses.py @@ -205,3 +205,121 @@ def kl_div_loss( return _reduce(loss, reduction) + +def hinge_loss(predictions: mx.array, targets: mx.array, reduction: str = "none") -> mx.array: + """ + Computes the hinge loss between predictions and targets for binary classification tasks. + + Args: + predictions (mx.array): The predicted values. + targets (mx.array): The target values (-1 or 1). + reduction (str, optional): Specifies the reduction to apply to the output: + ``'none'`` | ``'mean'`` | ``'sum'``. Default: ``'none'``. + + Returns: + mx.array: The computed hinge loss. + """ + loss = mx.maximum(0, 1 - targets * predictions) + return _reduce(loss, reduction) + + +def huber_loss(predictions: mx.array, targets: mx.array, delta: float = 1.0, reduction: str = "none") -> mx.array: + """ + Computes the Huber loss, a robust loss function for regression tasks. + + Args: + predictions (mx.array): The predicted values. + targets (mx.array): The target values. + delta (float, optional): Threshold for switching between quadratic and linear losses. + reduction (str, optional): Specifies the reduction to apply to the output: + ``'none'`` | ``'mean'`` | ``'sum'``. Default: ``'none'``. + + Returns: + mx.array: The computed Huber loss. + """ + error = mx.abs(predictions - targets) + is_small_error = error < delta + squared_loss = 0.5 * mx.square(error) + linear_loss = delta * error - 0.5 * mx.square(delta) + loss = mx.where(is_small_error, squared_loss, linear_loss) + return _reduce(loss, reduction) + + +def dice_loss(inputs: mx.array, targets: mx.array, epsilon: float = 1e-6, reduction: str = "none") -> mx.array: + """ + Computes the Dice loss, useful for binary segmentation tasks. + + Args: + inputs (mx.array): Predicted probabilities for each pixel. + targets (mx.array): The target values (binary labels for each pixel). + epsilon (float, optional): Small constant for numerical stability. + reduction (str, optional): Specifies the reduction to apply to the output: + ``'none'`` | ``'mean'`` | ``'sum'``. Default: ``'none'``. + + Returns: + mx.array: The computed Dice loss. + """ + intersection = mx.sum(inputs * targets) + union = mx.sum(inputs) + mx.sum(targets) + dice_score = (2. * intersection + epsilon) / (union + epsilon) + return _reduce(1 - dice_score, reduction) + + +def focal_loss(inputs: mx.array, targets: mx.array, alpha: float = 0.25, gamma: float = 2.0, reduction: str = "none") -> mx.array: + """ + Computes the Focal loss, useful for handling class imbalance in binary classification tasks. + + Args: + inputs (mx.array): Predicted probabilities for the positive class. + targets (mx.array): The target values (binary). + alpha (float, optional): Weighting factor for positive examples. + gamma (float, optional): Modulating factor for hard examples. + reduction (str, optional): Specifies the reduction to apply to the output: + ``'none'`` | ``'mean'`` | ``'sum'``. Default: ``'none'``. + + Returns: + mx.array: The computed Focal loss. + """ + p_t = targets * inputs + (1 - targets) * (1 - inputs) + alpha_t = targets * alpha + (1 - targets) * (1 - alpha) + loss = -alpha_t * mx.pow((1 - p_t), gamma) * mx.log(p_t) + return _reduce(loss, reduction) + + +def contrastive_loss(embeddings1: mx.array, embeddings2: mx.array, targets: mx.array, margin: float = 1.0, reduction: str = "none") -> mx.array: + """ + Computes the Contrastive loss, useful for learning embeddings. + + Args: + embeddings1 (mx.array): Embeddings for the first set of samples. + embeddings2 (mx.array): Embeddings for the second set of samples. + targets (mx.array): The target values (binary labels indicating if pairs are similar or dissimilar). + margin (float, optional): Margin for dissimilar pairs. + reduction (str, optional): Specifies the reduction to apply to the output: + ``'none'`` | ``'mean'`` | ``'sum'``. Default: ``'none'``. + + Returns: + mx.array: The computed Contrastive loss. + """ + distances = mx.sqrt(mx.sum(mx.square(embeddings1 - embeddings2), axis=1)) + loss = targets * distances + (1 - targets) * mx.maximum(0, margin - distances) + return _reduce(loss, reduction) + + +def cosine_similarity_loss(embeddings1: mx.array, embeddings2: mx.array, targets: mx.array, reduction: str = "none") -> mx.array: + """ + Computes the Cosine Similarity loss, useful for tasks where the angle between embeddings is important. + + Args: + embeddings1 (mx.array): Embeddings for the first set of samples. + embeddings2 (mx.array): Embeddings for the second set of samples. + targets (mx.array): The target values (cosine similarity between embeddings). + reduction (str, optional): Specifies the reduction to apply to the output: + ``'none'`` | ``'mean'`` | ``'sum'``. Default: ``'none'``. + + Returns: + mx.array: The computed Cosine Similarity loss. + """ + cos_similarity = mx.sum(embeddings1 * embeddings2, axis=1) / (mx.norm(embeddings1, axis=1) * mx.norm(embeddings2, axis=1)) + loss = 1 - cos_similarity * targets + return _reduce(loss, reduction) \ No newline at end of file