From d699cc1330c2e9827326abb8f6a77f50feae02a7 Mon Sep 17 00:00:00 2001 From: Chunyang Wen Date: Sat, 8 Mar 2025 09:23:04 +0800 Subject: [PATCH] Fix unreachable warning (#1939) * Fix unreachable warning * Update error message --- python/mlx/nn/losses.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/python/mlx/nn/losses.py b/python/mlx/nn/losses.py index ebf05d8ff..bccf45b16 100644 --- a/python/mlx/nn/losses.py +++ b/python/mlx/nn/losses.py @@ -1,7 +1,7 @@ # Copyright © 2023 Apple Inc. import math -from typing import Literal, Optional +from typing import Literal, Optional, get_args import mlx.core as mx @@ -9,14 +9,15 @@ Reduction = Literal["none", "mean", "sum"] def _reduce(loss: mx.array, reduction: Reduction = "none"): + if reduction not in get_args(Reduction): + raise ValueError(f"Invalid reduction. Must be one of {get_args(Reduction)}.") + if reduction == "mean": return mx.mean(loss) elif reduction == "sum": return mx.sum(loss) elif reduction == "none": return loss - else: - raise ValueError("Invalid reduction. Must be 'none', 'mean', or 'sum'.") def cross_entropy(