rearrange

This commit is contained in:
NripeshN 2023-12-31 01:43:07 +05:30
parent e7114f4b91
commit 50aab87168

View File

@ -271,17 +271,6 @@ def triplet_loss(
return _reduce(loss, reduction) return _reduce(loss, reduction)
def _reduce(loss: mx.array, reduction: str = "none"):
if reduction == "mean":
return mx.mean(loss)
elif reduction == "sum":
return mx.sum(loss)
elif reduction == "none":
return loss
else:
raise ValueError("Invalid reduction. Must be 'none', 'mean', or 'sum'.")
def hinge_loss( def hinge_loss(
inputs: mx.array, targets: mx.array, reduction: str = "none" inputs: mx.array, targets: mx.array, reduction: str = "none"
) -> mx.array: ) -> mx.array:
@ -543,3 +532,14 @@ def tversky_loss(
loss = 1 - tversky_index loss = 1 - tversky_index
return _reduce(loss, reduction) return _reduce(loss, reduction)
def _reduce(loss: mx.array, reduction: str = "none"):
if reduction == "mean":
return mx.mean(loss)
elif reduction == "sum":
return mx.sum(loss)
elif reduction == "none":
return loss
else:
raise ValueError("Invalid reduction. Must be 'none', 'mean', or 'sum'.")