diff --git a/qwen/README.md b/qwen/README.md index db1d61e9..31d17c7f 100644 --- a/qwen/README.md +++ b/qwen/README.md @@ -1,25 +1,37 @@ # Qwen -Qwen (通义千问) is a language model proposed by Alibaba Cloud[^1]. The architecture of Qwen is similar to Llama except for the bias in the attention layers. +Qwen (通义千问) is a language model developed by Alibaba Cloud.[^1] The +architecture of Qwen is similar to Llama except for the bias in the attention +layers. ## Setup -Download (from huggingface) and conver the model. By default, the model is `Qwen/Qwen-1_8B`. +First download and convert the model with: ```sh python convert.py ``` +The script downloads the model from Hugging Face. The default model is +`Qwen/Qwen-1_8B`. Check out the [Hugging Face page](https://huggingface.co/Qwen) to see a list of available models. -This will make the `weights.npz` file which MLX can read. +The conversion script will make the `weights.npz` and `params.json` files in +the working directory. ## Generate -To generate text with the default prompt (default tokenizer is `Qwen/Qwen-1_8B`): +To generate text with the default prompt: ```sh python qwen.py ``` +If you change the model, make sure to pass the corresponding tokenizer. E.g., +for Qwen 7B use: + +``` +python qwen.py --tokenizer Qwen/Qwen-7B +``` + To see a list of options, run: ```sh diff --git a/qwen/convert.py b/qwen/convert.py index 3406dde8..50a8d7a8 100644 --- a/qwen/convert.py +++ b/qwen/convert.py @@ -25,7 +25,7 @@ def convert(model_path: str = "Qwen/Qwen-1_8B"): config = model.config config_dict = config.to_dict() with open("config.json", "w") as f: - json.dump(config_dict, f) + json.dump(config_dict, f, indent=4) if __name__ == "__main__": diff --git a/qwen/qwen.py b/qwen/qwen.py index fa5cc834..c490d650 100644 --- a/qwen/qwen.py +++ b/qwen/qwen.py @@ -1,13 +1,9 @@ -# The architecture of qwen is similar to Llama. -# This inference script is mainly for compatibility with the huggingface model of qwen. - import argparse +from dataclasses import dataclass import json import mlx.core as mx import mlx.nn as nn - from mlx.utils import tree_unflatten -from dataclasses import dataclass from transformers import AutoTokenizer @@ -17,37 +13,44 @@ class ModelArgs: num_attention_heads: int = 16 num_hidden_layers: int = 24 kv_channels: int = 128 - max_position_embeddings: int = 8192 layer_norm_epsilon: float = 1e-6 intermediate_size: int = 11008 - no_bias: bool = True - vocab_size: int = 151936 -class QWenAttntion(nn.Module): +class RMSNorm(nn.Module): + def __init__(self, dims: int, eps: float = 1e-5): + super().__init__() + self.weight = mx.ones((dims,)) + self.eps = eps + + def _norm(self, x): + return x * mx.rsqrt(x.square().mean(-1, keepdims=True) + self.eps) + + def __call__(self, x): + output = self._norm(x.astype(mx.float32)).astype(x.dtype) + return self.weight * output + + +class Attention(nn.Module): def __init__(self, args: ModelArgs): super().__init__() - self.hidden_size = args.hidden_size + hidden_size = args.hidden_size self.num_attention_heads = args.num_attention_heads - self.hidden_size_per_attention_head = ( - self.hidden_size // self.num_attention_heads - ) + hidden_size_per_attention_head = hidden_size // self.num_attention_heads - self.rotary_emb = nn.RoPE( - self.hidden_size_per_attention_head, traditional=False - ) + self.rotary_emb = nn.RoPE(hidden_size_per_attention_head, traditional=False) - self.proj_size = args.kv_channels * self.num_attention_heads + proj_size = args.kv_channels * self.num_attention_heads - self.c_attn = nn.Linear(self.hidden_size, self.proj_size * 3, bias=True) - self.c_proj = nn.Linear(self.hidden_size, self.proj_size, bias=not args.no_bias) + self.c_attn = nn.Linear(hidden_size, proj_size * 3, bias=True) + self.c_proj = nn.Linear(hidden_size, proj_size, bias=not args.no_bias) - self.scale = self.hidden_size_per_attention_head**-0.5 + self.scale = hidden_size_per_attention_head**-0.5 def __call__(self, x, mask=None, cache=None): qkv = self.c_attn(x) @@ -76,13 +79,13 @@ class QWenAttntion(nn.Module): if mask is not None: scores = scores + mask - scores = mx.softmax(scores, axis=-1).astype(v.dtype) + scores = mx.softmax(scores.astype(mx.float32), axis=-1).astype(scores.dtype) v_hat = (scores @ v).transpose(0, 2, 1, 3).reshape(B, L, -1) return self.c_proj(v_hat), (k, v) -class QWenMlp(nn.Module): +class MLP(nn.Module): def __init__(self, args: ModelArgs): super().__init__() @@ -99,19 +102,17 @@ class QWenMlp(nn.Module): def __call__(self, x): a1 = self.w1(x) a2 = self.w2(x) - intermediate_parallel = a1 * nn.silu(a2) - out = self.c_proj(intermediate_parallel) - return out + return self.c_proj(a1 * nn.silu(a2)) -class QWenBlock(nn.Module): +class TransformerBlock(nn.Module): def __init__(self, args: ModelArgs): super().__init__() - self.ln_1 = nn.RMSNorm(args.hidden_size, eps=args.layer_norm_epsilon) - self.attn = QWenAttntion(args) - self.ln_2 = nn.RMSNorm(args.hidden_size, eps=args.layer_norm_epsilon) - self.mlp = QWenMlp(args) + self.ln_1 = RMSNorm(args.hidden_size, eps=args.layer_norm_epsilon) + self.attn = Attention(args) + self.ln_2 = RMSNorm(args.hidden_size, eps=args.layer_norm_epsilon) + self.mlp = MLP(args) def __call__(self, x, mask=None, cache=None): residual = x @@ -125,15 +126,15 @@ class QWenBlock(nn.Module): return x, cache -class QWen(nn.Module): +class Qwen(nn.Module): def __init__(self, args: ModelArgs): super().__init__() self.embed_dim = args.hidden_size self.wte = nn.Embedding(args.vocab_size, args.hidden_size) - self.h = [QWenBlock(args) for _ in range(args.num_hidden_layers)] - self.ln_f = nn.RMSNorm(self.embed_dim, eps=args.layer_norm_epsilon) + self.h = [TransformerBlock(args) for _ in range(args.num_hidden_layers)] + self.ln_f = RMSNorm(self.embed_dim, eps=args.layer_norm_epsilon) self.lm_head = nn.Linear(self.embed_dim, args.vocab_size, bias=False) @@ -141,8 +142,9 @@ class QWen(nn.Module): x = self.wte(inputs) mask = None - if x.shape[1] > 1: - mask = nn.MultiHeadAttention.create_additive_causal_mask(x.shape[1]) + T = x.shape[1] + if T > 1: + mask = nn.MultiHeadAttention.create_additive_causal_mask(T) mask = mask.astype(x.dtype) if cache is None: @@ -151,12 +153,11 @@ class QWen(nn.Module): for e, layer in enumerate(self.h): x, cache[e] = layer(x, mask, cache[e]) - x = self.ln_f(x) - + x = self.ln_f(x[:, T - 1 : T, :]) return self.lm_head(x), cache -def generate(prompt: mx.array, model: QWen, temp: 0.0): +def generate(prompt: mx.array, model: Qwen, temp: 0.0): def sample(logits): if temp == 0: return mx.argmax(logits, axis=-1) @@ -190,7 +191,7 @@ def load_model( model_args.intermediate_size = config["intermediate_size"] model_args.no_bias = config["no_bias"] - model = QWen(model_args) + model = Qwen(model_args) weights = mx.load("weights.npz") model.update(tree_unflatten(list(weights.items()))) @@ -201,8 +202,6 @@ def load_model( if __name__ == "__main__": - # The infernece code and arguments were mainly derived from phi-2 example. - parser = argparse.ArgumentParser(description="Qwen inference script") parser.add_argument( "--tokenizer", diff --git a/qwen/requirements.txt b/qwen/requirements.txt index 8a318500..0ce17aec 100644 --- a/qwen/requirements.txt +++ b/qwen/requirements.txt @@ -1,4 +1,7 @@ +einops mlx numpy transformers>=4.35 -torch \ No newline at end of file +transformers_stream_generator>=0.0.4 +torch +tiktoken