mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-24 09:21:18 +08:00
LoRA: Remove unnecessary model type judgments (#388)
* LoRA: Remove unnecessary model type judgments 1. Supported models are already checked in the load_model function in utils, no need to repeat the check in lora 2. The checks in lora are not synchronized with those in utils * LoRA: add LoRA supported models in mlx_lm utils
This commit is contained in:
parent
0a49ba0697
commit
ba3a9355d1
@ -7,13 +7,9 @@ import mlx.optimizers as optim
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
from mlx.utils import tree_flatten
|
from mlx.utils import tree_flatten
|
||||||
|
|
||||||
from .models import llama, mixtral, phi2
|
|
||||||
from .tuner.lora import LoRALinear
|
from .tuner.lora import LoRALinear
|
||||||
from .tuner.trainer import TrainingArgs, evaluate, train
|
from .tuner.trainer import TrainingArgs, evaluate, train
|
||||||
from .utils import generate, load
|
from .utils import generate, load, LORA_SUPPORTED_MODELS
|
||||||
|
|
||||||
SUPPORTED_MODELS = [llama.Model, mixtral.Model, phi2.Model]
|
|
||||||
|
|
||||||
|
|
||||||
def build_parser():
|
def build_parser():
|
||||||
parser = argparse.ArgumentParser(description="LoRA or QLoRA finetuning.")
|
parser = argparse.ArgumentParser(description="LoRA or QLoRA finetuning.")
|
||||||
@ -166,10 +162,10 @@ if __name__ == "__main__":
|
|||||||
print("Loading pretrained model")
|
print("Loading pretrained model")
|
||||||
model, tokenizer = load(args.model)
|
model, tokenizer = load(args.model)
|
||||||
|
|
||||||
if model.__class__ not in SUPPORTED_MODELS:
|
if model.__class__ not in LORA_SUPPORTED_MODELS:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Model {model.__class__} not supported. "
|
f"Model {model.__class__} not supported. "
|
||||||
f"Supported models: { SUPPORTED_MODELS}"
|
f"Supported models: {LORA_SUPPORTED_MODELS}"
|
||||||
)
|
)
|
||||||
|
|
||||||
# Freeze all layers other than LORA linears
|
# Freeze all layers other than LORA linears
|
||||||
|
@ -25,6 +25,9 @@ MODEL_MAPPING = {
|
|||||||
"qwen": qwen,
|
"qwen": qwen,
|
||||||
"plamo": plamo,
|
"plamo": plamo,
|
||||||
}
|
}
|
||||||
|
LORA_SUPPORTED_MODELS = [
|
||||||
|
llama.Model, mixtral.Model, phi2.Model, stablelm_epoch.Model
|
||||||
|
]
|
||||||
MAX_FILE_SIZE_GB = 5
|
MAX_FILE_SIZE_GB = 5
|
||||||
|
|
||||||
linear_class_predicate = (
|
linear_class_predicate = (
|
||||||
|
Loading…
Reference in New Issue
Block a user