From d699cc1330c2e9827326abb8f6a77f50feae02a7 Mon Sep 17 00:00:00 2001 From: Chunyang Wen Date: Sat, 8 Mar 2025 09:23:04 +0800 Subject: [PATCH 1/5] Fix unreachable warning (#1939) * Fix unreachable warning * Update error message --- python/mlx/nn/losses.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/python/mlx/nn/losses.py b/python/mlx/nn/losses.py index ebf05d8ff..bccf45b16 100644 --- a/python/mlx/nn/losses.py +++ b/python/mlx/nn/losses.py @@ -1,7 +1,7 @@ # Copyright © 2023 Apple Inc. import math -from typing import Literal, Optional +from typing import Literal, Optional, get_args import mlx.core as mx @@ -9,14 +9,15 @@ Reduction = Literal["none", "mean", "sum"] def _reduce(loss: mx.array, reduction: Reduction = "none"): + if reduction not in get_args(Reduction): + raise ValueError(f"Invalid reduction. Must be one of {get_args(Reduction)}.") + if reduction == "mean": return mx.mean(loss) elif reduction == "sum": return mx.sum(loss) elif reduction == "none": return loss - else: - raise ValueError("Invalid reduction. Must be 'none', 'mean', or 'sum'.") def cross_entropy( From 5db90ce8223eab7f09dde21eca4120c2dec0ff21 Mon Sep 17 00:00:00 2001 From: Chunyang Wen Date: Sun, 9 Mar 2025 06:50:39 +0800 Subject: [PATCH 2/5] Fix obsured warning (#1944) --- python/mlx/nn/layers/base.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/python/mlx/nn/layers/base.py b/python/mlx/nn/layers/base.py index f141cfc0f..f24bd1806 100644 --- a/python/mlx/nn/layers/base.py +++ b/python/mlx/nn/layers/base.py @@ -598,9 +598,7 @@ class Module(dict): parameters to the new dtype. """ if predicate is None: - - def predicate(_): - return True + predicate = lambda _: True self.apply(lambda x: x.astype(dtype) if predicate(x.dtype) else x) From d14c9fe7ea6e7390513ad955d595d3c20fc9ab21 Mon Sep 17 00:00:00 2001 From: Chunyang Wen Date: Sun, 9 Mar 2025 06:51:04 +0800 Subject: [PATCH 3/5] Add file info when raising errors in save (#1943) --- python/mlx/nn/layers/base.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/mlx/nn/layers/base.py b/python/mlx/nn/layers/base.py index f24bd1806..3a696df71 100644 --- a/python/mlx/nn/layers/base.py +++ b/python/mlx/nn/layers/base.py @@ -210,7 +210,7 @@ class Module(dict): mx.save_safetensors(file, params_dict) else: raise ValueError( - "Unsupported file extension. Use '.npz' or '.safetensors'." + f"Unsupported file extension for {file}. Use '.npz' or '.safetensors'." ) @staticmethod From 048805ad2c7d4c325ef99e40a2fd6bea9464db5f Mon Sep 17 00:00:00 2001 From: Chunyang Wen Date: Mon, 10 Mar 2025 21:05:26 +0800 Subject: [PATCH 4/5] Remove unused modules (#1949) --- benchmarks/python/gather_bench.py | 1 - 1 file changed, 1 deletion(-) diff --git a/benchmarks/python/gather_bench.py b/benchmarks/python/gather_bench.py index e000841d2..ae6fb8f5f 100644 --- a/benchmarks/python/gather_bench.py +++ b/benchmarks/python/gather_bench.py @@ -1,7 +1,6 @@ # Copyright © 2023-2024 Apple Inc. import argparse -from time import time import mlx.core as mx import torch From cffceda6ee6d5dcd91bae28741fd858699dbe67d Mon Sep 17 00:00:00 2001 From: Chunyang Wen Date: Mon, 10 Mar 2025 21:05:36 +0800 Subject: [PATCH 5/5] Add type hint for _extra_repr (#1948) --- python/mlx/nn/layers/base.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/mlx/nn/layers/base.py b/python/mlx/nn/layers/base.py index 3a696df71..b35c58478 100644 --- a/python/mlx/nn/layers/base.py +++ b/python/mlx/nn/layers/base.py @@ -81,7 +81,7 @@ class Module(dict): """ return self - def _extra_repr(self): + def _extra_repr(self) -> str: return "" def __repr__(self):