Fix unreachable warning (#1939)

* Fix unreachable warning

* Update error message
This commit is contained in:
Chunyang Wen 2025-03-08 09:23:04 +08:00 committed by GitHub
parent c4230747a1
commit d699cc1330
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -1,7 +1,7 @@
# Copyright © 2023 Apple Inc. # Copyright © 2023 Apple Inc.
import math import math
from typing import Literal, Optional from typing import Literal, Optional, get_args
import mlx.core as mx import mlx.core as mx
@ -9,14 +9,15 @@ Reduction = Literal["none", "mean", "sum"]
def _reduce(loss: mx.array, reduction: Reduction = "none"): def _reduce(loss: mx.array, reduction: Reduction = "none"):
if reduction not in get_args(Reduction):
raise ValueError(f"Invalid reduction. Must be one of {get_args(Reduction)}.")
if reduction == "mean": if reduction == "mean":
return mx.mean(loss) return mx.mean(loss)
elif reduction == "sum": elif reduction == "sum":
return mx.sum(loss) return mx.sum(loss)
elif reduction == "none": elif reduction == "none":
return loss return loss
else:
raise ValueError("Invalid reduction. Must be 'none', 'mean', or 'sum'.")
def cross_entropy( def cross_entropy(