diff --git a/python/mlx/nn/losses.py b/python/mlx/nn/losses.py index ebf05d8ff..bccf45b16 100644 --- a/python/mlx/nn/losses.py +++ b/python/mlx/nn/losses.py @@ -1,7 +1,7 @@ # Copyright © 2023 Apple Inc. import math -from typing import Literal, Optional +from typing import Literal, Optional, get_args import mlx.core as mx @@ -9,14 +9,15 @@ Reduction = Literal["none", "mean", "sum"] def _reduce(loss: mx.array, reduction: Reduction = "none"): + if reduction not in get_args(Reduction): + raise ValueError(f"Invalid reduction. Must be one of {get_args(Reduction)}.") + if reduction == "mean": return mx.mean(loss) elif reduction == "sum": return mx.sum(loss) elif reduction == "none": return loss - else: - raise ValueError("Invalid reduction. Must be 'none', 'mean', or 'sum'.") def cross_entropy(