Added lora support for Phi-2 (#302)

* Added lora support for Phi-2

* Added Phi-2 support in fuse and convert

* format + readme

---------

Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
Yousif 2024-01-12 13:45:30 -08:00 committed by GitHub
parent 3ac731dd4f
commit 7575125d5d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
12 changed files with 564 additions and 25 deletions

View File

@ -2,7 +2,7 @@
This is an example of using MLX to fine-tune an LLM with low rank adaptation
(LoRA) for a target task.[^lora] The example also supports quantized LoRA
(QLoRA).[^qlora] The example works with Llama and Mistral style
(QLoRA).[^qlora] The example works with Llama, Mistral, and Phi-2 style
models available on Hugging Face.
In this example we'll use the WikiSQL[^wikisql] dataset to train the LLM to
@ -81,7 +81,7 @@ To fine-tune a model use:
```
python lora.py --model <path_to_model> \
--train \
--iters 600
--iters 600 \
```
If `--model` points to a quantized model, then the training will use QLoRA,
@ -100,7 +100,7 @@ To compute test set perplexity use:
```
python lora.py --model <path_to_model> \
--adapter-file <path_to_adapters.npz> \
--test
--test \
```
### Generate
@ -114,7 +114,7 @@ python lora.py --model <path_to_model> \
--prompt "table: 1-10015132-16
columns: Player, No., Nationality, Position, Years in Toronto, School/Club Team
Q: What is terrence ross' nationality
A: "
A: " \
```
## Results
@ -211,7 +211,7 @@ python lora.py \
--model mistralai/Mistral-7B-v0.1 \
--train \
--batch-size 1 \
--lora-layers 4
--lora-layers 4 \
```
The above command on an M1 Max with 32 GB runs at about 250 tokens-per-second.

View File

@ -7,14 +7,16 @@ import mlx.core as mx
import mlx.nn as nn
import utils
from mlx.utils import tree_flatten
from models import Model, ModelArgs
def quantize(weights, config, args):
quantized_config = copy.deepcopy(config)
# Get model classes
model_class, model_args_class = utils._get_classes(config=config)
# Load the model:
model = Model(ModelArgs.from_dict(config))
model = model_class(model_args_class.from_dict(config))
model.load_weights(list(weights.items()))
# Quantize the model:

View File

@ -4,9 +4,9 @@ import argparse
from pathlib import Path
import mlx.core as mx
import models
import utils
from mlx.utils import tree_flatten, tree_unflatten
from models.lora import LoRALinear
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="LoRA or QLoRA finetuning.")
@ -45,7 +45,7 @@ if __name__ == "__main__":
print("Loading pretrained model")
args = parser.parse_args()
model, tokenizer, config = models.load(args.model)
model, tokenizer, config = utils.load(args.model)
# Load adapters and get number of LoRA layers
adapters = list(mx.load(args.adapter_file).items())
@ -54,14 +54,14 @@ if __name__ == "__main__":
# Freeze all layers other than LORA linears
model.freeze()
for l in model.model.layers[-lora_layers:]:
l.self_attn.q_proj = models.LoRALinear.from_linear(l.self_attn.q_proj)
l.self_attn.v_proj = models.LoRALinear.from_linear(l.self_attn.v_proj)
l.self_attn.q_proj = LoRALinear.from_linear(l.self_attn.q_proj)
l.self_attn.v_proj = LoRALinear.from_linear(l.self_attn.v_proj)
model.update(tree_unflatten(adapters))
fused_linears = [
(n, m.to_linear())
for n, m in model.named_modules()
if isinstance(m, models.LoRALinear)
if isinstance(m, LoRALinear)
]
model.update_modules(tree_unflatten(fused_linears))

View File

@ -9,9 +9,10 @@ from pathlib import Path
import mlx.core as mx
import mlx.nn as nn
import mlx.optimizers as optim
import models
import numpy as np
import utils as lora_utils
from mlx.utils import tree_flatten, tree_unflatten
from models.lora import LoRALinear
def build_parser():
@ -270,7 +271,7 @@ def generate(model, prompt, tokenizer, args):
tokens = []
skip = 0
for token, n in zip(
models.generate(prompt, model, args.temp),
lora_utils.generate(prompt, model, args.temp),
range(args.max_tokens),
):
if token == tokenizer.eos_token_id:
@ -294,13 +295,13 @@ if __name__ == "__main__":
np.random.seed(args.seed)
print("Loading pretrained model")
model, tokenizer, _ = models.load(args.model)
model, tokenizer, _ = lora_utils.load(args.model)
# Freeze all layers other than LORA linears
model.freeze()
for l in model.model.layers[len(model.model.layers) - args.lora_layers :]:
l.self_attn.q_proj = models.LoRALinear.from_linear(l.self_attn.q_proj)
l.self_attn.v_proj = models.LoRALinear.from_linear(l.self_attn.v_proj)
l.self_attn.q_proj = LoRALinear.from_linear(l.self_attn.q_proj)
l.self_attn.v_proj = LoRALinear.from_linear(l.self_attn.v_proj)
p = sum(v.size for _, v in tree_flatten(model.parameters())) / 10**6
print(f"Total parameters {p:.3f}M")

0
lora/models/__init__.py Normal file
View File

15
lora/models/base.py Normal file
View File

@ -0,0 +1,15 @@
import inspect
from dataclasses import dataclass
@dataclass
class BaseModelArgs:
@classmethod
def from_dict(cls, params):
return cls(
**{
k: v
for k, v in params.items()
if k in inspect.signature(cls).parameters
}
)

202
lora/models/llama.py Normal file
View File

@ -0,0 +1,202 @@
from dataclasses import dataclass
from typing import Dict, Optional, Tuple, Union
import mlx.core as mx
import mlx.nn as nn
from .base import BaseModelArgs
@dataclass
class ModelArgs(BaseModelArgs):
hidden_size: int
num_hidden_layers: int
intermediate_size: int
num_attention_heads: int
rms_norm_eps: float
vocab_size: int
num_key_value_heads: int = None
rope_theta: float = 10000
rope_traditional: bool = False
model_type: str = None
rope_scaling: Optional[Dict[str, Union[float, str]]] = None
def __post_init__(self):
if self.num_key_value_heads is None:
self.num_key_value_heads = self.num_attention_heads
if self.rope_scaling:
required_keys = {"factor", "type"}
if not all(key in self.rope_scaling for key in required_keys):
raise ValueError(f"rope_scaling must contain keys {required_keys}")
if self.rope_scaling["type"] != "linear":
raise ValueError("rope_scaling 'type' currently only supports 'linear'")
class RMSNorm(nn.Module):
def __init__(self, dims: int, eps: float = 1e-5):
super().__init__()
self.weight = mx.ones((dims,))
self.eps = eps
def _norm(self, x):
return x * mx.rsqrt(x.square().mean(-1, keepdims=True) + self.eps)
def __call__(self, x):
output = self._norm(x.astype(mx.float32)).astype(x.dtype)
return self.weight * output
class Attention(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
dim = args.hidden_size
self.n_heads = n_heads = args.num_attention_heads
self.n_kv_heads = n_kv_heads = args.num_key_value_heads
self.repeats = n_heads // n_kv_heads
head_dim = args.hidden_size // n_heads
self.scale = head_dim**-0.5
self.q_proj = nn.Linear(dim, n_heads * head_dim, bias=False)
self.k_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=False)
self.v_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=False)
self.o_proj = nn.Linear(n_heads * head_dim, dim, bias=False)
rope_scale = (
1 / args.rope_scaling["factor"]
if args.rope_scaling is not None and args.rope_scaling["type"] == "linear"
else 1
)
self.rope = nn.RoPE(
head_dim,
traditional=args.rope_traditional,
base=args.rope_theta,
scale=rope_scale,
)
def __call__(
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[Tuple[mx.array, mx.array]] = None,
) -> mx.array:
B, L, D = x.shape
queries, keys, values = self.q_proj(x), self.k_proj(x), self.v_proj(x)
# Prepare the queries, keys and values for the attention computation
queries = queries.reshape(B, L, self.n_heads, -1).transpose(0, 2, 1, 3)
keys = keys.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
values = values.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
def repeat(a):
a = mx.concatenate([mx.expand_dims(a, 2)] * self.repeats, axis=2)
return a.reshape([B, self.n_heads, L, -1])
if self.repeats > 1:
keys, values = map(repeat, (keys, values))
if cache is not None:
key_cache, value_cache = cache
queries = self.rope(queries, offset=key_cache.shape[2])
keys = self.rope(keys, offset=key_cache.shape[2])
keys = mx.concatenate([key_cache, keys], axis=2)
values = mx.concatenate([value_cache, values], axis=2)
else:
queries = self.rope(queries)
keys = self.rope(keys)
scores = (queries * self.scale) @ keys.transpose(0, 1, 3, 2)
if mask is not None:
scores += mask
scores = mx.softmax(scores.astype(mx.float32), axis=-1).astype(scores.dtype)
output = (scores @ values).transpose(0, 2, 1, 3).reshape(B, L, -1)
return self.o_proj(output), (keys, values)
class MLP(nn.Module):
def __init__(self, dim, hidden_dim):
super().__init__()
self.gate_proj = nn.Linear(dim, hidden_dim, bias=False)
self.down_proj = nn.Linear(hidden_dim, dim, bias=False)
self.up_proj = nn.Linear(dim, hidden_dim, bias=False)
def __call__(self, x) -> mx.array:
return self.down_proj(nn.silu(self.gate_proj(x)) * self.up_proj(x))
class TransformerBlock(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.num_attention_heads = args.num_attention_heads
self.hidden_size = args.hidden_size
self.self_attn = Attention(args)
self.mlp = MLP(args.hidden_size, args.intermediate_size)
self.input_layernorm = RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
self.post_attention_layernorm = RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
self.args = args
def __call__(
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[Tuple[mx.array, mx.array]] = None,
) -> mx.array:
r, cache = self.self_attn(self.input_layernorm(x), mask, cache)
h = x + r
r = self.mlp(self.post_attention_layernorm(h))
out = h + r
return out, cache
class LlamaModel(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.args = args
self.vocab_size = args.vocab_size
self.num_hidden_layers = args.num_hidden_layers
assert self.vocab_size > 0
self.embed_tokens = nn.Embedding(args.vocab_size, args.hidden_size)
self.layers = [
TransformerBlock(args=args) for _ in range(args.num_hidden_layers)
]
self.norm = RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
def __call__(
self,
inputs: mx.array,
cache=None,
):
h = self.embed_tokens(inputs)
mask = None
if h.shape[1] > 1:
mask = nn.MultiHeadAttention.create_additive_causal_mask(h.shape[1])
mask = mask.astype(h.dtype)
if cache is None:
cache = [None] * len(self.layers)
for e, layer in enumerate(self.layers):
h, cache[e] = layer(h, mask, cache[e])
return self.norm(h), cache
class Model(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.model = LlamaModel(args)
self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False)
def __call__(
self,
inputs: mx.array,
cache=None,
):
out, cache = self.model(inputs, cache)
return self.lm_head(out), cache

86
lora/models/lora.py Normal file
View File

@ -0,0 +1,86 @@
import math
import mlx.core as mx
import mlx.nn as nn
class LoRALinear(nn.Module):
@staticmethod
def from_linear(linear: nn.Linear, rank: int = 8):
# TODO remove when input_dims and output_dims are attributes
# on linear and quantized linear
output_dims, input_dims = linear.weight.shape
if isinstance(linear, nn.QuantizedLinear):
input_dims *= 32 // linear.bits
lora_lin = LoRALinear(input_dims, output_dims, rank)
lora_lin.linear = linear
return lora_lin
def to_linear(self):
linear = self.linear
bias = "bias" in linear
weight = linear.weight
is_quantized = isinstance(linear, nn.QuantizedLinear)
# Use the same type as the linear weight if not quantized
dtype = weight.dtype
if is_quantized:
dtype = mx.float16
weight = mx.dequantize(
weight,
linear.scales,
linear.biases,
linear.group_size,
linear.bits,
)
output_dims, input_dims = weight.shape
fused_linear = nn.Linear(input_dims, output_dims, bias=bias)
lora_b = (self.scale * self.lora_b.T).astype(dtype)
lora_a = self.lora_a.T.astype(dtype)
fused_linear.weight = weight + lora_b @ lora_a
if bias:
fused_linear.bias = linear.bias
if is_quantized:
fused_linear = nn.QuantizedLinear.from_linear(
fused_linear,
linear.group_size,
linear.bits,
)
return fused_linear
def __init__(
self,
input_dims: int,
output_dims: int,
lora_rank: int = 8,
bias: bool = False,
scale: float = 20.0,
):
super().__init__()
# Regular linear layer weights
self.linear = nn.Linear(input_dims, output_dims, bias=bias)
# Scale for low-rank update
self.scale = scale
# Low rank lora weights
scale = 1 / math.sqrt(input_dims)
self.lora_a = mx.random.uniform(
low=-scale,
high=scale,
shape=(input_dims, lora_rank),
)
self.lora_b = mx.zeros(shape=(lora_rank, output_dims))
def __call__(self, x):
dtype = self.linear.weight.dtype
if isinstance(self.linear, nn.QuantizedLinear):
dtype = self.linear.scales.dtype
y = self.linear(x.astype(dtype))
z = (x @ self.lora_a) @ self.lora_b
return y + self.scale * z

138
lora/models/phi2.py Normal file
View File

@ -0,0 +1,138 @@
import math
from dataclasses import dataclass
import mlx.core as mx
import mlx.nn as nn
from .base import BaseModelArgs
@dataclass
class ModelArgs(BaseModelArgs):
n_positions: int = 2048
vocab_size: int = 51200
n_embd: int = 2560
n_head: int = 32
n_layer: int = 32
rotary_dim: int = 32
class LayerNorm(nn.LayerNorm):
def __call__(self, x: mx.array) -> mx.array:
return super().__call__(x.astype(mx.float32)).astype(x.dtype)
class RoPEAttention(nn.Module):
def __init__(self, dims: int, n_head: int, rotary_dim: int):
super().__init__()
self.n_head = n_head
self.q_proj = nn.Linear(dims, dims)
self.k_proj = nn.Linear(dims, dims)
self.v_proj = nn.Linear(dims, dims)
self.dense = nn.Linear(dims, dims)
self.rope = nn.RoPE(rotary_dim, traditional=False)
def __call__(self, x, mask=None, cache=None):
queries, keys, values = self.q_proj(x), self.k_proj(x), self.v_proj(x)
# Extract some shapes
n_head = self.n_head
B, L, D = queries.shape
# Prepare the queries, keys and values for the attention computation
queries = queries.reshape(B, L, n_head, -1).transpose(0, 2, 1, 3)
keys = keys.reshape(B, L, n_head, -1).transpose(0, 2, 1, 3)
values = values.reshape(B, L, n_head, -1).transpose(0, 2, 1, 3)
# Add RoPE to the queries and keys and combine them with the cache
if cache is not None:
key_cache, value_cache = cache
queries = self.rope(queries, offset=key_cache.shape[2])
keys = self.rope(keys, offset=key_cache.shape[2])
keys = mx.concatenate([key_cache, keys], axis=2)
values = mx.concatenate([value_cache, values], axis=2)
else:
queries = self.rope(queries)
keys = self.rope(keys)
queries = queries.astype(mx.float32)
keys = keys.astype(mx.float32)
# Finally perform the attention computation
scale = math.sqrt(1 / queries.shape[-1])
scores = (queries * scale) @ keys.transpose(0, 1, 3, 2)
if mask is not None:
scores = scores + mask
scores = mx.softmax(scores, axis=-1).astype(values.dtype)
values_hat = (scores @ values).transpose(0, 2, 1, 3).reshape(B, L, -1)
return self.dense(values_hat), (keys, values)
class MLP(nn.Module):
def __init__(self, dim, hidden_dim):
super().__init__()
self.fc1 = nn.Linear(dim, hidden_dim)
self.fc2 = nn.Linear(hidden_dim, dim)
self.act = nn.GELU(approx="precise")
def __call__(self, x) -> mx.array:
return self.fc2(self.act(self.fc1(x)))
class ParallelBlock(nn.Module):
def __init__(self, config: ModelArgs):
super().__init__()
dims = config.n_embd
mlp_dims = dims * 4
self.self_attn = RoPEAttention(dims, config.n_head, config.rotary_dim)
self.input_layernorm = LayerNorm(dims)
self.mlp = MLP(dims, mlp_dims)
def __call__(self, x, mask, cache):
h = self.input_layernorm(x)
attn_h, cache = self.self_attn(h, mask, cache)
ff_h = self.mlp(h)
return attn_h + ff_h + x, cache
class Transformer(nn.Module):
def __init__(self, config: ModelArgs):
super().__init__()
self.embed_tokens = nn.Embedding(config.vocab_size, config.n_embd)
self.layers = [ParallelBlock(config) for i in range(config.n_layer)]
self.final_layernorm = LayerNorm(config.n_embd)
def __call__(self, x, mask, cache):
x = self.embed_tokens(x)
if cache is None:
cache = [None] * len(self.layers)
for e, layer in enumerate(self.layers):
x, cache[e] = layer(x, mask, cache[e])
return self.final_layernorm(x), cache
class Model(nn.Module):
def __init__(self, config: ModelArgs):
super().__init__()
self.model = Transformer(config)
self.lm_head = nn.Linear(config.n_embd, config.vocab_size)
def __call__(
self,
x: mx.array,
mask: mx.array = None,
cache: mx.array = None,
) -> tuple[mx.array, mx.array]:
mask = None
if x.shape[1] > 1:
mask = nn.MultiHeadAttention.create_additive_causal_mask(x.shape[1])
mask = mask.astype(x.dtype)
y, cache = self.model(x, mask, cache)
return self.lm_head(y), cache

View File

@ -2,12 +2,44 @@
import glob
import json
import logging
from pathlib import Path
from typing import Generator
import mlx.core as mx
import mlx.nn as nn
import models.llama as llama
import models.phi2 as phi2
import transformers
from huggingface_hub import snapshot_download
# Constants
MODEL_MAPPING = {
"llama": llama,
"mistral": llama, # mistral is compatible with llama
"phi": phi2,
}
def _get_classes(config: dict):
"""
Retrieve the model and model args classes based on the configuration.
Args:
config (dict): The model configuration.
Returns:
A tuple containing the Model class and the ModelArgs class.
"""
model_type = config["model_type"]
if model_type not in MODEL_MAPPING:
msg = f"Model type {model_type} not supported."
logging.error(msg)
raise ValueError(msg)
arch = MODEL_MAPPING[model_type]
return arch.Model, arch.ModelArgs
def fetch_from_hub(hf_path: str):
model_path = snapshot_download(
@ -88,3 +120,71 @@ def save_model(save_dir: str, weights, tokenizer, config):
tokenizer.save_pretrained(save_dir)
with open(save_dir / "config.json", "w") as fid:
json.dump(config, fid, indent=4)
def load(path_or_hf_repo: str):
# If the path exists, it will try to load model form it
# otherwise download and cache from the hf_repo and cache
model_path = Path(path_or_hf_repo)
if not model_path.exists():
model_path = Path(
snapshot_download(
repo_id=path_or_hf_repo,
allow_patterns=["*.json", "*.safetensors", "tokenizer.model"],
)
)
with open(model_path / "config.json", "r") as f:
config = json.loads(f.read())
quantization = config.get("quantization", None)
weight_files = glob.glob(str(model_path / "*.safetensors"))
if len(weight_files) == 0:
raise FileNotFoundError("No safetensors found in {}".format(model_path))
weights = {}
for wf in weight_files:
weights.update(mx.load(wf).items())
model_class, model_args_class = _get_classes(config=config)
model_args = model_args_class.from_dict(config)
model = model_class(model_args)
if quantization is not None:
nn.QuantizedLinear.quantize_module(model, **quantization)
model.load_weights(list(weights.items()))
mx.eval(model.parameters())
tokenizer = transformers.AutoTokenizer.from_pretrained(model_path)
return model, tokenizer, config
def generate(
prompt: mx.array, model: nn.Module, temp: float = 0.0
) -> Generator[mx.array, None, None]:
"""
Generate text based on the given prompt and model.
Args:
prompt (mx.array): The input prompt.
model (nn.Module): The model to use for generation.
temp (float): The temperature for sampling. If temp is 0, use max sampling.
Yields:
mx.array: The generated text.
"""
def sample(logits: mx.array) -> mx.array:
return (
mx.argmax(logits, axis=-1)
if temp == 0
else mx.random.categorical(logits * (1 / temp))
)
y = prompt
cache = None
while True:
logits, cache = model(y[None], cache=cache)
logits = logits[:, -1, :]
y = sample(logits)
yield y

View File

@ -3,9 +3,9 @@
import argparse
import mlx.core as mx
import numpy as np
from PIL import Image
from tqdm import tqdm
import numpy as np
from stable_diffusion import StableDiffusion

View File

@ -5,9 +5,8 @@ from pathlib import Path
import mlx.core as mx
import mlx.nn as nn
from mlx.utils import tree_unflatten
from huggingface_hub import snapshot_download
from mlx.utils import tree_unflatten
from . import whisper
@ -18,11 +17,7 @@ def load_model(
) -> whisper.Whisper:
model_path = Path(path_or_hf_repo)
if not model_path.exists():
model_path = Path(
snapshot_download(
repo_id=path_or_hf_repo
)
)
model_path = Path(snapshot_download(repo_id=path_or_hf_repo))
with open(str(model_path / "config.json"), "r") as f:
config = json.loads(f.read())