diff --git a/python/mlx/nn/losses.py b/python/mlx/nn/losses.py index b495bc61a..fe8b5e4b9 100644 --- a/python/mlx/nn/losses.py +++ b/python/mlx/nn/losses.py @@ -271,17 +271,6 @@ def triplet_loss( return _reduce(loss, reduction) -def _reduce(loss: mx.array, reduction: str = "none"): - if reduction == "mean": - return mx.mean(loss) - elif reduction == "sum": - return mx.sum(loss) - elif reduction == "none": - return loss - else: - raise ValueError("Invalid reduction. Must be 'none', 'mean', or 'sum'.") - - def hinge_loss( inputs: mx.array, targets: mx.array, reduction: str = "none" ) -> mx.array: @@ -543,3 +532,14 @@ def tversky_loss( loss = 1 - tversky_index return _reduce(loss, reduction) + + +def _reduce(loss: mx.array, reduction: str = "none"): + if reduction == "mean": + return mx.mean(loss) + elif reduction == "sum": + return mx.sum(loss) + elif reduction == "none": + return loss + else: + raise ValueError("Invalid reduction. Must be 'none', 'mean', or 'sum'.")