mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-08-30 10:56:38 +08:00
some updates / style consistency
This commit is contained in:
parent
2a9c5e8a8c
commit
a476ed9f50
@ -1,25 +1,37 @@
|
|||||||
# Qwen
|
# Qwen
|
||||||
|
|
||||||
Qwen (通义千问) is a language model proposed by Alibaba Cloud[^1]. The architecture of Qwen is similar to Llama except for the bias in the attention layers.
|
Qwen (通义千问) is a language model developed by Alibaba Cloud.[^1] The
|
||||||
|
architecture of Qwen is similar to Llama except for the bias in the attention
|
||||||
|
layers.
|
||||||
|
|
||||||
## Setup
|
## Setup
|
||||||
|
|
||||||
Download (from huggingface) and conver the model. By default, the model is `Qwen/Qwen-1_8B`.
|
First download and convert the model with:
|
||||||
|
|
||||||
```sh
|
```sh
|
||||||
python convert.py
|
python convert.py
|
||||||
```
|
```
|
||||||
|
The script downloads the model from Hugging Face. The default model is
|
||||||
|
`Qwen/Qwen-1_8B`. Check out the [Hugging Face page](https://huggingface.co/Qwen) to see a list of available models.
|
||||||
|
|
||||||
This will make the `weights.npz` file which MLX can read.
|
The conversion script will make the `weights.npz` and `params.json` files in
|
||||||
|
the working directory.
|
||||||
|
|
||||||
## Generate
|
## Generate
|
||||||
|
|
||||||
To generate text with the default prompt (default tokenizer is `Qwen/Qwen-1_8B`):
|
To generate text with the default prompt:
|
||||||
|
|
||||||
```sh
|
```sh
|
||||||
python qwen.py
|
python qwen.py
|
||||||
```
|
```
|
||||||
|
|
||||||
|
If you change the model, make sure to pass the corresponding tokenizer. E.g.,
|
||||||
|
for Qwen 7B use:
|
||||||
|
|
||||||
|
```
|
||||||
|
python qwen.py --tokenizer Qwen/Qwen-7B
|
||||||
|
```
|
||||||
|
|
||||||
To see a list of options, run:
|
To see a list of options, run:
|
||||||
|
|
||||||
```sh
|
```sh
|
||||||
|
@ -25,7 +25,7 @@ def convert(model_path: str = "Qwen/Qwen-1_8B"):
|
|||||||
config = model.config
|
config = model.config
|
||||||
config_dict = config.to_dict()
|
config_dict = config.to_dict()
|
||||||
with open("config.json", "w") as f:
|
with open("config.json", "w") as f:
|
||||||
json.dump(config_dict, f)
|
json.dump(config_dict, f, indent=4)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
81
qwen/qwen.py
81
qwen/qwen.py
@ -1,13 +1,9 @@
|
|||||||
# The architecture of qwen is similar to Llama.
|
|
||||||
# This inference script is mainly for compatibility with the huggingface model of qwen.
|
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
|
from dataclasses import dataclass
|
||||||
import json
|
import json
|
||||||
import mlx.core as mx
|
import mlx.core as mx
|
||||||
import mlx.nn as nn
|
import mlx.nn as nn
|
||||||
|
|
||||||
from mlx.utils import tree_unflatten
|
from mlx.utils import tree_unflatten
|
||||||
from dataclasses import dataclass
|
|
||||||
from transformers import AutoTokenizer
|
from transformers import AutoTokenizer
|
||||||
|
|
||||||
|
|
||||||
@ -17,37 +13,44 @@ class ModelArgs:
|
|||||||
num_attention_heads: int = 16
|
num_attention_heads: int = 16
|
||||||
num_hidden_layers: int = 24
|
num_hidden_layers: int = 24
|
||||||
kv_channels: int = 128
|
kv_channels: int = 128
|
||||||
|
|
||||||
max_position_embeddings: int = 8192
|
max_position_embeddings: int = 8192
|
||||||
layer_norm_epsilon: float = 1e-6
|
layer_norm_epsilon: float = 1e-6
|
||||||
intermediate_size: int = 11008
|
intermediate_size: int = 11008
|
||||||
|
|
||||||
no_bias: bool = True
|
no_bias: bool = True
|
||||||
|
|
||||||
vocab_size: int = 151936
|
vocab_size: int = 151936
|
||||||
|
|
||||||
|
|
||||||
class QWenAttntion(nn.Module):
|
class RMSNorm(nn.Module):
|
||||||
|
def __init__(self, dims: int, eps: float = 1e-5):
|
||||||
|
super().__init__()
|
||||||
|
self.weight = mx.ones((dims,))
|
||||||
|
self.eps = eps
|
||||||
|
|
||||||
|
def _norm(self, x):
|
||||||
|
return x * mx.rsqrt(x.square().mean(-1, keepdims=True) + self.eps)
|
||||||
|
|
||||||
|
def __call__(self, x):
|
||||||
|
output = self._norm(x.astype(mx.float32)).astype(x.dtype)
|
||||||
|
return self.weight * output
|
||||||
|
|
||||||
|
|
||||||
|
class Attention(nn.Module):
|
||||||
def __init__(self, args: ModelArgs):
|
def __init__(self, args: ModelArgs):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
self.hidden_size = args.hidden_size
|
hidden_size = args.hidden_size
|
||||||
self.num_attention_heads = args.num_attention_heads
|
self.num_attention_heads = args.num_attention_heads
|
||||||
|
|
||||||
self.hidden_size_per_attention_head = (
|
hidden_size_per_attention_head = hidden_size // self.num_attention_heads
|
||||||
self.hidden_size // self.num_attention_heads
|
|
||||||
)
|
|
||||||
|
|
||||||
self.rotary_emb = nn.RoPE(
|
self.rotary_emb = nn.RoPE(hidden_size_per_attention_head, traditional=False)
|
||||||
self.hidden_size_per_attention_head, traditional=False
|
|
||||||
)
|
|
||||||
|
|
||||||
self.proj_size = args.kv_channels * self.num_attention_heads
|
proj_size = args.kv_channels * self.num_attention_heads
|
||||||
|
|
||||||
self.c_attn = nn.Linear(self.hidden_size, self.proj_size * 3, bias=True)
|
self.c_attn = nn.Linear(hidden_size, proj_size * 3, bias=True)
|
||||||
self.c_proj = nn.Linear(self.hidden_size, self.proj_size, bias=not args.no_bias)
|
self.c_proj = nn.Linear(hidden_size, proj_size, bias=not args.no_bias)
|
||||||
|
|
||||||
self.scale = self.hidden_size_per_attention_head**-0.5
|
self.scale = hidden_size_per_attention_head**-0.5
|
||||||
|
|
||||||
def __call__(self, x, mask=None, cache=None):
|
def __call__(self, x, mask=None, cache=None):
|
||||||
qkv = self.c_attn(x)
|
qkv = self.c_attn(x)
|
||||||
@ -76,13 +79,13 @@ class QWenAttntion(nn.Module):
|
|||||||
if mask is not None:
|
if mask is not None:
|
||||||
scores = scores + mask
|
scores = scores + mask
|
||||||
|
|
||||||
scores = mx.softmax(scores, axis=-1).astype(v.dtype)
|
scores = mx.softmax(scores.astype(mx.float32), axis=-1).astype(scores.dtype)
|
||||||
v_hat = (scores @ v).transpose(0, 2, 1, 3).reshape(B, L, -1)
|
v_hat = (scores @ v).transpose(0, 2, 1, 3).reshape(B, L, -1)
|
||||||
|
|
||||||
return self.c_proj(v_hat), (k, v)
|
return self.c_proj(v_hat), (k, v)
|
||||||
|
|
||||||
|
|
||||||
class QWenMlp(nn.Module):
|
class MLP(nn.Module):
|
||||||
def __init__(self, args: ModelArgs):
|
def __init__(self, args: ModelArgs):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
@ -99,19 +102,17 @@ class QWenMlp(nn.Module):
|
|||||||
def __call__(self, x):
|
def __call__(self, x):
|
||||||
a1 = self.w1(x)
|
a1 = self.w1(x)
|
||||||
a2 = self.w2(x)
|
a2 = self.w2(x)
|
||||||
intermediate_parallel = a1 * nn.silu(a2)
|
return self.c_proj(a1 * nn.silu(a2))
|
||||||
out = self.c_proj(intermediate_parallel)
|
|
||||||
return out
|
|
||||||
|
|
||||||
|
|
||||||
class QWenBlock(nn.Module):
|
class TransformerBlock(nn.Module):
|
||||||
def __init__(self, args: ModelArgs):
|
def __init__(self, args: ModelArgs):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
self.ln_1 = nn.RMSNorm(args.hidden_size, eps=args.layer_norm_epsilon)
|
self.ln_1 = RMSNorm(args.hidden_size, eps=args.layer_norm_epsilon)
|
||||||
self.attn = QWenAttntion(args)
|
self.attn = Attention(args)
|
||||||
self.ln_2 = nn.RMSNorm(args.hidden_size, eps=args.layer_norm_epsilon)
|
self.ln_2 = RMSNorm(args.hidden_size, eps=args.layer_norm_epsilon)
|
||||||
self.mlp = QWenMlp(args)
|
self.mlp = MLP(args)
|
||||||
|
|
||||||
def __call__(self, x, mask=None, cache=None):
|
def __call__(self, x, mask=None, cache=None):
|
||||||
residual = x
|
residual = x
|
||||||
@ -125,15 +126,15 @@ class QWenBlock(nn.Module):
|
|||||||
return x, cache
|
return x, cache
|
||||||
|
|
||||||
|
|
||||||
class QWen(nn.Module):
|
class Qwen(nn.Module):
|
||||||
def __init__(self, args: ModelArgs):
|
def __init__(self, args: ModelArgs):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
self.embed_dim = args.hidden_size
|
self.embed_dim = args.hidden_size
|
||||||
|
|
||||||
self.wte = nn.Embedding(args.vocab_size, args.hidden_size)
|
self.wte = nn.Embedding(args.vocab_size, args.hidden_size)
|
||||||
self.h = [QWenBlock(args) for _ in range(args.num_hidden_layers)]
|
self.h = [TransformerBlock(args) for _ in range(args.num_hidden_layers)]
|
||||||
self.ln_f = nn.RMSNorm(self.embed_dim, eps=args.layer_norm_epsilon)
|
self.ln_f = RMSNorm(self.embed_dim, eps=args.layer_norm_epsilon)
|
||||||
|
|
||||||
self.lm_head = nn.Linear(self.embed_dim, args.vocab_size, bias=False)
|
self.lm_head = nn.Linear(self.embed_dim, args.vocab_size, bias=False)
|
||||||
|
|
||||||
@ -141,8 +142,9 @@ class QWen(nn.Module):
|
|||||||
x = self.wte(inputs)
|
x = self.wte(inputs)
|
||||||
|
|
||||||
mask = None
|
mask = None
|
||||||
if x.shape[1] > 1:
|
T = x.shape[1]
|
||||||
mask = nn.MultiHeadAttention.create_additive_causal_mask(x.shape[1])
|
if T > 1:
|
||||||
|
mask = nn.MultiHeadAttention.create_additive_causal_mask(T)
|
||||||
mask = mask.astype(x.dtype)
|
mask = mask.astype(x.dtype)
|
||||||
|
|
||||||
if cache is None:
|
if cache is None:
|
||||||
@ -151,12 +153,11 @@ class QWen(nn.Module):
|
|||||||
for e, layer in enumerate(self.h):
|
for e, layer in enumerate(self.h):
|
||||||
x, cache[e] = layer(x, mask, cache[e])
|
x, cache[e] = layer(x, mask, cache[e])
|
||||||
|
|
||||||
x = self.ln_f(x)
|
x = self.ln_f(x[:, T - 1 : T, :])
|
||||||
|
|
||||||
return self.lm_head(x), cache
|
return self.lm_head(x), cache
|
||||||
|
|
||||||
|
|
||||||
def generate(prompt: mx.array, model: QWen, temp: 0.0):
|
def generate(prompt: mx.array, model: Qwen, temp: 0.0):
|
||||||
def sample(logits):
|
def sample(logits):
|
||||||
if temp == 0:
|
if temp == 0:
|
||||||
return mx.argmax(logits, axis=-1)
|
return mx.argmax(logits, axis=-1)
|
||||||
@ -190,7 +191,7 @@ def load_model(
|
|||||||
model_args.intermediate_size = config["intermediate_size"]
|
model_args.intermediate_size = config["intermediate_size"]
|
||||||
model_args.no_bias = config["no_bias"]
|
model_args.no_bias = config["no_bias"]
|
||||||
|
|
||||||
model = QWen(model_args)
|
model = Qwen(model_args)
|
||||||
|
|
||||||
weights = mx.load("weights.npz")
|
weights = mx.load("weights.npz")
|
||||||
model.update(tree_unflatten(list(weights.items())))
|
model.update(tree_unflatten(list(weights.items())))
|
||||||
@ -201,8 +202,6 @@ def load_model(
|
|||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
# The infernece code and arguments were mainly derived from phi-2 example.
|
|
||||||
|
|
||||||
parser = argparse.ArgumentParser(description="Qwen inference script")
|
parser = argparse.ArgumentParser(description="Qwen inference script")
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--tokenizer",
|
"--tokenizer",
|
||||||
|
@ -1,4 +1,7 @@
|
|||||||
|
einops
|
||||||
mlx
|
mlx
|
||||||
numpy
|
numpy
|
||||||
transformers>=4.35
|
transformers>=4.35
|
||||||
|
transformers_stream_generator>=0.0.4
|
||||||
torch
|
torch
|
||||||
|
tiktoken
|
||||||
|
Loading…
Reference in New Issue
Block a user