some updates / style consistency

This commit is contained in:
Awni Hannun 2023-12-19 12:58:59 -08:00
parent 2a9c5e8a8c
commit a476ed9f50
4 changed files with 61 additions and 47 deletions

View File

@ -1,25 +1,37 @@
# Qwen
Qwen (通义千问) is a language model proposed by Alibaba Cloud[^1]. The architecture of Qwen is similar to Llama except for the bias in the attention layers.
Qwen (通义千问) is a language model developed by Alibaba Cloud.[^1] The
architecture of Qwen is similar to Llama except for the bias in the attention
layers.
## Setup
Download (from huggingface) and conver the model. By default, the model is `Qwen/Qwen-1_8B`.
First download and convert the model with:
```sh
python convert.py
```
The script downloads the model from Hugging Face. The default model is
`Qwen/Qwen-1_8B`. Check out the [Hugging Face page](https://huggingface.co/Qwen) to see a list of available models.
This will make the `weights.npz` file which MLX can read.
The conversion script will make the `weights.npz` and `params.json` files in
the working directory.
## Generate
To generate text with the default prompt (default tokenizer is `Qwen/Qwen-1_8B`):
To generate text with the default prompt:
```sh
python qwen.py
```
If you change the model, make sure to pass the corresponding tokenizer. E.g.,
for Qwen 7B use:
```
python qwen.py --tokenizer Qwen/Qwen-7B
```
To see a list of options, run:
```sh

View File

@ -25,7 +25,7 @@ def convert(model_path: str = "Qwen/Qwen-1_8B"):
config = model.config
config_dict = config.to_dict()
with open("config.json", "w") as f:
json.dump(config_dict, f)
json.dump(config_dict, f, indent=4)
if __name__ == "__main__":

View File

@ -1,13 +1,9 @@
# The architecture of qwen is similar to Llama.
# This inference script is mainly for compatibility with the huggingface model of qwen.
import argparse
from dataclasses import dataclass
import json
import mlx.core as mx
import mlx.nn as nn
from mlx.utils import tree_unflatten
from dataclasses import dataclass
from transformers import AutoTokenizer
@ -17,37 +13,44 @@ class ModelArgs:
num_attention_heads: int = 16
num_hidden_layers: int = 24
kv_channels: int = 128
max_position_embeddings: int = 8192
layer_norm_epsilon: float = 1e-6
intermediate_size: int = 11008
no_bias: bool = True
vocab_size: int = 151936
class QWenAttntion(nn.Module):
class RMSNorm(nn.Module):
def __init__(self, dims: int, eps: float = 1e-5):
super().__init__()
self.weight = mx.ones((dims,))
self.eps = eps
def _norm(self, x):
return x * mx.rsqrt(x.square().mean(-1, keepdims=True) + self.eps)
def __call__(self, x):
output = self._norm(x.astype(mx.float32)).astype(x.dtype)
return self.weight * output
class Attention(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.hidden_size = args.hidden_size
hidden_size = args.hidden_size
self.num_attention_heads = args.num_attention_heads
self.hidden_size_per_attention_head = (
self.hidden_size // self.num_attention_heads
)
hidden_size_per_attention_head = hidden_size // self.num_attention_heads
self.rotary_emb = nn.RoPE(
self.hidden_size_per_attention_head, traditional=False
)
self.rotary_emb = nn.RoPE(hidden_size_per_attention_head, traditional=False)
self.proj_size = args.kv_channels * self.num_attention_heads
proj_size = args.kv_channels * self.num_attention_heads
self.c_attn = nn.Linear(self.hidden_size, self.proj_size * 3, bias=True)
self.c_proj = nn.Linear(self.hidden_size, self.proj_size, bias=not args.no_bias)
self.c_attn = nn.Linear(hidden_size, proj_size * 3, bias=True)
self.c_proj = nn.Linear(hidden_size, proj_size, bias=not args.no_bias)
self.scale = self.hidden_size_per_attention_head**-0.5
self.scale = hidden_size_per_attention_head**-0.5
def __call__(self, x, mask=None, cache=None):
qkv = self.c_attn(x)
@ -76,13 +79,13 @@ class QWenAttntion(nn.Module):
if mask is not None:
scores = scores + mask
scores = mx.softmax(scores, axis=-1).astype(v.dtype)
scores = mx.softmax(scores.astype(mx.float32), axis=-1).astype(scores.dtype)
v_hat = (scores @ v).transpose(0, 2, 1, 3).reshape(B, L, -1)
return self.c_proj(v_hat), (k, v)
class QWenMlp(nn.Module):
class MLP(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
@ -99,19 +102,17 @@ class QWenMlp(nn.Module):
def __call__(self, x):
a1 = self.w1(x)
a2 = self.w2(x)
intermediate_parallel = a1 * nn.silu(a2)
out = self.c_proj(intermediate_parallel)
return out
return self.c_proj(a1 * nn.silu(a2))
class QWenBlock(nn.Module):
class TransformerBlock(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.ln_1 = nn.RMSNorm(args.hidden_size, eps=args.layer_norm_epsilon)
self.attn = QWenAttntion(args)
self.ln_2 = nn.RMSNorm(args.hidden_size, eps=args.layer_norm_epsilon)
self.mlp = QWenMlp(args)
self.ln_1 = RMSNorm(args.hidden_size, eps=args.layer_norm_epsilon)
self.attn = Attention(args)
self.ln_2 = RMSNorm(args.hidden_size, eps=args.layer_norm_epsilon)
self.mlp = MLP(args)
def __call__(self, x, mask=None, cache=None):
residual = x
@ -125,15 +126,15 @@ class QWenBlock(nn.Module):
return x, cache
class QWen(nn.Module):
class Qwen(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.embed_dim = args.hidden_size
self.wte = nn.Embedding(args.vocab_size, args.hidden_size)
self.h = [QWenBlock(args) for _ in range(args.num_hidden_layers)]
self.ln_f = nn.RMSNorm(self.embed_dim, eps=args.layer_norm_epsilon)
self.h = [TransformerBlock(args) for _ in range(args.num_hidden_layers)]
self.ln_f = RMSNorm(self.embed_dim, eps=args.layer_norm_epsilon)
self.lm_head = nn.Linear(self.embed_dim, args.vocab_size, bias=False)
@ -141,8 +142,9 @@ class QWen(nn.Module):
x = self.wte(inputs)
mask = None
if x.shape[1] > 1:
mask = nn.MultiHeadAttention.create_additive_causal_mask(x.shape[1])
T = x.shape[1]
if T > 1:
mask = nn.MultiHeadAttention.create_additive_causal_mask(T)
mask = mask.astype(x.dtype)
if cache is None:
@ -151,12 +153,11 @@ class QWen(nn.Module):
for e, layer in enumerate(self.h):
x, cache[e] = layer(x, mask, cache[e])
x = self.ln_f(x)
x = self.ln_f(x[:, T - 1 : T, :])
return self.lm_head(x), cache
def generate(prompt: mx.array, model: QWen, temp: 0.0):
def generate(prompt: mx.array, model: Qwen, temp: 0.0):
def sample(logits):
if temp == 0:
return mx.argmax(logits, axis=-1)
@ -190,7 +191,7 @@ def load_model(
model_args.intermediate_size = config["intermediate_size"]
model_args.no_bias = config["no_bias"]
model = QWen(model_args)
model = Qwen(model_args)
weights = mx.load("weights.npz")
model.update(tree_unflatten(list(weights.items())))
@ -201,8 +202,6 @@ def load_model(
if __name__ == "__main__":
# The infernece code and arguments were mainly derived from phi-2 example.
parser = argparse.ArgumentParser(description="Qwen inference script")
parser.add_argument(
"--tokenizer",

View File

@ -1,4 +1,7 @@
einops
mlx
numpy
transformers>=4.35
transformers_stream_generator>=0.0.4
torch
tiktoken