mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-08-30 02:53:41 +08:00
some updates / style consistency
This commit is contained in:
parent
2a9c5e8a8c
commit
a476ed9f50
@ -1,25 +1,37 @@
|
||||
# Qwen
|
||||
|
||||
Qwen (通义千问) is a language model proposed by Alibaba Cloud[^1]. The architecture of Qwen is similar to Llama except for the bias in the attention layers.
|
||||
Qwen (通义千问) is a language model developed by Alibaba Cloud.[^1] The
|
||||
architecture of Qwen is similar to Llama except for the bias in the attention
|
||||
layers.
|
||||
|
||||
## Setup
|
||||
|
||||
Download (from huggingface) and conver the model. By default, the model is `Qwen/Qwen-1_8B`.
|
||||
First download and convert the model with:
|
||||
|
||||
```sh
|
||||
python convert.py
|
||||
```
|
||||
The script downloads the model from Hugging Face. The default model is
|
||||
`Qwen/Qwen-1_8B`. Check out the [Hugging Face page](https://huggingface.co/Qwen) to see a list of available models.
|
||||
|
||||
This will make the `weights.npz` file which MLX can read.
|
||||
The conversion script will make the `weights.npz` and `params.json` files in
|
||||
the working directory.
|
||||
|
||||
## Generate
|
||||
|
||||
To generate text with the default prompt (default tokenizer is `Qwen/Qwen-1_8B`):
|
||||
To generate text with the default prompt:
|
||||
|
||||
```sh
|
||||
python qwen.py
|
||||
```
|
||||
|
||||
If you change the model, make sure to pass the corresponding tokenizer. E.g.,
|
||||
for Qwen 7B use:
|
||||
|
||||
```
|
||||
python qwen.py --tokenizer Qwen/Qwen-7B
|
||||
```
|
||||
|
||||
To see a list of options, run:
|
||||
|
||||
```sh
|
||||
|
@ -25,7 +25,7 @@ def convert(model_path: str = "Qwen/Qwen-1_8B"):
|
||||
config = model.config
|
||||
config_dict = config.to_dict()
|
||||
with open("config.json", "w") as f:
|
||||
json.dump(config_dict, f)
|
||||
json.dump(config_dict, f, indent=4)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
81
qwen/qwen.py
81
qwen/qwen.py
@ -1,13 +1,9 @@
|
||||
# The architecture of qwen is similar to Llama.
|
||||
# This inference script is mainly for compatibility with the huggingface model of qwen.
|
||||
|
||||
import argparse
|
||||
from dataclasses import dataclass
|
||||
import json
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
|
||||
from mlx.utils import tree_unflatten
|
||||
from dataclasses import dataclass
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
|
||||
@ -17,37 +13,44 @@ class ModelArgs:
|
||||
num_attention_heads: int = 16
|
||||
num_hidden_layers: int = 24
|
||||
kv_channels: int = 128
|
||||
|
||||
max_position_embeddings: int = 8192
|
||||
layer_norm_epsilon: float = 1e-6
|
||||
intermediate_size: int = 11008
|
||||
|
||||
no_bias: bool = True
|
||||
|
||||
vocab_size: int = 151936
|
||||
|
||||
|
||||
class QWenAttntion(nn.Module):
|
||||
class RMSNorm(nn.Module):
|
||||
def __init__(self, dims: int, eps: float = 1e-5):
|
||||
super().__init__()
|
||||
self.weight = mx.ones((dims,))
|
||||
self.eps = eps
|
||||
|
||||
def _norm(self, x):
|
||||
return x * mx.rsqrt(x.square().mean(-1, keepdims=True) + self.eps)
|
||||
|
||||
def __call__(self, x):
|
||||
output = self._norm(x.astype(mx.float32)).astype(x.dtype)
|
||||
return self.weight * output
|
||||
|
||||
|
||||
class Attention(nn.Module):
|
||||
def __init__(self, args: ModelArgs):
|
||||
super().__init__()
|
||||
|
||||
self.hidden_size = args.hidden_size
|
||||
hidden_size = args.hidden_size
|
||||
self.num_attention_heads = args.num_attention_heads
|
||||
|
||||
self.hidden_size_per_attention_head = (
|
||||
self.hidden_size // self.num_attention_heads
|
||||
)
|
||||
hidden_size_per_attention_head = hidden_size // self.num_attention_heads
|
||||
|
||||
self.rotary_emb = nn.RoPE(
|
||||
self.hidden_size_per_attention_head, traditional=False
|
||||
)
|
||||
self.rotary_emb = nn.RoPE(hidden_size_per_attention_head, traditional=False)
|
||||
|
||||
self.proj_size = args.kv_channels * self.num_attention_heads
|
||||
proj_size = args.kv_channels * self.num_attention_heads
|
||||
|
||||
self.c_attn = nn.Linear(self.hidden_size, self.proj_size * 3, bias=True)
|
||||
self.c_proj = nn.Linear(self.hidden_size, self.proj_size, bias=not args.no_bias)
|
||||
self.c_attn = nn.Linear(hidden_size, proj_size * 3, bias=True)
|
||||
self.c_proj = nn.Linear(hidden_size, proj_size, bias=not args.no_bias)
|
||||
|
||||
self.scale = self.hidden_size_per_attention_head**-0.5
|
||||
self.scale = hidden_size_per_attention_head**-0.5
|
||||
|
||||
def __call__(self, x, mask=None, cache=None):
|
||||
qkv = self.c_attn(x)
|
||||
@ -76,13 +79,13 @@ class QWenAttntion(nn.Module):
|
||||
if mask is not None:
|
||||
scores = scores + mask
|
||||
|
||||
scores = mx.softmax(scores, axis=-1).astype(v.dtype)
|
||||
scores = mx.softmax(scores.astype(mx.float32), axis=-1).astype(scores.dtype)
|
||||
v_hat = (scores @ v).transpose(0, 2, 1, 3).reshape(B, L, -1)
|
||||
|
||||
return self.c_proj(v_hat), (k, v)
|
||||
|
||||
|
||||
class QWenMlp(nn.Module):
|
||||
class MLP(nn.Module):
|
||||
def __init__(self, args: ModelArgs):
|
||||
super().__init__()
|
||||
|
||||
@ -99,19 +102,17 @@ class QWenMlp(nn.Module):
|
||||
def __call__(self, x):
|
||||
a1 = self.w1(x)
|
||||
a2 = self.w2(x)
|
||||
intermediate_parallel = a1 * nn.silu(a2)
|
||||
out = self.c_proj(intermediate_parallel)
|
||||
return out
|
||||
return self.c_proj(a1 * nn.silu(a2))
|
||||
|
||||
|
||||
class QWenBlock(nn.Module):
|
||||
class TransformerBlock(nn.Module):
|
||||
def __init__(self, args: ModelArgs):
|
||||
super().__init__()
|
||||
|
||||
self.ln_1 = nn.RMSNorm(args.hidden_size, eps=args.layer_norm_epsilon)
|
||||
self.attn = QWenAttntion(args)
|
||||
self.ln_2 = nn.RMSNorm(args.hidden_size, eps=args.layer_norm_epsilon)
|
||||
self.mlp = QWenMlp(args)
|
||||
self.ln_1 = RMSNorm(args.hidden_size, eps=args.layer_norm_epsilon)
|
||||
self.attn = Attention(args)
|
||||
self.ln_2 = RMSNorm(args.hidden_size, eps=args.layer_norm_epsilon)
|
||||
self.mlp = MLP(args)
|
||||
|
||||
def __call__(self, x, mask=None, cache=None):
|
||||
residual = x
|
||||
@ -125,15 +126,15 @@ class QWenBlock(nn.Module):
|
||||
return x, cache
|
||||
|
||||
|
||||
class QWen(nn.Module):
|
||||
class Qwen(nn.Module):
|
||||
def __init__(self, args: ModelArgs):
|
||||
super().__init__()
|
||||
|
||||
self.embed_dim = args.hidden_size
|
||||
|
||||
self.wte = nn.Embedding(args.vocab_size, args.hidden_size)
|
||||
self.h = [QWenBlock(args) for _ in range(args.num_hidden_layers)]
|
||||
self.ln_f = nn.RMSNorm(self.embed_dim, eps=args.layer_norm_epsilon)
|
||||
self.h = [TransformerBlock(args) for _ in range(args.num_hidden_layers)]
|
||||
self.ln_f = RMSNorm(self.embed_dim, eps=args.layer_norm_epsilon)
|
||||
|
||||
self.lm_head = nn.Linear(self.embed_dim, args.vocab_size, bias=False)
|
||||
|
||||
@ -141,8 +142,9 @@ class QWen(nn.Module):
|
||||
x = self.wte(inputs)
|
||||
|
||||
mask = None
|
||||
if x.shape[1] > 1:
|
||||
mask = nn.MultiHeadAttention.create_additive_causal_mask(x.shape[1])
|
||||
T = x.shape[1]
|
||||
if T > 1:
|
||||
mask = nn.MultiHeadAttention.create_additive_causal_mask(T)
|
||||
mask = mask.astype(x.dtype)
|
||||
|
||||
if cache is None:
|
||||
@ -151,12 +153,11 @@ class QWen(nn.Module):
|
||||
for e, layer in enumerate(self.h):
|
||||
x, cache[e] = layer(x, mask, cache[e])
|
||||
|
||||
x = self.ln_f(x)
|
||||
|
||||
x = self.ln_f(x[:, T - 1 : T, :])
|
||||
return self.lm_head(x), cache
|
||||
|
||||
|
||||
def generate(prompt: mx.array, model: QWen, temp: 0.0):
|
||||
def generate(prompt: mx.array, model: Qwen, temp: 0.0):
|
||||
def sample(logits):
|
||||
if temp == 0:
|
||||
return mx.argmax(logits, axis=-1)
|
||||
@ -190,7 +191,7 @@ def load_model(
|
||||
model_args.intermediate_size = config["intermediate_size"]
|
||||
model_args.no_bias = config["no_bias"]
|
||||
|
||||
model = QWen(model_args)
|
||||
model = Qwen(model_args)
|
||||
|
||||
weights = mx.load("weights.npz")
|
||||
model.update(tree_unflatten(list(weights.items())))
|
||||
@ -201,8 +202,6 @@ def load_model(
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# The infernece code and arguments were mainly derived from phi-2 example.
|
||||
|
||||
parser = argparse.ArgumentParser(description="Qwen inference script")
|
||||
parser.add_argument(
|
||||
"--tokenizer",
|
||||
|
@ -1,4 +1,7 @@
|
||||
einops
|
||||
mlx
|
||||
numpy
|
||||
transformers>=4.35
|
||||
transformers_stream_generator>=0.0.4
|
||||
torch
|
||||
tiktoken
|
||||
|
Loading…
Reference in New Issue
Block a user