mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-09-01 04:14:38 +08:00
feat: add yayi2-30b example
This commit is contained in:
42
llms/yayi2/README.md
Normal file
42
llms/yayi2/README.md
Normal file
@@ -0,0 +1,42 @@
|
||||
# YAYI2
|
||||
|
||||
YAYI 2 is a collection of open-source large language models launched by Wenge Technology. YAYI2-30B is a Transformer-based large language model, and has been pretrained for 2.65 trillion tokens of multilingual data with high quality. The base model is aligned with human values through supervised fine-tuning with millions of instructions and reinforcement learning from human feedback (RLHF).
|
||||
|
||||
|
||||
### Setup
|
||||
|
||||
Install the dependencies:
|
||||
|
||||
```
|
||||
pip install -r requirements.txt
|
||||
```
|
||||
|
||||
Next, download and convert the model.
|
||||
|
||||
```sh
|
||||
python convert.py --hf-path <path_to_huggingface_model>
|
||||
```
|
||||
|
||||
To generate a 4-bit quantized model, use `-q`. For a full list of options run:
|
||||
|
||||
```
|
||||
python convert.py --help
|
||||
```
|
||||
|
||||
The converter downloads the model from Hugging Face. The default model is
|
||||
`wenge-research/yayi2-30b`. Check out the [Hugging Face
|
||||
page](https://huggingface.co/wenge-research) to see a list of available models.
|
||||
|
||||
By default, the conversion script will save the converted `weights.npz`,
|
||||
tokenizer, and `config.json` in the `mlx_model` directory.
|
||||
|
||||
### Run
|
||||
|
||||
Once you've converted the weights, you can interact with the Yayi2
|
||||
model:
|
||||
|
||||
```
|
||||
python yayi.py --prompt "The winter in Beijing is"
|
||||
```
|
||||
|
||||
|
154
llms/yayi2/convert.py
Normal file
154
llms/yayi2/convert.py
Normal file
@@ -0,0 +1,154 @@
|
||||
import argparse
|
||||
import copy
|
||||
import json
|
||||
from pathlib import Path
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
import numpy as np
|
||||
import torch
|
||||
from yayi import Yayi, ModelArgs
|
||||
from mlx.utils import tree_flatten, tree_map, tree_unflatten
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
|
||||
|
||||
def quantize(weights, config, args):
|
||||
quantized_config = copy.deepcopy(config)
|
||||
|
||||
# Load the model:
|
||||
model_args = ModelArgs(**config)
|
||||
model = Yayi(model_args)
|
||||
|
||||
weights = tree_map(mx.array, weights)
|
||||
model.update(tree_unflatten(list(weights.items())))
|
||||
|
||||
# Quantize the model:
|
||||
nn.QuantizedLinear.quantize_module(model, args.q_group_size, args.q_bits)
|
||||
|
||||
# Update the config:
|
||||
quantized_config["quantization"] = {
|
||||
"group_size": args.q_group_size,
|
||||
"bits": args.q_bits,
|
||||
}
|
||||
quantized_weights = dict(tree_flatten(model.parameters()))
|
||||
|
||||
return quantized_weights, quantized_config
|
||||
|
||||
|
||||
def convert(args):
|
||||
hf_path = Path(args.hf_path)
|
||||
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
str(hf_path), trust_remote_code=True, torch_dtype=torch.float16
|
||||
)
|
||||
config = model.config.to_dict()
|
||||
|
||||
state_dict = model.state_dict()
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
str(hf_path), trust_remote_code=True, use_fast=False
|
||||
)
|
||||
|
||||
# things to change
|
||||
# 1. there's no "model." in the weight names
|
||||
state_dict = {k.replace("model.", ""): v for k, v in state_dict.items()}
|
||||
|
||||
# 2. mlp is called feed_forward
|
||||
state_dict = {k.replace("mlp", "feed_forward"): v for k, v in state_dict.items()}
|
||||
|
||||
# 3. up_proj, down_proj, gate_proj
|
||||
state_dict = {k.replace("down_proj", "w2"): v for k, v in state_dict.items()}
|
||||
state_dict = {k.replace("up_proj", "w3"): v for k, v in state_dict.items()}
|
||||
state_dict = {k.replace("gate_proj", "w1"): v for k, v in state_dict.items()}
|
||||
|
||||
# 4. layernorms
|
||||
state_dict = {
|
||||
k.replace("input_layernorm", "attention_norm"): v for k, v in state_dict.items()
|
||||
}
|
||||
state_dict = {
|
||||
k.replace("post_attention_layernorm", "ffn_norm"): v
|
||||
for k, v in state_dict.items()
|
||||
}
|
||||
|
||||
# 5. lm head
|
||||
state_dict = {k.replace("lm_head", "output"): v for k, v in state_dict.items()}
|
||||
|
||||
# 6. token emb
|
||||
state_dict = {
|
||||
k.replace("embed_tokens", "tok_embeddings"): v for k, v in state_dict.items()
|
||||
}
|
||||
|
||||
# 7. attention
|
||||
state_dict = {k.replace("self_attn", "attention"): v for k, v in state_dict.items()}
|
||||
state_dict = {k.replace("q_proj", "wq"): v for k, v in state_dict.items()}
|
||||
state_dict = {k.replace("k_proj", "wk"): v for k, v in state_dict.items()}
|
||||
state_dict = {k.replace("v_proj", "wv"): v for k, v in state_dict.items()}
|
||||
state_dict = {k.replace("o_proj", "wo"): v for k, v in state_dict.items()}
|
||||
|
||||
weights = {k: v.numpy() for k, v in state_dict.items()}
|
||||
|
||||
keep_keys = set(
|
||||
[
|
||||
"vocab_size",
|
||||
"hidden_size",
|
||||
"num_attention_heads",
|
||||
"num_hidden_layers",
|
||||
"max_position_embeddings",
|
||||
"rms_norm_eps",
|
||||
"intermediate_size",
|
||||
"rope_theta",
|
||||
]
|
||||
)
|
||||
for k in list(config.keys()):
|
||||
if k not in keep_keys:
|
||||
config.pop(k)
|
||||
|
||||
return weights, config, tokenizer
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="Convert Yayi2 model to npz")
|
||||
parser.add_argument(
|
||||
"--hf-path",
|
||||
help="The huggingface model to be converted",
|
||||
default="wenge-research/yayi2-30b",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--mlx-path",
|
||||
type=str,
|
||||
default="mlx_model",
|
||||
help="The path to save the MLX model.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-q",
|
||||
"--quantize",
|
||||
help="Generate a quantized model.",
|
||||
action="store_true",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--q-group-size",
|
||||
help="Group size for quantization.",
|
||||
type=int,
|
||||
default=64,
|
||||
)
|
||||
parser.add_argument(
|
||||
"--q-bits",
|
||||
help="Bits per weight for quantization.",
|
||||
type=int,
|
||||
default=4,
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
mlx_path = Path(args.mlx_path)
|
||||
mlx_path.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
weights, config, tokenizer = convert(args)
|
||||
|
||||
if args.quantize:
|
||||
print("[INFO] Quantizing")
|
||||
weights, config = quantize(weights, config, args)
|
||||
|
||||
np.savez(str(mlx_path / "weights.npz"), **weights)
|
||||
tokenizer.save_pretrained(mlx_path)
|
||||
with open(mlx_path / "config.json", "w") as f:
|
||||
config["model_type"] = "yayi"
|
||||
json.dump(config, f, indent=4)
|
4
llms/yayi2/requirements.txt
Normal file
4
llms/yayi2/requirements.txt
Normal file
@@ -0,0 +1,4 @@
|
||||
torch
|
||||
mlx
|
||||
numpy
|
||||
transformers>=4.35
|
257
llms/yayi2/yayi.py
Normal file
257
llms/yayi2/yayi.py
Normal file
@@ -0,0 +1,257 @@
|
||||
import argparse
|
||||
import json
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
from mlx.utils import tree_unflatten
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelArgs:
|
||||
hidden_size: int = 7168
|
||||
num_attention_heads: int = 64
|
||||
num_hidden_layers: int = 64
|
||||
max_position_embeddings: int = 4096
|
||||
rms_norm_eps: float = 1e-6
|
||||
intermediate_size: int = 16384
|
||||
rope_theta: float = 100000
|
||||
vocab_size: int = 81920
|
||||
rope_traditional: bool = False
|
||||
|
||||
|
||||
class RMSNorm(nn.Module):
|
||||
def __init__(self, dims: int, eps: float = 1e-5):
|
||||
super().__init__()
|
||||
self.weight = mx.ones((dims,))
|
||||
self.eps = eps
|
||||
|
||||
def _norm(self, x):
|
||||
return x * mx.rsqrt(x.square().mean(-1, keepdims=True) + self.eps)
|
||||
|
||||
def __call__(self, x):
|
||||
output = self._norm(x.astype(mx.float32)).astype(x.dtype)
|
||||
return self.weight * output
|
||||
|
||||
|
||||
class Attention(nn.Module):
|
||||
def __init__(self, args: ModelArgs):
|
||||
super().__init__()
|
||||
self.num_attention_heads: int = args.num_attention_heads
|
||||
self.num_key_value_heads: int = args.num_key_value_heads
|
||||
self.hidden_size = args.hidden_size
|
||||
self.head_dim = args.hidden_size // args.num_attention_heads
|
||||
|
||||
self.scale = self.head_dim**-0.5
|
||||
|
||||
self.wq = nn.Linear(
|
||||
args.hidden_size, args.num_attention_heads * self.head_dim, bias=False
|
||||
)
|
||||
self.wk = nn.Linear(args.hidden_size, int(self.head_dim), bias=False)
|
||||
self.wv = nn.Linear(args.hidden_size, int(self.head_dim), bias=False)
|
||||
self.wo = nn.Linear(
|
||||
args.num_attention_heads * self.head_dim, args.hidden_size, bias=False
|
||||
)
|
||||
self.rope = nn.RoPE(
|
||||
self.head_dim, traditional=args.rope_traditional, base=args.rope_theta
|
||||
)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
x: mx.array,
|
||||
mask: Optional[mx.array] = None,
|
||||
cache: Optional[Tuple[mx.array, mx.array]] = None,
|
||||
) -> mx.array:
|
||||
B, L, _ = x.shape
|
||||
|
||||
q, k, v = self.wq(x), self.wk(x), self.wv(x)
|
||||
|
||||
q = q.reshape(B, L, self.num_attention_heads, self.head_dim).transpose(
|
||||
0, 2, 1, 3
|
||||
)
|
||||
k = k.reshape(B, L, 1, self.head_dim).transpose(0, 2, 1, 3)
|
||||
v = v.reshape(B, L, 1, self.head_dim).transpose(0, 2, 1, 3)
|
||||
|
||||
if cache is not None:
|
||||
k_cache, v_cache = cache
|
||||
q = self.rope(q, offset=k_cache.shape[2])
|
||||
k = self.rope(k, offset=k_cache.shape[2])
|
||||
k = mx.concatenate([k_cache, k], axis=2)
|
||||
v = mx.concatenate([v_cache, v], axis=2)
|
||||
|
||||
else:
|
||||
q = self.rope(q)
|
||||
k = self.rope(k)
|
||||
|
||||
scores = (q * self.scale) @ k.transpose(0, 1, 3, 2)
|
||||
|
||||
if mask is not None:
|
||||
scores = scores + mask
|
||||
|
||||
scores = mx.softmax(scores.astype(mx.float32), axis=-1).astype(scores.dtype)
|
||||
v_hat = (scores @ v).transpose(0, 2, 1, 3).reshape(B, L, self.hidden_size)
|
||||
|
||||
return self.wo(v_hat), (k, v)
|
||||
|
||||
|
||||
class FeedForward(nn.Module):
|
||||
def __init__(self, args: ModelArgs):
|
||||
super().__init__()
|
||||
self.w1 = nn.Linear(args.hidden_size, args.intermediate_size, bias=False)
|
||||
self.w2 = nn.Linear(args.intermediate_size, args.hidden_size, bias=False)
|
||||
self.w3 = nn.Linear(args.hidden_size, args.intermediate_size, bias=False)
|
||||
|
||||
def __call__(self, x) -> mx.array:
|
||||
return self.w2(nn.silu(self.w1(x)) * self.w3(x))
|
||||
|
||||
|
||||
class TransformerBlock(nn.Module):
|
||||
def __init__(self, args: ModelArgs):
|
||||
super().__init__()
|
||||
self.attention = Attention(args)
|
||||
self.feed_forward = FeedForward(args=args)
|
||||
self.attention_norm = RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
|
||||
self.ffn_norm = RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
x: mx.array,
|
||||
mask: Optional[mx.array] = None,
|
||||
cache: Optional[Tuple[mx.array, mx.array]] = None,
|
||||
) -> mx.array:
|
||||
r, cache = self.attention(self.attention_norm(x), mask, cache)
|
||||
h = x + r
|
||||
r = self.feed_forward(self.ffn_norm(h))
|
||||
out = h + r
|
||||
return out, cache
|
||||
|
||||
|
||||
class Yayi(nn.Module):
|
||||
def __init__(self, args: ModelArgs):
|
||||
super().__init__()
|
||||
self.args = args
|
||||
self.vocab_size = args.vocab_size
|
||||
self.tok_embeddings = nn.Embedding(args.vocab_size, args.hidden_size)
|
||||
self.layers = [
|
||||
TransformerBlock(args=args) for _ in range(args.num_hidden_layers)
|
||||
]
|
||||
self.norm = RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
|
||||
self.output = nn.Linear(args.hidden_size, args.vocab_size, bias=False)
|
||||
|
||||
def __call__(self, x, mask=None, cache=None):
|
||||
x = self.tok_embeddings(x)
|
||||
mask = None
|
||||
T = x.shape[1]
|
||||
if T > 1:
|
||||
mask = nn.MultiHeadAttention.create_additive_causal_mask(T)
|
||||
mask = mask.astype(x.dtype)
|
||||
|
||||
if cache is None:
|
||||
cache = [None] * len(self.layers)
|
||||
|
||||
for e, layer in enumerate(self.layers):
|
||||
x, cache[e] = layer(x, mask, cache[e])
|
||||
x = self.norm(x)
|
||||
return self.output(x), cache
|
||||
|
||||
|
||||
def generate(
|
||||
prompt: mx.array,
|
||||
model: Yayi,
|
||||
temp: float = 0.0,
|
||||
):
|
||||
def sample(logits):
|
||||
if temp == 0:
|
||||
return mx.argmax(logits, axis=-1)
|
||||
else:
|
||||
return mx.random.categorical(logits * (1 / temp))
|
||||
|
||||
y = prompt
|
||||
cache = None
|
||||
while True:
|
||||
logits, cache = model(y[None], cache=cache)
|
||||
logits = logits[:, -1, :]
|
||||
y = sample(logits)
|
||||
yield y
|
||||
|
||||
|
||||
def load_model(model_path: str):
|
||||
model_path = Path(model_path)
|
||||
with open(model_path / "config.json", "r") as f:
|
||||
config = json.load(f)
|
||||
config.pop("model_type")
|
||||
quantization = config.pop("quantization", None)
|
||||
model_args = ModelArgs(**config)
|
||||
|
||||
model = Yayi(model_args)
|
||||
weights = mx.load(str(model_path / "weights.npz"))
|
||||
if quantization is not None:
|
||||
nn.QuantizedLinear.quantize_module(model, **quantization)
|
||||
model.update(tree_unflatten(list(weights.items())))
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
|
||||
return model, tokenizer
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="Yayi inference script")
|
||||
parser.add_argument(
|
||||
"--model-path",
|
||||
type=str,
|
||||
default="mlx_model",
|
||||
help="The path to the mlx model weights, tokenizer, and config",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--prompt",
|
||||
help="The message to be processed by the model",
|
||||
default="The winter in Beijing is",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max-tokens",
|
||||
"-m",
|
||||
type=int,
|
||||
default=100,
|
||||
help="Maximum number of tokens to generate",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--temp",
|
||||
help="The sampling temperature.",
|
||||
type=float,
|
||||
default=0.6,
|
||||
)
|
||||
parser.add_argument("--seed", type=int, default=0, help="The PRNG seed")
|
||||
args = parser.parse_args()
|
||||
|
||||
mx.random.seed(args.seed)
|
||||
|
||||
model, tokenizer = load_model(args.model_path)
|
||||
|
||||
prompt = tokenizer(
|
||||
args.prompt,
|
||||
return_tensors="np",
|
||||
return_attention_mask=False,
|
||||
)[
|
||||
"input_ids"
|
||||
][0]
|
||||
|
||||
prompt = mx.array(prompt)
|
||||
|
||||
print(args.prompt, end="", flush=True)
|
||||
|
||||
tokens = []
|
||||
skip = 0
|
||||
for token, _ in zip(
|
||||
generate(prompt, model, args.temp),
|
||||
range(args.max_tokens),
|
||||
):
|
||||
if token == tokenizer.eos_token_id:
|
||||
break
|
||||
tokens.append(token.item())
|
||||
s = tokenizer.decode(tokens)
|
||||
print(s[skip:], end="", flush=True)
|
||||
skip = len(s)
|
||||
|
||||
print(tokenizer.decode(tokens)[skip:], flush=True)
|
Reference in New Issue
Block a user