mirror of
https://github.com/ml-explore/mlx.git
synced 2025-08-22 21:16:47 +08:00
change class order
This commit is contained in:
parent
fb675de30d
commit
f935cfb0fa
@ -14,6 +14,17 @@ def _make_loss_module(f):
|
|||||||
return decorator
|
return decorator
|
||||||
|
|
||||||
|
|
||||||
|
def _reduce(loss: mx.array, reduction: str = "none"):
|
||||||
|
if reduction == "mean":
|
||||||
|
return mx.mean(loss)
|
||||||
|
elif reduction == "sum":
|
||||||
|
return mx.sum(loss)
|
||||||
|
elif reduction == "none":
|
||||||
|
return loss
|
||||||
|
else:
|
||||||
|
raise ValueError("Invalid reduction. Must be 'none', 'mean', or 'sum'.")
|
||||||
|
|
||||||
|
|
||||||
def cross_entropy(
|
def cross_entropy(
|
||||||
logits: mx.array, targets: mx.array, axis: int = -1, reduction: str = "none"
|
logits: mx.array, targets: mx.array, axis: int = -1, reduction: str = "none"
|
||||||
) -> mx.array:
|
) -> mx.array:
|
||||||
@ -194,13 +205,3 @@ def kl_div_loss(
|
|||||||
|
|
||||||
return _reduce(loss, reduction)
|
return _reduce(loss, reduction)
|
||||||
|
|
||||||
|
|
||||||
def _reduce(loss: mx.array, reduction: str = "none"):
|
|
||||||
if reduction == "mean":
|
|
||||||
return mx.mean(loss)
|
|
||||||
elif reduction == "sum":
|
|
||||||
return mx.sum(loss)
|
|
||||||
elif reduction == "none":
|
|
||||||
return loss
|
|
||||||
else:
|
|
||||||
raise ValueError("Invalid reduction. Must be 'none', 'mean', or 'sum'.")
|
|
||||||
|
Loading…
Reference in New Issue
Block a user