feat: add yayi2-30b example

This commit is contained in:
Anchen
2024-01-01 01:49:12 +11:00
parent ee3c44d231
commit f4e38db123
4 changed files with 457 additions and 0 deletions

42
llms/yayi2/README.md Normal file
View File

@@ -0,0 +1,42 @@
# YAYI2
YAYI 2 is a collection of open-source large language models launched by Wenge Technology. YAYI2-30B is a Transformer-based large language model, and has been pretrained for 2.65 trillion tokens of multilingual data with high quality. The base model is aligned with human values through supervised fine-tuning with millions of instructions and reinforcement learning from human feedback (RLHF).
### Setup
Install the dependencies:
```
pip install -r requirements.txt
```
Next, download and convert the model.
```sh
python convert.py --hf-path <path_to_huggingface_model>
```
To generate a 4-bit quantized model, use `-q`. For a full list of options run:
```
python convert.py --help
```
The converter downloads the model from Hugging Face. The default model is
`wenge-research/yayi2-30b`. Check out the [Hugging Face
page](https://huggingface.co/wenge-research) to see a list of available models.
By default, the conversion script will save the converted `weights.npz`,
tokenizer, and `config.json` in the `mlx_model` directory.
### Run
Once you've converted the weights, you can interact with the Yayi2
model:
```
python yayi.py --prompt "The winter in Beijing is"
```

154
llms/yayi2/convert.py Normal file
View File

@@ -0,0 +1,154 @@
import argparse
import copy
import json
from pathlib import Path
import mlx.core as mx
import mlx.nn as nn
import numpy as np
import torch
from yayi import Yayi, ModelArgs
from mlx.utils import tree_flatten, tree_map, tree_unflatten
from transformers import AutoModelForCausalLM, AutoTokenizer
def quantize(weights, config, args):
quantized_config = copy.deepcopy(config)
# Load the model:
model_args = ModelArgs(**config)
model = Yayi(model_args)
weights = tree_map(mx.array, weights)
model.update(tree_unflatten(list(weights.items())))
# Quantize the model:
nn.QuantizedLinear.quantize_module(model, args.q_group_size, args.q_bits)
# Update the config:
quantized_config["quantization"] = {
"group_size": args.q_group_size,
"bits": args.q_bits,
}
quantized_weights = dict(tree_flatten(model.parameters()))
return quantized_weights, quantized_config
def convert(args):
hf_path = Path(args.hf_path)
model = AutoModelForCausalLM.from_pretrained(
str(hf_path), trust_remote_code=True, torch_dtype=torch.float16
)
config = model.config.to_dict()
state_dict = model.state_dict()
tokenizer = AutoTokenizer.from_pretrained(
str(hf_path), trust_remote_code=True, use_fast=False
)
# things to change
# 1. there's no "model." in the weight names
state_dict = {k.replace("model.", ""): v for k, v in state_dict.items()}
# 2. mlp is called feed_forward
state_dict = {k.replace("mlp", "feed_forward"): v for k, v in state_dict.items()}
# 3. up_proj, down_proj, gate_proj
state_dict = {k.replace("down_proj", "w2"): v for k, v in state_dict.items()}
state_dict = {k.replace("up_proj", "w3"): v for k, v in state_dict.items()}
state_dict = {k.replace("gate_proj", "w1"): v for k, v in state_dict.items()}
# 4. layernorms
state_dict = {
k.replace("input_layernorm", "attention_norm"): v for k, v in state_dict.items()
}
state_dict = {
k.replace("post_attention_layernorm", "ffn_norm"): v
for k, v in state_dict.items()
}
# 5. lm head
state_dict = {k.replace("lm_head", "output"): v for k, v in state_dict.items()}
# 6. token emb
state_dict = {
k.replace("embed_tokens", "tok_embeddings"): v for k, v in state_dict.items()
}
# 7. attention
state_dict = {k.replace("self_attn", "attention"): v for k, v in state_dict.items()}
state_dict = {k.replace("q_proj", "wq"): v for k, v in state_dict.items()}
state_dict = {k.replace("k_proj", "wk"): v for k, v in state_dict.items()}
state_dict = {k.replace("v_proj", "wv"): v for k, v in state_dict.items()}
state_dict = {k.replace("o_proj", "wo"): v for k, v in state_dict.items()}
weights = {k: v.numpy() for k, v in state_dict.items()}
keep_keys = set(
[
"vocab_size",
"hidden_size",
"num_attention_heads",
"num_hidden_layers",
"max_position_embeddings",
"rms_norm_eps",
"intermediate_size",
"rope_theta",
]
)
for k in list(config.keys()):
if k not in keep_keys:
config.pop(k)
return weights, config, tokenizer
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Convert Yayi2 model to npz")
parser.add_argument(
"--hf-path",
help="The huggingface model to be converted",
default="wenge-research/yayi2-30b",
)
parser.add_argument(
"--mlx-path",
type=str,
default="mlx_model",
help="The path to save the MLX model.",
)
parser.add_argument(
"-q",
"--quantize",
help="Generate a quantized model.",
action="store_true",
)
parser.add_argument(
"--q-group-size",
help="Group size for quantization.",
type=int,
default=64,
)
parser.add_argument(
"--q-bits",
help="Bits per weight for quantization.",
type=int,
default=4,
)
args = parser.parse_args()
mlx_path = Path(args.mlx_path)
mlx_path.mkdir(parents=True, exist_ok=True)
weights, config, tokenizer = convert(args)
if args.quantize:
print("[INFO] Quantizing")
weights, config = quantize(weights, config, args)
np.savez(str(mlx_path / "weights.npz"), **weights)
tokenizer.save_pretrained(mlx_path)
with open(mlx_path / "config.json", "w") as f:
config["model_type"] = "yayi"
json.dump(config, f, indent=4)

View File

@@ -0,0 +1,4 @@
torch
mlx
numpy
transformers>=4.35

257
llms/yayi2/yayi.py Normal file
View File

@@ -0,0 +1,257 @@
import argparse
import json
from dataclasses import dataclass
from pathlib import Path
from typing import Optional, Tuple
import mlx.core as mx
import mlx.nn as nn
from mlx.utils import tree_unflatten
from transformers import AutoTokenizer
@dataclass
class ModelArgs:
hidden_size: int = 7168
num_attention_heads: int = 64
num_hidden_layers: int = 64
max_position_embeddings: int = 4096
rms_norm_eps: float = 1e-6
intermediate_size: int = 16384
rope_theta: float = 100000
vocab_size: int = 81920
rope_traditional: bool = False
class RMSNorm(nn.Module):
def __init__(self, dims: int, eps: float = 1e-5):
super().__init__()
self.weight = mx.ones((dims,))
self.eps = eps
def _norm(self, x):
return x * mx.rsqrt(x.square().mean(-1, keepdims=True) + self.eps)
def __call__(self, x):
output = self._norm(x.astype(mx.float32)).astype(x.dtype)
return self.weight * output
class Attention(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.num_attention_heads: int = args.num_attention_heads
self.num_key_value_heads: int = args.num_key_value_heads
self.hidden_size = args.hidden_size
self.head_dim = args.hidden_size // args.num_attention_heads
self.scale = self.head_dim**-0.5
self.wq = nn.Linear(
args.hidden_size, args.num_attention_heads * self.head_dim, bias=False
)
self.wk = nn.Linear(args.hidden_size, int(self.head_dim), bias=False)
self.wv = nn.Linear(args.hidden_size, int(self.head_dim), bias=False)
self.wo = nn.Linear(
args.num_attention_heads * self.head_dim, args.hidden_size, bias=False
)
self.rope = nn.RoPE(
self.head_dim, traditional=args.rope_traditional, base=args.rope_theta
)
def __call__(
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[Tuple[mx.array, mx.array]] = None,
) -> mx.array:
B, L, _ = x.shape
q, k, v = self.wq(x), self.wk(x), self.wv(x)
q = q.reshape(B, L, self.num_attention_heads, self.head_dim).transpose(
0, 2, 1, 3
)
k = k.reshape(B, L, 1, self.head_dim).transpose(0, 2, 1, 3)
v = v.reshape(B, L, 1, self.head_dim).transpose(0, 2, 1, 3)
if cache is not None:
k_cache, v_cache = cache
q = self.rope(q, offset=k_cache.shape[2])
k = self.rope(k, offset=k_cache.shape[2])
k = mx.concatenate([k_cache, k], axis=2)
v = mx.concatenate([v_cache, v], axis=2)
else:
q = self.rope(q)
k = self.rope(k)
scores = (q * self.scale) @ k.transpose(0, 1, 3, 2)
if mask is not None:
scores = scores + mask
scores = mx.softmax(scores.astype(mx.float32), axis=-1).astype(scores.dtype)
v_hat = (scores @ v).transpose(0, 2, 1, 3).reshape(B, L, self.hidden_size)
return self.wo(v_hat), (k, v)
class FeedForward(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.w1 = nn.Linear(args.hidden_size, args.intermediate_size, bias=False)
self.w2 = nn.Linear(args.intermediate_size, args.hidden_size, bias=False)
self.w3 = nn.Linear(args.hidden_size, args.intermediate_size, bias=False)
def __call__(self, x) -> mx.array:
return self.w2(nn.silu(self.w1(x)) * self.w3(x))
class TransformerBlock(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.attention = Attention(args)
self.feed_forward = FeedForward(args=args)
self.attention_norm = RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
self.ffn_norm = RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
def __call__(
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[Tuple[mx.array, mx.array]] = None,
) -> mx.array:
r, cache = self.attention(self.attention_norm(x), mask, cache)
h = x + r
r = self.feed_forward(self.ffn_norm(h))
out = h + r
return out, cache
class Yayi(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.args = args
self.vocab_size = args.vocab_size
self.tok_embeddings = nn.Embedding(args.vocab_size, args.hidden_size)
self.layers = [
TransformerBlock(args=args) for _ in range(args.num_hidden_layers)
]
self.norm = RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
self.output = nn.Linear(args.hidden_size, args.vocab_size, bias=False)
def __call__(self, x, mask=None, cache=None):
x = self.tok_embeddings(x)
mask = None
T = x.shape[1]
if T > 1:
mask = nn.MultiHeadAttention.create_additive_causal_mask(T)
mask = mask.astype(x.dtype)
if cache is None:
cache = [None] * len(self.layers)
for e, layer in enumerate(self.layers):
x, cache[e] = layer(x, mask, cache[e])
x = self.norm(x)
return self.output(x), cache
def generate(
prompt: mx.array,
model: Yayi,
temp: float = 0.0,
):
def sample(logits):
if temp == 0:
return mx.argmax(logits, axis=-1)
else:
return mx.random.categorical(logits * (1 / temp))
y = prompt
cache = None
while True:
logits, cache = model(y[None], cache=cache)
logits = logits[:, -1, :]
y = sample(logits)
yield y
def load_model(model_path: str):
model_path = Path(model_path)
with open(model_path / "config.json", "r") as f:
config = json.load(f)
config.pop("model_type")
quantization = config.pop("quantization", None)
model_args = ModelArgs(**config)
model = Yayi(model_args)
weights = mx.load(str(model_path / "weights.npz"))
if quantization is not None:
nn.QuantizedLinear.quantize_module(model, **quantization)
model.update(tree_unflatten(list(weights.items())))
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
return model, tokenizer
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Yayi inference script")
parser.add_argument(
"--model-path",
type=str,
default="mlx_model",
help="The path to the mlx model weights, tokenizer, and config",
)
parser.add_argument(
"--prompt",
help="The message to be processed by the model",
default="The winter in Beijing is",
)
parser.add_argument(
"--max-tokens",
"-m",
type=int,
default=100,
help="Maximum number of tokens to generate",
)
parser.add_argument(
"--temp",
help="The sampling temperature.",
type=float,
default=0.6,
)
parser.add_argument("--seed", type=int, default=0, help="The PRNG seed")
args = parser.parse_args()
mx.random.seed(args.seed)
model, tokenizer = load_model(args.model_path)
prompt = tokenizer(
args.prompt,
return_tensors="np",
return_attention_mask=False,
)[
"input_ids"
][0]
prompt = mx.array(prompt)
print(args.prompt, end="", flush=True)
tokens = []
skip = 0
for token, _ in zip(
generate(prompt, model, args.temp),
range(args.max_tokens),
):
if token == tokenizer.eos_token_id:
break
tokens.append(token.item())
s = tokenizer.decode(tokens)
print(s[skip:], end="", flush=True)
skip = len(s)
print(tokenizer.decode(tokens)[skip:], flush=True)