From f935cfb0fa87b4e3cd33b7e956380168f4cf834b Mon Sep 17 00:00:00 2001 From: junwoo-yun Date: Sun, 31 Dec 2023 23:20:34 +0900 Subject: [PATCH] change class order --- python/mlx/nn/losses.py | 21 +++++++++++---------- 1 file changed, 11 insertions(+), 10 deletions(-) diff --git a/python/mlx/nn/losses.py b/python/mlx/nn/losses.py index b9d35d9b5..487182772 100644 --- a/python/mlx/nn/losses.py +++ b/python/mlx/nn/losses.py @@ -14,6 +14,17 @@ def _make_loss_module(f): return decorator +def _reduce(loss: mx.array, reduction: str = "none"): + if reduction == "mean": + return mx.mean(loss) + elif reduction == "sum": + return mx.sum(loss) + elif reduction == "none": + return loss + else: + raise ValueError("Invalid reduction. Must be 'none', 'mean', or 'sum'.") + + def cross_entropy( logits: mx.array, targets: mx.array, axis: int = -1, reduction: str = "none" ) -> mx.array: @@ -194,13 +205,3 @@ def kl_div_loss( return _reduce(loss, reduction) - -def _reduce(loss: mx.array, reduction: str = "none"): - if reduction == "mean": - return mx.mean(loss) - elif reduction == "sum": - return mx.sum(loss) - elif reduction == "none": - return loss - else: - raise ValueError("Invalid reduction. Must be 'none', 'mean', or 'sum'.")