Lazy import + refactor Lora layer addition (#426)

* lazy model import in mlx_lm

* change lora loading

* fix olmo lora

* remove a bunch of unused stuff from plamo

* move phixtral to mlx-lm and out of llms/
This commit is contained in:
Awni Hannun
2024-02-12 10:51:02 -08:00
committed by GitHub
parent 4576946151
commit d4666615bb
15 changed files with 127 additions and 393 deletions

View File

@@ -9,7 +9,8 @@ from mlx.utils import tree_flatten
from .tuner.lora import LoRALinear
from .tuner.trainer import TrainingArgs, evaluate, train
from .utils import LORA_SUPPORTED_MODELS, generate, load
from .tuner.utils import linear_to_lora_layers
from .utils import generate, load
def build_parser():
@@ -169,19 +170,10 @@ if __name__ == "__main__":
print("Loading pretrained model")
model, tokenizer = load(args.model)
if model.__class__ not in LORA_SUPPORTED_MODELS:
raise ValueError(
f"Model {model.__class__} not supported. "
f"Supported models: {LORA_SUPPORTED_MODELS}"
)
# Freeze all layers other than LORA linears
# Freeze all layers
model.freeze()
for l in model.model.layers[len(model.model.layers) - args.lora_layers :]:
l.self_attn.q_proj = LoRALinear.from_linear(l.self_attn.q_proj)
l.self_attn.v_proj = LoRALinear.from_linear(l.self_attn.v_proj)
if hasattr(l, "block_sparse_moe"):
l.block_sparse_moe.gate = LoRALinear.from_linear(l.block_sparse_moe.gate)
# Convert linear layers to lora layers and unfreeze in the process
linear_to_lora_layers(model, args.lora_layers)
p = sum(v.size for _, v in tree_flatten(model.parameters())) / 10**6
print(f"Total parameters {p:.3f}M")