diff --git a/benchmarks/python/gather_bench.py b/benchmarks/python/gather_bench.py index e000841d2..ae6fb8f5f 100644 --- a/benchmarks/python/gather_bench.py +++ b/benchmarks/python/gather_bench.py @@ -1,7 +1,6 @@ # Copyright © 2023-2024 Apple Inc. import argparse -from time import time import mlx.core as mx import torch diff --git a/python/mlx/nn/layers/base.py b/python/mlx/nn/layers/base.py index f141cfc0f..b35c58478 100644 --- a/python/mlx/nn/layers/base.py +++ b/python/mlx/nn/layers/base.py @@ -81,7 +81,7 @@ class Module(dict): """ return self - def _extra_repr(self): + def _extra_repr(self) -> str: return "" def __repr__(self): @@ -210,7 +210,7 @@ class Module(dict): mx.save_safetensors(file, params_dict) else: raise ValueError( - "Unsupported file extension. Use '.npz' or '.safetensors'." + f"Unsupported file extension for {file}. Use '.npz' or '.safetensors'." ) @staticmethod @@ -598,9 +598,7 @@ class Module(dict): parameters to the new dtype. """ if predicate is None: - - def predicate(_): - return True + predicate = lambda _: True self.apply(lambda x: x.astype(dtype) if predicate(x.dtype) else x) diff --git a/python/mlx/nn/losses.py b/python/mlx/nn/losses.py index ebf05d8ff..bccf45b16 100644 --- a/python/mlx/nn/losses.py +++ b/python/mlx/nn/losses.py @@ -1,7 +1,7 @@ # Copyright © 2023 Apple Inc. import math -from typing import Literal, Optional +from typing import Literal, Optional, get_args import mlx.core as mx @@ -9,14 +9,15 @@ Reduction = Literal["none", "mean", "sum"] def _reduce(loss: mx.array, reduction: Reduction = "none"): + if reduction not in get_args(Reduction): + raise ValueError(f"Invalid reduction. Must be one of {get_args(Reduction)}.") + if reduction == "mean": return mx.mean(loss) elif reduction == "sum": return mx.sum(loss) elif reduction == "none": return loss - else: - raise ValueError("Invalid reduction. Must be 'none', 'mean', or 'sum'.") def cross_entropy(