mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-24 09:21:16 +08:00
Fix unreachable warning (#1939)
* Fix unreachable warning * Update error message
This commit is contained in:
parent
c4230747a1
commit
d699cc1330
@ -1,7 +1,7 @@
|
||||
# Copyright © 2023 Apple Inc.
|
||||
|
||||
import math
|
||||
from typing import Literal, Optional
|
||||
from typing import Literal, Optional, get_args
|
||||
|
||||
import mlx.core as mx
|
||||
|
||||
@ -9,14 +9,15 @@ Reduction = Literal["none", "mean", "sum"]
|
||||
|
||||
|
||||
def _reduce(loss: mx.array, reduction: Reduction = "none"):
|
||||
if reduction not in get_args(Reduction):
|
||||
raise ValueError(f"Invalid reduction. Must be one of {get_args(Reduction)}.")
|
||||
|
||||
if reduction == "mean":
|
||||
return mx.mean(loss)
|
||||
elif reduction == "sum":
|
||||
return mx.sum(loss)
|
||||
elif reduction == "none":
|
||||
return loss
|
||||
else:
|
||||
raise ValueError("Invalid reduction. Must be 'none', 'mean', or 'sum'.")
|
||||
|
||||
|
||||
def cross_entropy(
|
||||
|
Loading…
Reference in New Issue
Block a user