diff --git a/llms/mlx_lm/convert.py b/llms/mlx_lm/convert.py index 86a96447..f268913b 100644 --- a/llms/mlx_lm/convert.py +++ b/llms/mlx_lm/convert.py @@ -1,27 +1,23 @@ # Copyright © 2023-2024 Apple Inc. import argparse -from enum import Enum -from .utils import convert, mixed_2_6, mixed_3_6 +from . import utils +from .utils import convert - -class MixedQuants(Enum): - mixed_3_6 = "mixed_3_6" - mixed_2_6 = "mixed_2_6" - - @classmethod - def recipe_names(cls): - return [member.name for member in cls] +QUANT_RECIPES = [ + "mixed_2_6", + "mixed_3_6", +] def quant_args(arg): - try: - return MixedQuants[arg].value - except KeyError: + if arg not in QUANT_RECIPES: raise argparse.ArgumentTypeError( - f"Invalid q-recipe {arg!r}. Choose from: {MixedQuants.recipe_names()}" + f"Invalid q-recipe {arg!r}. Choose from: {QUANT_RECIPES}" ) + else: + return getattr(utils, arg) def configure_parser() -> argparse.ArgumentParser: @@ -50,7 +46,7 @@ def configure_parser() -> argparse.ArgumentParser: ) parser.add_argument( "--quant-predicate", - help=f"Mixed-bit quantization recipe. Choices: {MixedQuants.recipe_names()}", + help=f"Mixed-bit quantization recipe. Choices: {QUANT_RECIPES}", type=quant_args, required=False, )