mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-10-23 22:18:06 +08:00
Added lora support for Phi-2 (#302)
* Added lora support for Phi-2 * Added Phi-2 support in fuse and convert * format + readme --------- Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
10
lora/fuse.py
10
lora/fuse.py
@@ -4,9 +4,9 @@ import argparse
|
||||
from pathlib import Path
|
||||
|
||||
import mlx.core as mx
|
||||
import models
|
||||
import utils
|
||||
from mlx.utils import tree_flatten, tree_unflatten
|
||||
from models.lora import LoRALinear
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="LoRA or QLoRA finetuning.")
|
||||
@@ -45,7 +45,7 @@ if __name__ == "__main__":
|
||||
print("Loading pretrained model")
|
||||
args = parser.parse_args()
|
||||
|
||||
model, tokenizer, config = models.load(args.model)
|
||||
model, tokenizer, config = utils.load(args.model)
|
||||
|
||||
# Load adapters and get number of LoRA layers
|
||||
adapters = list(mx.load(args.adapter_file).items())
|
||||
@@ -54,14 +54,14 @@ if __name__ == "__main__":
|
||||
# Freeze all layers other than LORA linears
|
||||
model.freeze()
|
||||
for l in model.model.layers[-lora_layers:]:
|
||||
l.self_attn.q_proj = models.LoRALinear.from_linear(l.self_attn.q_proj)
|
||||
l.self_attn.v_proj = models.LoRALinear.from_linear(l.self_attn.v_proj)
|
||||
l.self_attn.q_proj = LoRALinear.from_linear(l.self_attn.q_proj)
|
||||
l.self_attn.v_proj = LoRALinear.from_linear(l.self_attn.v_proj)
|
||||
|
||||
model.update(tree_unflatten(adapters))
|
||||
fused_linears = [
|
||||
(n, m.to_linear())
|
||||
for n, m in model.named_modules()
|
||||
if isinstance(m, models.LoRALinear)
|
||||
if isinstance(m, LoRALinear)
|
||||
]
|
||||
|
||||
model.update_modules(tree_unflatten(fused_linears))
|
||||
|
Reference in New Issue
Block a user