mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-09-01 12:49:50 +08:00
Added lora support for Phi-2 (#302)
* Added lora support for Phi-2 * Added Phi-2 support in fuse and convert * format + readme --------- Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
11
lora/lora.py
11
lora/lora.py
@@ -9,9 +9,10 @@ from pathlib import Path
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
import mlx.optimizers as optim
|
||||
import models
|
||||
import numpy as np
|
||||
import utils as lora_utils
|
||||
from mlx.utils import tree_flatten, tree_unflatten
|
||||
from models.lora import LoRALinear
|
||||
|
||||
|
||||
def build_parser():
|
||||
@@ -270,7 +271,7 @@ def generate(model, prompt, tokenizer, args):
|
||||
tokens = []
|
||||
skip = 0
|
||||
for token, n in zip(
|
||||
models.generate(prompt, model, args.temp),
|
||||
lora_utils.generate(prompt, model, args.temp),
|
||||
range(args.max_tokens),
|
||||
):
|
||||
if token == tokenizer.eos_token_id:
|
||||
@@ -294,13 +295,13 @@ if __name__ == "__main__":
|
||||
np.random.seed(args.seed)
|
||||
|
||||
print("Loading pretrained model")
|
||||
model, tokenizer, _ = models.load(args.model)
|
||||
model, tokenizer, _ = lora_utils.load(args.model)
|
||||
|
||||
# Freeze all layers other than LORA linears
|
||||
model.freeze()
|
||||
for l in model.model.layers[len(model.model.layers) - args.lora_layers :]:
|
||||
l.self_attn.q_proj = models.LoRALinear.from_linear(l.self_attn.q_proj)
|
||||
l.self_attn.v_proj = models.LoRALinear.from_linear(l.self_attn.v_proj)
|
||||
l.self_attn.q_proj = LoRALinear.from_linear(l.self_attn.q_proj)
|
||||
l.self_attn.v_proj = LoRALinear.from_linear(l.self_attn.v_proj)
|
||||
|
||||
p = sum(v.size for _, v in tree_flatten(model.parameters())) / 10**6
|
||||
print(f"Total parameters {p:.3f}M")
|
||||
|
Reference in New Issue
Block a user