mirror of
https://github.com/ml-explore/mlx.git
synced 2025-08-23 05:47:46 +08:00
precommit
This commit is contained in:
parent
19bc6a391a
commit
80c4630d26
@ -206,7 +206,9 @@ def kl_div_loss(
|
|||||||
return _reduce(loss, reduction)
|
return _reduce(loss, reduction)
|
||||||
|
|
||||||
|
|
||||||
def hinge_loss(predictions: mx.array, targets: mx.array, reduction: str = "none") -> mx.array:
|
def hinge_loss(
|
||||||
|
predictions: mx.array, targets: mx.array, reduction: str = "none"
|
||||||
|
) -> mx.array:
|
||||||
"""
|
"""
|
||||||
Computes the hinge loss between predictions and targets for binary classification tasks.
|
Computes the hinge loss between predictions and targets for binary classification tasks.
|
||||||
|
|
||||||
@ -223,7 +225,12 @@ def hinge_loss(predictions: mx.array, targets: mx.array, reduction: str = "none"
|
|||||||
return _reduce(loss, reduction)
|
return _reduce(loss, reduction)
|
||||||
|
|
||||||
|
|
||||||
def huber_loss(predictions: mx.array, targets: mx.array, delta: float = 1.0, reduction: str = "none") -> mx.array:
|
def huber_loss(
|
||||||
|
predictions: mx.array,
|
||||||
|
targets: mx.array,
|
||||||
|
delta: float = 1.0,
|
||||||
|
reduction: str = "none",
|
||||||
|
) -> mx.array:
|
||||||
"""
|
"""
|
||||||
Computes the Huber loss, a robust loss function for regression tasks.
|
Computes the Huber loss, a robust loss function for regression tasks.
|
||||||
|
|
||||||
@ -240,12 +247,14 @@ def huber_loss(predictions: mx.array, targets: mx.array, delta: float = 1.0, red
|
|||||||
error = mx.abs(predictions - targets)
|
error = mx.abs(predictions - targets)
|
||||||
is_small_error = error < delta
|
is_small_error = error < delta
|
||||||
squared_loss = 0.5 * mx.square(error)
|
squared_loss = 0.5 * mx.square(error)
|
||||||
linear_loss = delta * error - 0.5 * (delta ** 2)
|
linear_loss = delta * error - 0.5 * (delta**2)
|
||||||
loss = mx.where(is_small_error, squared_loss, linear_loss)
|
loss = mx.where(is_small_error, squared_loss, linear_loss)
|
||||||
return _reduce(loss, reduction)
|
return _reduce(loss, reduction)
|
||||||
|
|
||||||
|
|
||||||
def dice_loss(inputs: mx.array, targets: mx.array, eps: float = 1e-6, reduction: str = "none") -> mx.array:
|
def dice_loss(
|
||||||
|
inputs: mx.array, targets: mx.array, eps: float = 1e-6, reduction: str = "none"
|
||||||
|
) -> mx.array:
|
||||||
"""
|
"""
|
||||||
Computes the Dice loss, useful for binary segmentation tasks.
|
Computes the Dice loss, useful for binary segmentation tasks.
|
||||||
|
|
||||||
@ -261,11 +270,18 @@ def dice_loss(inputs: mx.array, targets: mx.array, eps: float = 1e-6, reduction:
|
|||||||
"""
|
"""
|
||||||
intersection = mx.sum(inputs * targets, axis=1) # Sum over the feature dimension
|
intersection = mx.sum(inputs * targets, axis=1) # Sum over the feature dimension
|
||||||
union = mx.sum(inputs, axis=1) + mx.sum(targets, axis=1)
|
union = mx.sum(inputs, axis=1) + mx.sum(targets, axis=1)
|
||||||
dice_score = (2. * intersection + eps) / (union + eps)
|
dice_score = (2.0 * intersection + eps) / (union + eps)
|
||||||
loss = 1 - dice_score
|
loss = 1 - dice_score
|
||||||
return _reduce(loss, reduction)
|
return _reduce(loss, reduction)
|
||||||
|
|
||||||
def focal_loss(inputs: mx.array, targets: mx.array, alpha: float = 0.25, gamma: float = 2.0, reduction: str = "none") -> mx.array:
|
|
||||||
|
def focal_loss(
|
||||||
|
inputs: mx.array,
|
||||||
|
targets: mx.array,
|
||||||
|
alpha: float = 0.25,
|
||||||
|
gamma: float = 2.0,
|
||||||
|
reduction: str = "none",
|
||||||
|
) -> mx.array:
|
||||||
"""
|
"""
|
||||||
Computes the Focal loss, useful for handling class imbalance in binary classification tasks.
|
Computes the Focal loss, useful for handling class imbalance in binary classification tasks.
|
||||||
|
|
||||||
@ -286,7 +302,13 @@ def focal_loss(inputs: mx.array, targets: mx.array, alpha: float = 0.25, gamma:
|
|||||||
return _reduce(loss, reduction)
|
return _reduce(loss, reduction)
|
||||||
|
|
||||||
|
|
||||||
def contrastive_loss(embeddings1: mx.array, embeddings2: mx.array, targets: mx.array, margin: float = 1.0, reduction: str = "none") -> mx.array:
|
def contrastive_loss(
|
||||||
|
embeddings1: mx.array,
|
||||||
|
embeddings2: mx.array,
|
||||||
|
targets: mx.array,
|
||||||
|
margin: float = 1.0,
|
||||||
|
reduction: str = "none",
|
||||||
|
) -> mx.array:
|
||||||
"""
|
"""
|
||||||
Computes the Contrastive loss, useful for learning embeddings.
|
Computes the Contrastive loss, useful for learning embeddings.
|
||||||
|
|
||||||
@ -306,7 +328,14 @@ def contrastive_loss(embeddings1: mx.array, embeddings2: mx.array, targets: mx.a
|
|||||||
return _reduce(loss, reduction)
|
return _reduce(loss, reduction)
|
||||||
|
|
||||||
|
|
||||||
def cosine_similarity_loss(embeddings1: mx.array, embeddings2: mx.array, targets: mx.array, eps: float=1e-8, margin: float=0.0, reduction: str = "none") -> mx.array:
|
def cosine_similarity_loss(
|
||||||
|
embeddings1: mx.array,
|
||||||
|
embeddings2: mx.array,
|
||||||
|
targets: mx.array,
|
||||||
|
eps: float = 1e-8,
|
||||||
|
margin: float = 0.0,
|
||||||
|
reduction: str = "none",
|
||||||
|
) -> mx.array:
|
||||||
"""
|
"""
|
||||||
Computes the Cosine Similarity loss, useful for tasks where the angle between embeddings is important.
|
Computes the Cosine Similarity loss, useful for tasks where the angle between embeddings is important.
|
||||||
|
|
||||||
@ -324,6 +353,10 @@ def cosine_similarity_loss(embeddings1: mx.array, embeddings2: mx.array, targets
|
|||||||
embeddings1_norm = mx.sqrt(mx.sum(mx.square(embeddings1), axis=1) + eps)
|
embeddings1_norm = mx.sqrt(mx.sum(mx.square(embeddings1), axis=1) + eps)
|
||||||
embeddings2_norm = mx.sqrt(mx.sum(mx.square(embeddings2), axis=1) + eps)
|
embeddings2_norm = mx.sqrt(mx.sum(mx.square(embeddings2), axis=1) + eps)
|
||||||
|
|
||||||
cos_similarity = mx.sum(embeddings1 * embeddings2, axis=1) / (embeddings1_norm * embeddings2_norm)
|
cos_similarity = mx.sum(embeddings1 * embeddings2, axis=1) / (
|
||||||
loss = mx.where(targets == 1, 1 - cos_similarity, mx.maximum(0, cos_similarity - margin))
|
embeddings1_norm * embeddings2_norm
|
||||||
|
)
|
||||||
|
loss = mx.where(
|
||||||
|
targets == 1, 1 - cos_similarity, mx.maximum(0, cos_similarity - margin)
|
||||||
|
)
|
||||||
return _reduce(loss, reduction)
|
return _reduce(loss, reduction)
|
Loading…
Reference in New Issue
Block a user