diff --git a/llms/mlx_lm/lora.py b/llms/mlx_lm/lora.py index ce4d1854..f5418bbc 100644 --- a/llms/mlx_lm/lora.py +++ b/llms/mlx_lm/lora.py @@ -7,13 +7,9 @@ import mlx.optimizers as optim import numpy as np from mlx.utils import tree_flatten -from .models import llama, mixtral, phi2 from .tuner.lora import LoRALinear from .tuner.trainer import TrainingArgs, evaluate, train -from .utils import generate, load - -SUPPORTED_MODELS = [llama.Model, mixtral.Model, phi2.Model] - +from .utils import generate, load, LORA_SUPPORTED_MODELS def build_parser(): parser = argparse.ArgumentParser(description="LoRA or QLoRA finetuning.") @@ -166,10 +162,10 @@ if __name__ == "__main__": print("Loading pretrained model") model, tokenizer = load(args.model) - if model.__class__ not in SUPPORTED_MODELS: + if model.__class__ not in LORA_SUPPORTED_MODELS: raise ValueError( f"Model {model.__class__} not supported. " - f"Supported models: { SUPPORTED_MODELS}" + f"Supported models: {LORA_SUPPORTED_MODELS}" ) # Freeze all layers other than LORA linears diff --git a/llms/mlx_lm/utils.py b/llms/mlx_lm/utils.py index 726342e0..6fa23d76 100644 --- a/llms/mlx_lm/utils.py +++ b/llms/mlx_lm/utils.py @@ -25,6 +25,9 @@ MODEL_MAPPING = { "qwen": qwen, "plamo": plamo, } +LORA_SUPPORTED_MODELS = [ + llama.Model, mixtral.Model, phi2.Model, stablelm_epoch.Model +] MAX_FILE_SIZE_GB = 5 linear_class_predicate = (