mirror of
https://github.com/ml-explore/mlx.git
synced 2025-08-21 20:38:07 +08:00
rearrange
This commit is contained in:
parent
e7114f4b91
commit
50aab87168
@ -271,17 +271,6 @@ def triplet_loss(
|
||||
return _reduce(loss, reduction)
|
||||
|
||||
|
||||
def _reduce(loss: mx.array, reduction: str = "none"):
|
||||
if reduction == "mean":
|
||||
return mx.mean(loss)
|
||||
elif reduction == "sum":
|
||||
return mx.sum(loss)
|
||||
elif reduction == "none":
|
||||
return loss
|
||||
else:
|
||||
raise ValueError("Invalid reduction. Must be 'none', 'mean', or 'sum'.")
|
||||
|
||||
|
||||
def hinge_loss(
|
||||
inputs: mx.array, targets: mx.array, reduction: str = "none"
|
||||
) -> mx.array:
|
||||
@ -543,3 +532,14 @@ def tversky_loss(
|
||||
loss = 1 - tversky_index
|
||||
|
||||
return _reduce(loss, reduction)
|
||||
|
||||
|
||||
def _reduce(loss: mx.array, reduction: str = "none"):
|
||||
if reduction == "mean":
|
||||
return mx.mean(loss)
|
||||
elif reduction == "sum":
|
||||
return mx.sum(loss)
|
||||
elif reduction == "none":
|
||||
return loss
|
||||
else:
|
||||
raise ValueError("Invalid reduction. Must be 'none', 'mean', or 'sum'.")
|
||||
|
Loading…
Reference in New Issue
Block a user