Added lora support for Phi-2 (#302)

* Added lora support for Phi-2

* Added Phi-2 support in fuse and convert

* format + readme

---------

Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
Yousif
2024-01-12 13:45:30 -08:00
committed by GitHub
parent 3ac731dd4f
commit 7575125d5d
12 changed files with 564 additions and 25 deletions

View File

@@ -4,9 +4,9 @@ import argparse
from pathlib import Path
import mlx.core as mx
import models
import utils
from mlx.utils import tree_flatten, tree_unflatten
from models.lora import LoRALinear
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="LoRA or QLoRA finetuning.")
@@ -45,7 +45,7 @@ if __name__ == "__main__":
print("Loading pretrained model")
args = parser.parse_args()
model, tokenizer, config = models.load(args.model)
model, tokenizer, config = utils.load(args.model)
# Load adapters and get number of LoRA layers
adapters = list(mx.load(args.adapter_file).items())
@@ -54,14 +54,14 @@ if __name__ == "__main__":
# Freeze all layers other than LORA linears
model.freeze()
for l in model.model.layers[-lora_layers:]:
l.self_attn.q_proj = models.LoRALinear.from_linear(l.self_attn.q_proj)
l.self_attn.v_proj = models.LoRALinear.from_linear(l.self_attn.v_proj)
l.self_attn.q_proj = LoRALinear.from_linear(l.self_attn.q_proj)
l.self_attn.v_proj = LoRALinear.from_linear(l.self_attn.v_proj)
model.update(tree_unflatten(adapters))
fused_linears = [
(n, m.to_linear())
for n, m in model.named_modules()
if isinstance(m, models.LoRALinear)
if isinstance(m, LoRALinear)
]
model.update_modules(tree_unflatten(fused_linears))