mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-24 09:21:18 +08:00
LoRA: Remove unnecessary model type judgments (#388)
* LoRA: Remove unnecessary model type judgments 1. Supported models are already checked in the load_model function in utils, no need to repeat the check in lora 2. The checks in lora are not synchronized with those in utils * LoRA: add LoRA supported models in mlx_lm utils
This commit is contained in:
parent
0a49ba0697
commit
ba3a9355d1
@ -7,13 +7,9 @@ import mlx.optimizers as optim
|
||||
import numpy as np
|
||||
from mlx.utils import tree_flatten
|
||||
|
||||
from .models import llama, mixtral, phi2
|
||||
from .tuner.lora import LoRALinear
|
||||
from .tuner.trainer import TrainingArgs, evaluate, train
|
||||
from .utils import generate, load
|
||||
|
||||
SUPPORTED_MODELS = [llama.Model, mixtral.Model, phi2.Model]
|
||||
|
||||
from .utils import generate, load, LORA_SUPPORTED_MODELS
|
||||
|
||||
def build_parser():
|
||||
parser = argparse.ArgumentParser(description="LoRA or QLoRA finetuning.")
|
||||
@ -166,10 +162,10 @@ if __name__ == "__main__":
|
||||
print("Loading pretrained model")
|
||||
model, tokenizer = load(args.model)
|
||||
|
||||
if model.__class__ not in SUPPORTED_MODELS:
|
||||
if model.__class__ not in LORA_SUPPORTED_MODELS:
|
||||
raise ValueError(
|
||||
f"Model {model.__class__} not supported. "
|
||||
f"Supported models: { SUPPORTED_MODELS}"
|
||||
f"Supported models: {LORA_SUPPORTED_MODELS}"
|
||||
)
|
||||
|
||||
# Freeze all layers other than LORA linears
|
||||
|
@ -25,6 +25,9 @@ MODEL_MAPPING = {
|
||||
"qwen": qwen,
|
||||
"plamo": plamo,
|
||||
}
|
||||
LORA_SUPPORTED_MODELS = [
|
||||
llama.Model, mixtral.Model, phi2.Model, stablelm_epoch.Model
|
||||
]
|
||||
MAX_FILE_SIZE_GB = 5
|
||||
|
||||
linear_class_predicate = (
|
||||
|
Loading…
Reference in New Issue
Block a user