From 80c4630d264b04b8910f00ec8186bbcda71ea1e4 Mon Sep 17 00:00:00 2001 From: Jyun1998 Date: Tue, 2 Jan 2024 02:55:36 +0900 Subject: [PATCH] precommit --- python/mlx/nn/losses.py | 57 ++++++++++++++++++++++++++++++++--------- 1 file changed, 45 insertions(+), 12 deletions(-) diff --git a/python/mlx/nn/losses.py b/python/mlx/nn/losses.py index d395a0f04..f902af8af 100644 --- a/python/mlx/nn/losses.py +++ b/python/mlx/nn/losses.py @@ -206,7 +206,9 @@ def kl_div_loss( return _reduce(loss, reduction) -def hinge_loss(predictions: mx.array, targets: mx.array, reduction: str = "none") -> mx.array: +def hinge_loss( + predictions: mx.array, targets: mx.array, reduction: str = "none" +) -> mx.array: """ Computes the hinge loss between predictions and targets for binary classification tasks. @@ -223,7 +225,12 @@ def hinge_loss(predictions: mx.array, targets: mx.array, reduction: str = "none" return _reduce(loss, reduction) -def huber_loss(predictions: mx.array, targets: mx.array, delta: float = 1.0, reduction: str = "none") -> mx.array: +def huber_loss( + predictions: mx.array, + targets: mx.array, + delta: float = 1.0, + reduction: str = "none", +) -> mx.array: """ Computes the Huber loss, a robust loss function for regression tasks. @@ -240,12 +247,14 @@ def huber_loss(predictions: mx.array, targets: mx.array, delta: float = 1.0, red error = mx.abs(predictions - targets) is_small_error = error < delta squared_loss = 0.5 * mx.square(error) - linear_loss = delta * error - 0.5 * (delta ** 2) + linear_loss = delta * error - 0.5 * (delta**2) loss = mx.where(is_small_error, squared_loss, linear_loss) return _reduce(loss, reduction) -def dice_loss(inputs: mx.array, targets: mx.array, eps: float = 1e-6, reduction: str = "none") -> mx.array: +def dice_loss( + inputs: mx.array, targets: mx.array, eps: float = 1e-6, reduction: str = "none" +) -> mx.array: """ Computes the Dice loss, useful for binary segmentation tasks. @@ -261,17 +270,24 @@ def dice_loss(inputs: mx.array, targets: mx.array, eps: float = 1e-6, reduction: """ intersection = mx.sum(inputs * targets, axis=1) # Sum over the feature dimension union = mx.sum(inputs, axis=1) + mx.sum(targets, axis=1) - dice_score = (2. * intersection + eps) / (union + eps) + dice_score = (2.0 * intersection + eps) / (union + eps) loss = 1 - dice_score return _reduce(loss, reduction) -def focal_loss(inputs: mx.array, targets: mx.array, alpha: float = 0.25, gamma: float = 2.0, reduction: str = "none") -> mx.array: + +def focal_loss( + inputs: mx.array, + targets: mx.array, + alpha: float = 0.25, + gamma: float = 2.0, + reduction: str = "none", +) -> mx.array: """ Computes the Focal loss, useful for handling class imbalance in binary classification tasks. Args: inputs (mx.array): Predicted probabilities for the positive class. - targets (mx.array): The target values (binary). + targets (mx.array): The target values (binary). alpha (float, optional): Weighting factor for positive examples. Default: ``0.25``. gamma (float, optional): Modulating factor for hard examples. Default: ``2.0``. reduction (str, optional): Specifies the reduction to apply to the output: @@ -286,7 +302,13 @@ def focal_loss(inputs: mx.array, targets: mx.array, alpha: float = 0.25, gamma: return _reduce(loss, reduction) -def contrastive_loss(embeddings1: mx.array, embeddings2: mx.array, targets: mx.array, margin: float = 1.0, reduction: str = "none") -> mx.array: +def contrastive_loss( + embeddings1: mx.array, + embeddings2: mx.array, + targets: mx.array, + margin: float = 1.0, + reduction: str = "none", +) -> mx.array: """ Computes the Contrastive loss, useful for learning embeddings. @@ -306,7 +328,14 @@ def contrastive_loss(embeddings1: mx.array, embeddings2: mx.array, targets: mx.a return _reduce(loss, reduction) -def cosine_similarity_loss(embeddings1: mx.array, embeddings2: mx.array, targets: mx.array, eps: float=1e-8, margin: float=0.0, reduction: str = "none") -> mx.array: +def cosine_similarity_loss( + embeddings1: mx.array, + embeddings2: mx.array, + targets: mx.array, + eps: float = 1e-8, + margin: float = 0.0, + reduction: str = "none", +) -> mx.array: """ Computes the Cosine Similarity loss, useful for tasks where the angle between embeddings is important. @@ -324,6 +353,10 @@ def cosine_similarity_loss(embeddings1: mx.array, embeddings2: mx.array, targets embeddings1_norm = mx.sqrt(mx.sum(mx.square(embeddings1), axis=1) + eps) embeddings2_norm = mx.sqrt(mx.sum(mx.square(embeddings2), axis=1) + eps) - cos_similarity = mx.sum(embeddings1 * embeddings2, axis=1) / (embeddings1_norm * embeddings2_norm) - loss = mx.where(targets == 1, 1 - cos_similarity, mx.maximum(0, cos_similarity - margin)) - return _reduce(loss, reduction) \ No newline at end of file + cos_similarity = mx.sum(embeddings1 * embeddings2, axis=1) / ( + embeddings1_norm * embeddings2_norm + ) + loss = mx.where( + targets == 1, 1 - cos_similarity, mx.maximum(0, cos_similarity - margin) + ) + return _reduce(loss, reduction)