some updates / style consistency

This commit is contained in:
Awni Hannun 2023-12-19 12:58:59 -08:00
parent 2a9c5e8a8c
commit a476ed9f50
4 changed files with 61 additions and 47 deletions

View File

@ -1,25 +1,37 @@
# Qwen # Qwen
Qwen (通义千问) is a language model proposed by Alibaba Cloud[^1]. The architecture of Qwen is similar to Llama except for the bias in the attention layers. Qwen (通义千问) is a language model developed by Alibaba Cloud.[^1] The
architecture of Qwen is similar to Llama except for the bias in the attention
layers.
## Setup ## Setup
Download (from huggingface) and conver the model. By default, the model is `Qwen/Qwen-1_8B`. First download and convert the model with:
```sh ```sh
python convert.py python convert.py
``` ```
The script downloads the model from Hugging Face. The default model is
`Qwen/Qwen-1_8B`. Check out the [Hugging Face page](https://huggingface.co/Qwen) to see a list of available models.
This will make the `weights.npz` file which MLX can read. The conversion script will make the `weights.npz` and `params.json` files in
the working directory.
## Generate ## Generate
To generate text with the default prompt (default tokenizer is `Qwen/Qwen-1_8B`): To generate text with the default prompt:
```sh ```sh
python qwen.py python qwen.py
``` ```
If you change the model, make sure to pass the corresponding tokenizer. E.g.,
for Qwen 7B use:
```
python qwen.py --tokenizer Qwen/Qwen-7B
```
To see a list of options, run: To see a list of options, run:
```sh ```sh

View File

@ -25,7 +25,7 @@ def convert(model_path: str = "Qwen/Qwen-1_8B"):
config = model.config config = model.config
config_dict = config.to_dict() config_dict = config.to_dict()
with open("config.json", "w") as f: with open("config.json", "w") as f:
json.dump(config_dict, f) json.dump(config_dict, f, indent=4)
if __name__ == "__main__": if __name__ == "__main__":

View File

@ -1,13 +1,9 @@
# The architecture of qwen is similar to Llama.
# This inference script is mainly for compatibility with the huggingface model of qwen.
import argparse import argparse
from dataclasses import dataclass
import json import json
import mlx.core as mx import mlx.core as mx
import mlx.nn as nn import mlx.nn as nn
from mlx.utils import tree_unflatten from mlx.utils import tree_unflatten
from dataclasses import dataclass
from transformers import AutoTokenizer from transformers import AutoTokenizer
@ -17,37 +13,44 @@ class ModelArgs:
num_attention_heads: int = 16 num_attention_heads: int = 16
num_hidden_layers: int = 24 num_hidden_layers: int = 24
kv_channels: int = 128 kv_channels: int = 128
max_position_embeddings: int = 8192 max_position_embeddings: int = 8192
layer_norm_epsilon: float = 1e-6 layer_norm_epsilon: float = 1e-6
intermediate_size: int = 11008 intermediate_size: int = 11008
no_bias: bool = True no_bias: bool = True
vocab_size: int = 151936 vocab_size: int = 151936
class QWenAttntion(nn.Module): class RMSNorm(nn.Module):
def __init__(self, dims: int, eps: float = 1e-5):
super().__init__()
self.weight = mx.ones((dims,))
self.eps = eps
def _norm(self, x):
return x * mx.rsqrt(x.square().mean(-1, keepdims=True) + self.eps)
def __call__(self, x):
output = self._norm(x.astype(mx.float32)).astype(x.dtype)
return self.weight * output
class Attention(nn.Module):
def __init__(self, args: ModelArgs): def __init__(self, args: ModelArgs):
super().__init__() super().__init__()
self.hidden_size = args.hidden_size hidden_size = args.hidden_size
self.num_attention_heads = args.num_attention_heads self.num_attention_heads = args.num_attention_heads
self.hidden_size_per_attention_head = ( hidden_size_per_attention_head = hidden_size // self.num_attention_heads
self.hidden_size // self.num_attention_heads
)
self.rotary_emb = nn.RoPE( self.rotary_emb = nn.RoPE(hidden_size_per_attention_head, traditional=False)
self.hidden_size_per_attention_head, traditional=False
)
self.proj_size = args.kv_channels * self.num_attention_heads proj_size = args.kv_channels * self.num_attention_heads
self.c_attn = nn.Linear(self.hidden_size, self.proj_size * 3, bias=True) self.c_attn = nn.Linear(hidden_size, proj_size * 3, bias=True)
self.c_proj = nn.Linear(self.hidden_size, self.proj_size, bias=not args.no_bias) self.c_proj = nn.Linear(hidden_size, proj_size, bias=not args.no_bias)
self.scale = self.hidden_size_per_attention_head**-0.5 self.scale = hidden_size_per_attention_head**-0.5
def __call__(self, x, mask=None, cache=None): def __call__(self, x, mask=None, cache=None):
qkv = self.c_attn(x) qkv = self.c_attn(x)
@ -76,13 +79,13 @@ class QWenAttntion(nn.Module):
if mask is not None: if mask is not None:
scores = scores + mask scores = scores + mask
scores = mx.softmax(scores, axis=-1).astype(v.dtype) scores = mx.softmax(scores.astype(mx.float32), axis=-1).astype(scores.dtype)
v_hat = (scores @ v).transpose(0, 2, 1, 3).reshape(B, L, -1) v_hat = (scores @ v).transpose(0, 2, 1, 3).reshape(B, L, -1)
return self.c_proj(v_hat), (k, v) return self.c_proj(v_hat), (k, v)
class QWenMlp(nn.Module): class MLP(nn.Module):
def __init__(self, args: ModelArgs): def __init__(self, args: ModelArgs):
super().__init__() super().__init__()
@ -99,19 +102,17 @@ class QWenMlp(nn.Module):
def __call__(self, x): def __call__(self, x):
a1 = self.w1(x) a1 = self.w1(x)
a2 = self.w2(x) a2 = self.w2(x)
intermediate_parallel = a1 * nn.silu(a2) return self.c_proj(a1 * nn.silu(a2))
out = self.c_proj(intermediate_parallel)
return out
class QWenBlock(nn.Module): class TransformerBlock(nn.Module):
def __init__(self, args: ModelArgs): def __init__(self, args: ModelArgs):
super().__init__() super().__init__()
self.ln_1 = nn.RMSNorm(args.hidden_size, eps=args.layer_norm_epsilon) self.ln_1 = RMSNorm(args.hidden_size, eps=args.layer_norm_epsilon)
self.attn = QWenAttntion(args) self.attn = Attention(args)
self.ln_2 = nn.RMSNorm(args.hidden_size, eps=args.layer_norm_epsilon) self.ln_2 = RMSNorm(args.hidden_size, eps=args.layer_norm_epsilon)
self.mlp = QWenMlp(args) self.mlp = MLP(args)
def __call__(self, x, mask=None, cache=None): def __call__(self, x, mask=None, cache=None):
residual = x residual = x
@ -125,15 +126,15 @@ class QWenBlock(nn.Module):
return x, cache return x, cache
class QWen(nn.Module): class Qwen(nn.Module):
def __init__(self, args: ModelArgs): def __init__(self, args: ModelArgs):
super().__init__() super().__init__()
self.embed_dim = args.hidden_size self.embed_dim = args.hidden_size
self.wte = nn.Embedding(args.vocab_size, args.hidden_size) self.wte = nn.Embedding(args.vocab_size, args.hidden_size)
self.h = [QWenBlock(args) for _ in range(args.num_hidden_layers)] self.h = [TransformerBlock(args) for _ in range(args.num_hidden_layers)]
self.ln_f = nn.RMSNorm(self.embed_dim, eps=args.layer_norm_epsilon) self.ln_f = RMSNorm(self.embed_dim, eps=args.layer_norm_epsilon)
self.lm_head = nn.Linear(self.embed_dim, args.vocab_size, bias=False) self.lm_head = nn.Linear(self.embed_dim, args.vocab_size, bias=False)
@ -141,8 +142,9 @@ class QWen(nn.Module):
x = self.wte(inputs) x = self.wte(inputs)
mask = None mask = None
if x.shape[1] > 1: T = x.shape[1]
mask = nn.MultiHeadAttention.create_additive_causal_mask(x.shape[1]) if T > 1:
mask = nn.MultiHeadAttention.create_additive_causal_mask(T)
mask = mask.astype(x.dtype) mask = mask.astype(x.dtype)
if cache is None: if cache is None:
@ -151,12 +153,11 @@ class QWen(nn.Module):
for e, layer in enumerate(self.h): for e, layer in enumerate(self.h):
x, cache[e] = layer(x, mask, cache[e]) x, cache[e] = layer(x, mask, cache[e])
x = self.ln_f(x) x = self.ln_f(x[:, T - 1 : T, :])
return self.lm_head(x), cache return self.lm_head(x), cache
def generate(prompt: mx.array, model: QWen, temp: 0.0): def generate(prompt: mx.array, model: Qwen, temp: 0.0):
def sample(logits): def sample(logits):
if temp == 0: if temp == 0:
return mx.argmax(logits, axis=-1) return mx.argmax(logits, axis=-1)
@ -190,7 +191,7 @@ def load_model(
model_args.intermediate_size = config["intermediate_size"] model_args.intermediate_size = config["intermediate_size"]
model_args.no_bias = config["no_bias"] model_args.no_bias = config["no_bias"]
model = QWen(model_args) model = Qwen(model_args)
weights = mx.load("weights.npz") weights = mx.load("weights.npz")
model.update(tree_unflatten(list(weights.items()))) model.update(tree_unflatten(list(weights.items())))
@ -201,8 +202,6 @@ def load_model(
if __name__ == "__main__": if __name__ == "__main__":
# The infernece code and arguments were mainly derived from phi-2 example.
parser = argparse.ArgumentParser(description="Qwen inference script") parser = argparse.ArgumentParser(description="Qwen inference script")
parser.add_argument( parser.add_argument(
"--tokenizer", "--tokenizer",

View File

@ -1,4 +1,7 @@
einops
mlx mlx
numpy numpy
transformers>=4.35 transformers>=4.35
transformers_stream_generator>=0.0.4
torch torch
tiktoken