mirror of
https://github.com/ml-explore/mlx.git
synced 2025-08-21 20:38:07 +08:00
rearrange
This commit is contained in:
parent
e7114f4b91
commit
50aab87168
@ -271,17 +271,6 @@ def triplet_loss(
|
|||||||
return _reduce(loss, reduction)
|
return _reduce(loss, reduction)
|
||||||
|
|
||||||
|
|
||||||
def _reduce(loss: mx.array, reduction: str = "none"):
|
|
||||||
if reduction == "mean":
|
|
||||||
return mx.mean(loss)
|
|
||||||
elif reduction == "sum":
|
|
||||||
return mx.sum(loss)
|
|
||||||
elif reduction == "none":
|
|
||||||
return loss
|
|
||||||
else:
|
|
||||||
raise ValueError("Invalid reduction. Must be 'none', 'mean', or 'sum'.")
|
|
||||||
|
|
||||||
|
|
||||||
def hinge_loss(
|
def hinge_loss(
|
||||||
inputs: mx.array, targets: mx.array, reduction: str = "none"
|
inputs: mx.array, targets: mx.array, reduction: str = "none"
|
||||||
) -> mx.array:
|
) -> mx.array:
|
||||||
@ -543,3 +532,14 @@ def tversky_loss(
|
|||||||
loss = 1 - tversky_index
|
loss = 1 - tversky_index
|
||||||
|
|
||||||
return _reduce(loss, reduction)
|
return _reduce(loss, reduction)
|
||||||
|
|
||||||
|
|
||||||
|
def _reduce(loss: mx.array, reduction: str = "none"):
|
||||||
|
if reduction == "mean":
|
||||||
|
return mx.mean(loss)
|
||||||
|
elif reduction == "sum":
|
||||||
|
return mx.sum(loss)
|
||||||
|
elif reduction == "none":
|
||||||
|
return loss
|
||||||
|
else:
|
||||||
|
raise ValueError("Invalid reduction. Must be 'none', 'mean', or 'sum'.")
|
||||||
|
Loading…
Reference in New Issue
Block a user