change class order

This commit is contained in:
junwoo-yun 2023-12-31 23:20:34 +09:00
parent fb675de30d
commit f935cfb0fa

View File

@ -14,6 +14,17 @@ def _make_loss_module(f):
return decorator
def _reduce(loss: mx.array, reduction: str = "none"):
if reduction == "mean":
return mx.mean(loss)
elif reduction == "sum":
return mx.sum(loss)
elif reduction == "none":
return loss
else:
raise ValueError("Invalid reduction. Must be 'none', 'mean', or 'sum'.")
def cross_entropy(
logits: mx.array, targets: mx.array, axis: int = -1, reduction: str = "none"
) -> mx.array:
@ -194,13 +205,3 @@ def kl_div_loss(
return _reduce(loss, reduction)
def _reduce(loss: mx.array, reduction: str = "none"):
if reduction == "mean":
return mx.mean(loss)
elif reduction == "sum":
return mx.sum(loss)
elif reduction == "none":
return loss
else:
raise ValueError("Invalid reduction. Must be 'none', 'mean', or 'sum'.")