LoRA: Remove unnecessary model type judgments (#388)

* LoRA: Remove unnecessary model type judgments

1. Supported models are already checked in the load_model function in utils, no need to repeat the check in lora
2. The checks in lora are not synchronized with those in utils

* LoRA: add LoRA supported models in mlx_lm utils
This commit is contained in:
Madroid Ma 2024-02-01 03:55:27 +08:00 committed by GitHub
parent 0a49ba0697
commit ba3a9355d1
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 6 additions and 7 deletions

View File

@ -7,13 +7,9 @@ import mlx.optimizers as optim
import numpy as np
from mlx.utils import tree_flatten
from .models import llama, mixtral, phi2
from .tuner.lora import LoRALinear
from .tuner.trainer import TrainingArgs, evaluate, train
from .utils import generate, load
SUPPORTED_MODELS = [llama.Model, mixtral.Model, phi2.Model]
from .utils import generate, load, LORA_SUPPORTED_MODELS
def build_parser():
parser = argparse.ArgumentParser(description="LoRA or QLoRA finetuning.")
@ -166,10 +162,10 @@ if __name__ == "__main__":
print("Loading pretrained model")
model, tokenizer = load(args.model)
if model.__class__ not in SUPPORTED_MODELS:
if model.__class__ not in LORA_SUPPORTED_MODELS:
raise ValueError(
f"Model {model.__class__} not supported. "
f"Supported models: { SUPPORTED_MODELS}"
f"Supported models: {LORA_SUPPORTED_MODELS}"
)
# Freeze all layers other than LORA linears

View File

@ -25,6 +25,9 @@ MODEL_MAPPING = {
"qwen": qwen,
"plamo": plamo,
}
LORA_SUPPORTED_MODELS = [
llama.Model, mixtral.Model, phi2.Model, stablelm_epoch.Model
]
MAX_FILE_SIZE_GB = 5
linear_class_predicate = (