LoRA: Remove unnecessary model type judgments (#388)

* LoRA: Remove unnecessary model type judgments

1. Supported models are already checked in the load_model function in utils, no need to repeat the check in lora
2. The checks in lora are not synchronized with those in utils

* LoRA: add LoRA supported models in mlx_lm utils
This commit is contained in:
Madroid Ma 2024-02-01 03:55:27 +08:00 committed by GitHub
parent 0a49ba0697
commit ba3a9355d1
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 6 additions and 7 deletions

View File

@ -7,13 +7,9 @@ import mlx.optimizers as optim
import numpy as np import numpy as np
from mlx.utils import tree_flatten from mlx.utils import tree_flatten
from .models import llama, mixtral, phi2
from .tuner.lora import LoRALinear from .tuner.lora import LoRALinear
from .tuner.trainer import TrainingArgs, evaluate, train from .tuner.trainer import TrainingArgs, evaluate, train
from .utils import generate, load from .utils import generate, load, LORA_SUPPORTED_MODELS
SUPPORTED_MODELS = [llama.Model, mixtral.Model, phi2.Model]
def build_parser(): def build_parser():
parser = argparse.ArgumentParser(description="LoRA or QLoRA finetuning.") parser = argparse.ArgumentParser(description="LoRA or QLoRA finetuning.")
@ -166,10 +162,10 @@ if __name__ == "__main__":
print("Loading pretrained model") print("Loading pretrained model")
model, tokenizer = load(args.model) model, tokenizer = load(args.model)
if model.__class__ not in SUPPORTED_MODELS: if model.__class__ not in LORA_SUPPORTED_MODELS:
raise ValueError( raise ValueError(
f"Model {model.__class__} not supported. " f"Model {model.__class__} not supported. "
f"Supported models: { SUPPORTED_MODELS}" f"Supported models: {LORA_SUPPORTED_MODELS}"
) )
# Freeze all layers other than LORA linears # Freeze all layers other than LORA linears

View File

@ -25,6 +25,9 @@ MODEL_MAPPING = {
"qwen": qwen, "qwen": qwen,
"plamo": plamo, "plamo": plamo,
} }
LORA_SUPPORTED_MODELS = [
llama.Model, mixtral.Model, phi2.Model, stablelm_epoch.Model
]
MAX_FILE_SIZE_GB = 5 MAX_FILE_SIZE_GB = 5
linear_class_predicate = ( linear_class_predicate = (