mirror of
https://github.com/ml-explore/mlx.git
synced 2025-08-22 13:07:51 +08:00
change class order
This commit is contained in:
parent
fb675de30d
commit
f935cfb0fa
@ -14,6 +14,17 @@ def _make_loss_module(f):
|
||||
return decorator
|
||||
|
||||
|
||||
def _reduce(loss: mx.array, reduction: str = "none"):
|
||||
if reduction == "mean":
|
||||
return mx.mean(loss)
|
||||
elif reduction == "sum":
|
||||
return mx.sum(loss)
|
||||
elif reduction == "none":
|
||||
return loss
|
||||
else:
|
||||
raise ValueError("Invalid reduction. Must be 'none', 'mean', or 'sum'.")
|
||||
|
||||
|
||||
def cross_entropy(
|
||||
logits: mx.array, targets: mx.array, axis: int = -1, reduction: str = "none"
|
||||
) -> mx.array:
|
||||
@ -194,13 +205,3 @@ def kl_div_loss(
|
||||
|
||||
return _reduce(loss, reduction)
|
||||
|
||||
|
||||
def _reduce(loss: mx.array, reduction: str = "none"):
|
||||
if reduction == "mean":
|
||||
return mx.mean(loss)
|
||||
elif reduction == "sum":
|
||||
return mx.sum(loss)
|
||||
elif reduction == "none":
|
||||
return loss
|
||||
else:
|
||||
raise ValueError("Invalid reduction. Must be 'none', 'mean', or 'sum'.")
|
||||
|
Loading…
Reference in New Issue
Block a user