Move lora example to use the same model format / conversion as hf_llm (#252)

* huffing face the lora example to allow more models

* fixes

* comments

* more readme nits

* fusion + works better for qlora

* nits'

* comments
This commit is contained in:
Awni Hannun
2024-01-09 11:14:52 -08:00
committed by GitHub
parent bbd7172eef
commit 7b258f33ac
10 changed files with 521 additions and 224 deletions

View File

@@ -1,23 +1,81 @@
# Copyright © 2023 Apple Inc.
import glob
import inspect
import json
import math
from dataclasses import dataclass
from typing import List, Optional, Tuple
from pathlib import Path
from typing import Dict, List, Optional, Tuple, Union
import mlx.core as mx
import mlx.nn as nn
from mlx.utils import tree_map, tree_unflatten
import numpy as np
from huggingface_hub import snapshot_download
from transformers import AutoTokenizer
@dataclass
class ModelArgs:
dim: int
n_layers: int
head_dim: int
hidden_dim: int
n_heads: int
n_kv_heads: int
norm_eps: float
hidden_size: int
num_hidden_layers: int
intermediate_size: int
num_attention_heads: int
rms_norm_eps: float
vocab_size: int
num_key_value_heads: int = None
rope_theta: float = 10000
rope_traditional: bool = False
model_type: str = None
rope_scaling: Optional[Dict[str, Union[float, str]]] = None
def __post_init__(self):
if self.num_key_value_heads is None:
self.num_key_value_heads = self.num_attention_heads
if self.rope_scaling:
required_keys = {"factor", "type"}
if not all(key in self.rope_scaling for key in required_keys):
raise ValueError(f"rope_scaling must contain keys {required_keys}")
if self.rope_scaling["type"] != "linear":
raise ValueError("rope_scaling 'type' currently only supports 'linear'")
@classmethod
def from_dict(cls, params):
return cls(
**{
k: v
for k, v in params.items()
if k in inspect.signature(cls).parameters
}
)
class Tokenizer:
def __init__(self, model_path: str):
self._tokenizer = AutoTokenizer.from_pretrained(model_path)
self._eos = self._tokenizer.eos_token_id
self._bos = self._tokenizer.bos_token_id
def encode(self, s: str, eos: bool = False) -> mx.array:
toks = self._tokenizer(
s,
return_tensors="np",
return_attention_mask=False,
)[
"input_ids"
][0]
if eos:
toks = np.concatenate([toks, [self._eos]])
return mx.array(toks)
@property
def eos_id(self) -> int:
return self._eos
def decode(self, t: List[int]) -> str:
return self._tokenizer.decode(t)
class LoRALinear(nn.Module):
@@ -32,14 +90,58 @@ class LoRALinear(nn.Module):
lora_lin.linear = linear
return lora_lin
def to_linear(self):
linear = self.linear
bias = "bias" in linear
weight = linear.weight
is_quantized = isinstance(linear, nn.QuantizedLinear)
# Use the same type as the linear weight if not quantized
dtype = weight.dtype
if is_quantized:
dtype = mx.float16
weight = mx.dequantize(
weight,
linear.scales,
linear.biases,
linear.group_size,
linear.bits,
)
output_dims, input_dims = weight.shape
fused_linear = nn.Linear(input_dims, output_dims, bias=bias)
lora_b = (self.scale * self.lora_b.T).astype(dtype)
lora_a = self.lora_a.T.astype(dtype)
fused_linear.weight = weight + lora_b @ lora_a
if bias:
fused_linear.bias = linear.bias
if is_quantized:
fused_linear = nn.QuantizedLinear.from_linear(
fused_linear,
linear.group_size,
linear.bits,
)
return fused_linear
def __init__(
self, input_dims: int, output_dims: int, lora_rank: int = 8, bias: bool = False
self,
input_dims: int,
output_dims: int,
lora_rank: int = 8,
bias: bool = False,
scale: float = 20.0,
):
super().__init__()
# Regular linear layer weights
self.linear = nn.Linear(input_dims, output_dims, bias=bias)
# Scale for low-rank update
self.scale = scale
# Low rank lora weights
scale = 1 / math.sqrt(input_dims)
self.lora_a = mx.random.uniform(
@@ -55,7 +157,7 @@ class LoRALinear(nn.Module):
dtype = self.linear.scales.dtype
y = self.linear(x.astype(dtype))
z = (x @ self.lora_a) @ self.lora_b
return y + 2.0 * z
return y + self.scale * z
class RMSNorm(nn.Module):
@@ -75,20 +177,31 @@ class RMSNorm(nn.Module):
class Attention(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.args = args
self.n_heads: int = args.n_heads
self.n_kv_heads: int = args.n_kv_heads
dim = args.hidden_size
self.n_heads = n_heads = args.num_attention_heads
self.n_kv_heads = n_kv_heads = args.num_key_value_heads
self.repeats = self.n_heads // self.n_kv_heads
self.repeats = n_heads // n_kv_heads
self.scale = self.args.head_dim**-0.5
head_dim = args.hidden_size // n_heads
self.scale = head_dim**-0.5
self.wq = nn.Linear(args.dim, args.n_heads * args.head_dim, bias=False)
self.wk = nn.Linear(args.dim, args.n_kv_heads * args.head_dim, bias=False)
self.wv = nn.Linear(args.dim, args.n_kv_heads * args.head_dim, bias=False)
self.wo = nn.Linear(args.n_heads * args.head_dim, args.dim, bias=False)
self.rope = nn.RoPE(args.head_dim, traditional=True)
self.q_proj = nn.Linear(dim, n_heads * head_dim, bias=False)
self.k_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=False)
self.v_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=False)
self.o_proj = nn.Linear(n_heads * head_dim, dim, bias=False)
rope_scale = (
1 / args.rope_scaling["factor"]
if args.rope_scaling is not None and args.rope_scaling["type"] == "linear"
else 1
)
self.rope = nn.RoPE(
head_dim,
traditional=args.rope_traditional,
base=args.rope_theta,
scale=rope_scale,
)
def __call__(
self,
@@ -98,7 +211,7 @@ class Attention(nn.Module):
) -> mx.array:
B, L, D = x.shape
queries, keys, values = self.wq(x), self.wk(x), self.wv(x)
queries, keys, values = self.q_proj(x), self.k_proj(x), self.v_proj(x)
# Prepare the queries, keys and values for the attention computation
queries = queries.reshape(B, L, self.n_heads, -1).transpose(0, 2, 1, 3)
@@ -127,30 +240,29 @@ class Attention(nn.Module):
scores += mask
scores = mx.softmax(scores.astype(mx.float32), axis=-1).astype(scores.dtype)
output = (scores @ values).transpose(0, 2, 1, 3).reshape(B, L, -1)
return self.wo(output), (keys, values)
return self.o_proj(output), (keys, values)
class FeedForward(nn.Module):
def __init__(self, args: ModelArgs):
class MLP(nn.Module):
def __init__(self, dim, hidden_dim):
super().__init__()
self.w1 = nn.Linear(args.dim, args.hidden_dim, bias=False)
self.w2 = nn.Linear(args.hidden_dim, args.dim, bias=False)
self.w3 = nn.Linear(args.dim, args.hidden_dim, bias=False)
self.gate_proj = nn.Linear(dim, hidden_dim, bias=False)
self.down_proj = nn.Linear(hidden_dim, dim, bias=False)
self.up_proj = nn.Linear(dim, hidden_dim, bias=False)
def __call__(self, x) -> mx.array:
return self.w2(nn.silu(self.w1(x)) * self.w3(x))
return self.down_proj(nn.silu(self.gate_proj(x)) * self.up_proj(x))
class TransformerBlock(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.n_heads = args.n_heads
self.dim = args.dim
self.attention = Attention(args)
self.feed_forward = FeedForward(args=args)
self.attention_norm = RMSNorm(args.dim, eps=args.norm_eps)
self.ffn_norm = RMSNorm(args.dim, eps=args.norm_eps)
self.num_attention_heads = args.num_attention_heads
self.hidden_size = args.hidden_size
self.self_attn = Attention(args)
self.mlp = MLP(args.hidden_size, args.intermediate_size)
self.input_layernorm = RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
self.post_attention_layernorm = RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
self.args = args
def __call__(
@@ -159,31 +271,32 @@ class TransformerBlock(nn.Module):
mask: Optional[mx.array] = None,
cache: Optional[Tuple[mx.array, mx.array]] = None,
) -> mx.array:
r, cache = self.attention(self.attention_norm(x), mask, cache)
r, cache = self.self_attn(self.input_layernorm(x), mask, cache)
h = x + r
r = self.feed_forward(self.ffn_norm(h))
r = self.mlp(self.post_attention_layernorm(h))
out = h + r
return out, cache
class Model(nn.Module):
class LlamaModel(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.args = args
self.vocab_size = args.vocab_size
self.n_layers = args.n_layers
self.num_hidden_layers = args.num_hidden_layers
assert self.vocab_size > 0
self.tok_embeddings = nn.Embedding(args.vocab_size, args.dim)
self.layers = [TransformerBlock(args=args) for _ in range(args.n_layers)]
self.norm = RMSNorm(args.dim, eps=args.norm_eps)
self.output = nn.Linear(args.dim, args.vocab_size, bias=False)
self.embed_tokens = nn.Embedding(args.vocab_size, args.hidden_size)
self.layers = [
TransformerBlock(args=args) for _ in range(args.num_hidden_layers)
]
self.norm = RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
def __call__(
self,
inputs: mx.array,
cache=None,
):
h = self.tok_embeddings(inputs)
h = self.embed_tokens(inputs)
mask = None
if h.shape[1] > 1:
@@ -196,4 +309,70 @@ class Model(nn.Module):
for e, layer in enumerate(self.layers):
h, cache[e] = layer(h, mask, cache[e])
return self.output(self.norm(h)), cache
return self.norm(h), cache
class Model(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.model = LlamaModel(args)
self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False)
def __call__(
self,
inputs: mx.array,
cache=None,
):
out, cache = self.model(inputs, cache)
return self.lm_head(out), cache
def load(path_or_hf_repo: str):
# If the path exists, it will try to load model form it
# otherwise download and cache from the hf_repo and cache
model_path = Path(path_or_hf_repo)
if not model_path.exists():
model_path = Path(
snapshot_download(
repo_id=path_or_hf_repo,
allow_patterns=["*.json", "*.safetensors", "tokenizer.model"],
)
)
with open(model_path / "config.json", "r") as f:
config = json.loads(f.read())
quantization = config.get("quantization", None)
model_args = ModelArgs.from_dict(config)
weight_files = glob.glob(str(model_path / "*.safetensors"))
if len(weight_files) == 0:
raise FileNotFoundError("No safetensors found in {}".format(model_path))
weights = {}
for wf in weight_files:
weights.update(mx.load(wf).items())
model = Model(model_args)
if quantization is not None:
nn.QuantizedLinear.quantize_module(model, **quantization)
model.load_weights(list(weights.items()))
mx.eval(model.parameters())
return model, Tokenizer(model_path), config
def generate(prompt: mx.array, model: Model, temp: float = 0.0):
def sample(logits):
if temp == 0:
return mx.argmax(logits, axis=-1)
else:
return mx.random.categorical(logits * (1 / temp))
y = prompt
cache = None
while True:
logits, cache = model(y[None], cache=cache)
logits = logits[:, -1, :]
y = sample(logits)
yield y