mirror of
				https://github.com/ml-explore/mlx.git
				synced 2025-10-31 16:21:27 +08:00 
			
		
		
		
	Fix unreachable warning (#1939)
* Fix unreachable warning * Update error message
This commit is contained in:
		| @@ -1,7 +1,7 @@ | ||||
| # Copyright © 2023 Apple Inc. | ||||
|  | ||||
| import math | ||||
| from typing import Literal, Optional | ||||
| from typing import Literal, Optional, get_args | ||||
|  | ||||
| import mlx.core as mx | ||||
|  | ||||
| @@ -9,14 +9,15 @@ Reduction = Literal["none", "mean", "sum"] | ||||
|  | ||||
|  | ||||
| def _reduce(loss: mx.array, reduction: Reduction = "none"): | ||||
|     if reduction not in get_args(Reduction): | ||||
|         raise ValueError(f"Invalid reduction. Must be one of {get_args(Reduction)}.") | ||||
|  | ||||
|     if reduction == "mean": | ||||
|         return mx.mean(loss) | ||||
|     elif reduction == "sum": | ||||
|         return mx.sum(loss) | ||||
|     elif reduction == "none": | ||||
|         return loss | ||||
|     else: | ||||
|         raise ValueError("Invalid reduction. Must be 'none', 'mean', or 'sum'.") | ||||
|  | ||||
|  | ||||
| def cross_entropy( | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 Chunyang Wen
					Chunyang Wen