diff --git a/python/mlx/nn/layers/base.py b/python/mlx/nn/layers/base.py index f141cfc0f..f24bd1806 100644 --- a/python/mlx/nn/layers/base.py +++ b/python/mlx/nn/layers/base.py @@ -598,9 +598,7 @@ class Module(dict): parameters to the new dtype. """ if predicate is None: - - def predicate(_): - return True + predicate = lambda _: True self.apply(lambda x: x.astype(dtype) if predicate(x.dtype) else x)