From 30be4c473408800e9e343e2e0df805a27e1ab156 Mon Sep 17 00:00:00 2001
From: Anchen
Date: Mon, 22 Jan 2024 15:00:07 -0800
Subject: [PATCH] refactor(qwen): moving qwen into mlx-lm (#312)
* refactor(qwen): moving qwen into mlx-lm
* chore: update doc
* chore: fix type hint
* add qwen model support in convert
* chore: fix doc
* chore: only load model in quantize_model
* chore: make the convert script only copy tokenizer files instead of load it and save
* chore: update docstring
* chore: remove unnecessary try catch
* chore: clean up for tokenizer and update transformers 4.37
* nits in README
---------
Co-authored-by: Awni Hannun
---
llms/README.md | 17 ++-
llms/mlx_lm/generate.py | 19 +++-
llms/{qwen => mlx_lm/models}/qwen.py | 159 +++++----------------------
llms/mlx_lm/requirements.txt | 2 +-
llms/mlx_lm/utils.py | 25 +++--
llms/qwen/README.md | 45 --------
llms/qwen/convert.py | 115 -------------------
llms/qwen/requirements.txt | 7 --
8 files changed, 80 insertions(+), 309 deletions(-)
rename llms/{qwen => mlx_lm/models}/qwen.py (50%)
delete mode 100644 llms/qwen/README.md
delete mode 100644 llms/qwen/convert.py
delete mode 100644 llms/qwen/requirements.txt
diff --git a/llms/README.md b/llms/README.md
index 40dbc067..ad4940a8 100644
--- a/llms/README.md
+++ b/llms/README.md
@@ -102,11 +102,26 @@ Here are a few examples of Hugging Face models that work with this example:
- [01-ai/Yi-6B-Chat](https://huggingface.co/01-ai/Yi-6B-Chat)
- [microsoft/phi-2](https://huggingface.co/microsoft/phi-2)
- [mistralai/Mixtral-8x7B-Instruct-v0.1](https://huggingface.co/mistralai/Mixtral-8x7B-Instruct-v0.1)
+- [Qwen/Qwen-7B](https://huggingface.co/Qwen/Qwen-7B)
Most
[Mistral](https://huggingface.co/models?library=transformers,safetensors&other=mistral&sort=trending),
[Llama](https://huggingface.co/models?library=transformers,safetensors&other=llama&sort=trending),
-[Phi-2](https://huggingface.co/models?library=transformers,safetensors&other=phi&sort=trending)
+[Phi-2](https://huggingface.co/models?library=transformers,safetensors&other=phi&sort=trending),
and
[Mixtral](https://huggingface.co/models?library=transformers,safetensors&other=mixtral&sort=trending)
style models should work out of the box.
+
+For
+[Qwen](https://huggingface.co/models?library=transformers,safetensors&other=qwen&sort=trending)
+style models, you must enable the `trust_remote_code` option and specify the
+`eos_token`. This ensures the tokenizer works correctly. You can do this by
+passing `--trust-remote-code` and `--eos-token "<|endoftext|>"` in the command
+line, or by setting these options in the Python API:
+
+```python
+model, tokenizer = load(
+ "qwen/Qwen-7B",
+ tokenizer_config={"eos_token": "<|endoftext|>", "trust_remote_code": True},
+)
+```
diff --git a/llms/mlx_lm/generate.py b/llms/mlx_lm/generate.py
index 237fb056..530a3483 100644
--- a/llms/mlx_lm/generate.py
+++ b/llms/mlx_lm/generate.py
@@ -21,6 +21,17 @@ def setup_arg_parser():
default="mlx_model",
help="The path to the local model directory or Hugging Face repo.",
)
+ parser.add_argument(
+ "--trust-remote-code",
+ action="store_true",
+ help="Enable trusting remote code for tokenizer",
+ )
+ parser.add_argument(
+ "--eos-token",
+ type=str,
+ default=None,
+ help="End of sequence token for tokenizer",
+ )
parser.add_argument(
"--prompt", default=DEFAULT_PROMPT, help="Message to be processed by the model"
)
@@ -40,7 +51,13 @@ def setup_arg_parser():
def main(args):
mx.random.seed(args.seed)
- model, tokenizer = load(args.model)
+
+ # Building tokenizer_config
+ tokenizer_config = {"trust_remote_code": True if args.trust_remote_code else None}
+ if args.eos_token is not None:
+ tokenizer_config["eos_token"] = args.eos_token
+
+ model, tokenizer = load(args.model, tokenizer_config=tokenizer_config)
print("=" * 10)
print("Prompt:", args.prompt)
prompt = tokenizer.encode(args.prompt)
diff --git a/llms/qwen/qwen.py b/llms/mlx_lm/models/qwen.py
similarity index 50%
rename from llms/qwen/qwen.py
rename to llms/mlx_lm/models/qwen.py
index 532a8031..a086a95c 100644
--- a/llms/qwen/qwen.py
+++ b/llms/mlx_lm/models/qwen.py
@@ -1,16 +1,14 @@
-import argparse
-import json
from dataclasses import dataclass
-from pathlib import Path
+from typing import Tuple
import mlx.core as mx
import mlx.nn as nn
-from mlx.utils import tree_unflatten
-from transformers import AutoTokenizer
+
+from .base import BaseModelArgs
@dataclass
-class ModelArgs:
+class ModelArgs(BaseModelArgs):
hidden_size: int = 2048
num_attention_heads: int = 16
num_hidden_layers: int = 24
@@ -20,6 +18,11 @@ class ModelArgs:
intermediate_size: int = 11008
no_bias: bool = True
vocab_size: int = 151936
+ num_key_value_heads = None
+
+ def __post_init__(self):
+ if self.num_key_value_heads is None:
+ self.num_key_value_heads = self.num_attention_heads
class RMSNorm(nn.Module):
@@ -95,7 +98,7 @@ class MLP(nn.Module):
args.hidden_size, args.intermediate_size // 2, bias=not args.no_bias
)
self.w2 = nn.Linear(
- args.intermediate_size // 2, args.hidden_size, bias=not args.no_bias
+ args.hidden_size, args.intermediate_size // 2, bias=not args.no_bias
)
self.c_proj = nn.Linear(
args.intermediate_size // 2, args.hidden_size, bias=not args.no_bias
@@ -128,17 +131,12 @@ class TransformerBlock(nn.Module):
return x, cache
-class Qwen(nn.Module):
+class QwenModel(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
-
- self.embed_dim = args.hidden_size
-
self.wte = nn.Embedding(args.vocab_size, args.hidden_size)
self.h = [TransformerBlock(args) for _ in range(args.num_hidden_layers)]
- self.ln_f = RMSNorm(self.embed_dim, eps=args.layer_norm_epsilon)
-
- self.lm_head = nn.Linear(self.embed_dim, args.vocab_size, bias=False)
+ self.ln_f = RMSNorm(args.hidden_size, eps=args.layer_norm_epsilon)
def __call__(self, inputs, mask=None, cache=None):
x = self.wte(inputs)
@@ -156,123 +154,22 @@ class Qwen(nn.Module):
x, cache[e] = layer(x, mask, cache[e])
x = self.ln_f(x[:, T - 1 : T, :])
- return self.lm_head(x), cache
+ return x, cache
-def generate(prompt: mx.array, model: Qwen, temp: 0.0):
- def sample(logits):
- if temp == 0:
- return mx.argmax(logits, axis=-1)
- else:
- return mx.random.categorical(logits * (1 / temp))
+class Model(nn.Module):
+ def __init__(self, config: ModelArgs):
+ super().__init__()
+ self.transformer = QwenModel(config)
+ self.lm_head = nn.Linear(
+ config.hidden_size, config.vocab_size, bias=not config.no_bias
+ )
- logits, cache = model(prompt)
- y = sample(logits[:, -1, :])
- yield y
-
- while True:
- logits, cache = model(y[:, None], cache=cache)
- y = sample(logits.squeeze(1))
- yield y
-
-
-def load_model(model_path: str, tokenizer_path: str = "Qwen/Qwen-1_8B"):
- model_args = ModelArgs()
-
- model_path = Path(model_path)
- with open(model_path / "config.json", "r") as f:
- config = json.load(f)
- model_args.vocab_size = config["vocab_size"]
- model_args.hidden_size = config["hidden_size"]
- model_args.num_attention_heads = config["num_attention_heads"]
- model_args.num_hidden_layers = config["num_hidden_layers"]
- model_args.kv_channels = config["kv_channels"]
- model_args.max_position_embeddings = config["max_position_embeddings"]
- model_args.layer_norm_epsilon = config["layer_norm_epsilon"]
- model_args.intermediate_size = config["intermediate_size"]
- model_args.no_bias = config["no_bias"]
-
- model = Qwen(model_args)
- weights = mx.load(str(model_path / "weights.npz"))
- if quantization := config.get("quantization", False):
- nn.QuantizedLinear.quantize_module(model, **quantization)
- model.update(tree_unflatten(list(weights.items())))
-
- tokenizer = AutoTokenizer.from_pretrained(
- tokenizer_path, trust_remote_code=True, eos_token="<|endoftext|>"
- )
- return model, tokenizer
-
-
-if __name__ == "__main__":
- parser = argparse.ArgumentParser(description="Qwen inference script")
- parser.add_argument(
- "--model-path",
- type=str,
- default="mlx_model",
- help="The path to the model weights and config",
- )
- parser.add_argument(
- "--tokenizer",
- help="The tokenizer to be used, defaults to Qwen/Qwen-1_8B",
- default="Qwen/Qwen-1_8B",
- )
- parser.add_argument(
- "--prompt",
- help="The message to be processed by the model",
- # The example from the official huggingface repo of Qwen
- default="蒙古国的首都是乌兰巴托(Ulaanbaatar)\n冰岛的首都是雷克雅未克(Reykjavik)\n埃塞俄比亚的首都是",
- )
- parser.add_argument(
- "--max-tokens",
- "-m",
- type=int,
- default=100,
- help="Maximum number of tokens to generate",
- )
- parser.add_argument(
- "--temp",
- help="The sampling temperature.",
- type=float,
- default=0.0,
- )
- parser.add_argument("--seed", type=int, default=0, help="The PRNG seed")
- args = parser.parse_args()
-
- mx.random.seed(args.seed)
-
- model, tokenizer = load_model(args.model_path, args.tokenizer)
-
- prompt = tokenizer(
- args.prompt,
- return_tensors="np",
- return_attention_mask=False,
- )["input_ids"]
-
- prompt = mx.array(prompt)
-
- print(args.prompt, end="", flush=True)
-
- tokens = []
- for token, _ in zip(generate(prompt, model, args.temp), range(args.max_tokens)):
- tokens.append(token)
-
- if (len(tokens) % 10) == 0:
- mx.eval(tokens)
- eos_index = next(
- (i for i, t in enumerate(tokens) if t.item() == tokenizer.eos_token_id),
- None,
- )
-
- if eos_index is not None:
- tokens = tokens[:eos_index]
-
- s = tokenizer.decode([t.item() for t in tokens])
- print(s, end="", flush=True)
- tokens = []
- if eos_index is not None:
- break
-
- mx.eval(tokens)
- s = tokenizer.decode([t.item() for t in tokens])
- print(s, flush=True)
+ def __call__(
+ self,
+ x: mx.array,
+ mask: mx.array = None,
+ cache: mx.array = None,
+ ) -> Tuple[mx.array, mx.array]:
+ y, cache = self.transformer(x, mask, cache)
+ return self.lm_head(y), cache
diff --git a/llms/mlx_lm/requirements.txt b/llms/mlx_lm/requirements.txt
index c78cefa2..a04cc7bb 100644
--- a/llms/mlx_lm/requirements.txt
+++ b/llms/mlx_lm/requirements.txt
@@ -1,4 +1,4 @@
mlx
numpy
-transformers
+transformers>=4.37.0
protobuf
diff --git a/llms/mlx_lm/utils.py b/llms/mlx_lm/utils.py
index a7eaea52..5e8f8e2a 100644
--- a/llms/mlx_lm/utils.py
+++ b/llms/mlx_lm/utils.py
@@ -10,8 +10,7 @@ from huggingface_hub import snapshot_download
from transformers import AutoTokenizer, PreTrainedTokenizer
# Local imports
-from .models import llama, mixtral, phi2
-from .models.base import BaseModelArgs
+from .models import llama, mixtral, phi2, qwen
# Constants
MODEL_MAPPING = {
@@ -19,6 +18,7 @@ MODEL_MAPPING = {
"mistral": llama, # mistral is compatible with llama
"mixtral": mixtral,
"phi": phi2,
+ "qwen": qwen,
}
linear_class_predicate = (
@@ -64,7 +64,13 @@ def get_model_path(path_or_hf_repo: str) -> Path:
model_path = Path(
snapshot_download(
repo_id=path_or_hf_repo,
- allow_patterns=["*.json", "*.safetensors", "*.py", "tokenizer.model"],
+ allow_patterns=[
+ "*.json",
+ "*.safetensors",
+ "*.py",
+ "tokenizer.model",
+ "*.tiktoken",
+ ],
)
)
return model_path
@@ -196,15 +202,18 @@ def load_model(model_path: Path) -> nn.Module:
return model
-def load(path_or_hf_repo: str) -> Tuple[nn.Module, PreTrainedTokenizer]:
+def load(
+ path_or_hf_repo: str, tokenizer_config={}
+) -> Tuple[nn.Module, PreTrainedTokenizer]:
"""
Load the model from a given path or a huggingface repository.
Args:
- path_or_hf_repo (str): The path or the huggingface repository to load the model from.
-
+ model_path (Path): The path or the huggingface repository to load the model from.
+ tokenizer_config (dict, optional): Configuration parameters specifically for the tokenizer.
+ Defaults to an empty dictionary.
Returns:
- Tuple[nn.Module, PreTrainedTokenizer]: The loaded model and tokenizer.
+ nn.Module: The loaded model.
Raises:
FileNotFoundError: If config file or safetensors are not found.
@@ -213,5 +222,5 @@ def load(path_or_hf_repo: str) -> Tuple[nn.Module, PreTrainedTokenizer]:
model_path = get_model_path(path_or_hf_repo)
model = load_model(model_path)
- tokenizer = AutoTokenizer.from_pretrained(model_path)
+ tokenizer = AutoTokenizer.from_pretrained(model_path, **tokenizer_config)
return model, tokenizer
diff --git a/llms/qwen/README.md b/llms/qwen/README.md
deleted file mode 100644
index 75154325..00000000
--- a/llms/qwen/README.md
+++ /dev/null
@@ -1,45 +0,0 @@
-# Qwen
-
-Qwen (通义千问) are a family of language models developed by Alibaba Cloud.[^1]
-The architecture of the Qwen models is similar to Llama except for the bias in
-the attention layers.
-
-## Setup
-
-First download and convert the model with:
-
-```sh
-python convert.py
-```
-
-To generate a 4-bit quantized model, use ``-q``. For a full list of options:
-
-The script downloads the model from Hugging Face. The default model is
-`Qwen/Qwen-1_8B`. Check out the [Hugging Face
-page](https://huggingface.co/Qwen) to see a list of available models.
-
-By default, the conversion script will make the directory `mlx_model` and save
-the converted `weights.npz` and `config.json` there.
-
-## Generate
-
-To generate text with the default prompt:
-
-```sh
-python qwen.py
-```
-
-If you change the model, make sure to pass the corresponding tokenizer. E.g.,
-for Qwen 7B use:
-
-```
-python qwen.py --tokenizer Qwen/Qwen-7B
-```
-
-To see a list of options, run:
-
-```sh
-python qwen.py --help
-```
-
-[^1]: For more details on the model see the official repo of [Qwen](https://github.com/QwenLM/Qwen) and the [Hugging Face](https://huggingface.co/Qwen).
diff --git a/llms/qwen/convert.py b/llms/qwen/convert.py
deleted file mode 100644
index 9f7bed89..00000000
--- a/llms/qwen/convert.py
+++ /dev/null
@@ -1,115 +0,0 @@
-import argparse
-import copy
-import json
-from pathlib import Path
-
-import mlx.core as mx
-import mlx.nn as nn
-import numpy as np
-import torch
-from mlx.utils import tree_flatten, tree_map, tree_unflatten
-from qwen import ModelArgs, Qwen
-from transformers import AutoModelForCausalLM
-
-
-def replace_key(key: str) -> str:
- if key.startswith("transformer."):
- # remove transformer prefix
- key = key.replace("transformer.", "")
-
- return key
-
-
-def quantize(weights, config, args):
- quantized_config = copy.deepcopy(config)
-
- # Load the model:
- model_args = ModelArgs()
- model_args.vocab_size = config["vocab_size"]
- model_args.hidden_size = config["hidden_size"]
- model_args.num_attention_heads = config["num_attention_heads"]
- model_args.num_hidden_layers = config["num_hidden_layers"]
- model_args.kv_channels = config["kv_channels"]
- model_args.max_position_embeddings = config["max_position_embeddings"]
- model_args.layer_norm_epsilon = config["layer_norm_epsilon"]
- model_args.intermediate_size = config["intermediate_size"]
- model_args.no_bias = config["no_bias"]
- model = Qwen(model_args)
-
- weights = tree_map(mx.array, weights)
- model.update(tree_unflatten(list(weights.items())))
-
- # Quantize the model:
- nn.QuantizedLinear.quantize_module(model, args.q_group_size, args.q_bits)
-
- # Update the config:
- quantized_config["quantization"] = {
- "group_size": args.q_group_size,
- "bits": args.q_bits,
- }
- quantized_weights = dict(tree_flatten(model.parameters()))
-
- return quantized_weights, quantized_config
-
-
-def convert(args):
- mlx_path = Path(args.mlx_path)
- mlx_path.mkdir(parents=True, exist_ok=True)
-
- model = AutoModelForCausalLM.from_pretrained(
- args.model, trust_remote_code=True, torch_dtype=torch.float16
- )
- state_dict = model.state_dict()
- weights = {
- replace_key(k): (
- v.numpy() if v.dtype != torch.bfloat16 else v.to(torch.float32).numpy()
- )
- for k, v in state_dict.items()
- }
- config = model.config.to_dict()
-
- if args.quantize:
- print("[INFO] Quantizing")
- weights, config = quantize(weights, config, args)
-
- np.savez(str(mlx_path / "weights.npz"), **weights)
-
- # write config
- with open(mlx_path / "config.json", "w") as f:
- json.dump(config, f, indent=4)
-
-
-if __name__ == "__main__":
- parser = argparse.ArgumentParser(description="Convert Qwen model to npz")
-
- parser.add_argument(
- "--model",
- help="The huggingface model to be converted",
- default="Qwen/Qwen-1_8B",
- )
- parser.add_argument(
- "--mlx-path",
- type=str,
- default="mlx_model",
- help="The path to save the MLX model.",
- )
- parser.add_argument(
- "-q",
- "--quantize",
- help="Generate a quantized model.",
- action="store_true",
- )
- parser.add_argument(
- "--q-group-size",
- help="Group size for quantization.",
- type=int,
- default=64,
- )
- parser.add_argument(
- "--q-bits",
- help="Bits per weight for quantization.",
- type=int,
- default=4,
- )
- args = parser.parse_args()
- convert(args)
diff --git a/llms/qwen/requirements.txt b/llms/qwen/requirements.txt
deleted file mode 100644
index 0ce17aec..00000000
--- a/llms/qwen/requirements.txt
+++ /dev/null
@@ -1,7 +0,0 @@
-einops
-mlx
-numpy
-transformers>=4.35
-transformers_stream_generator>=0.0.4
-torch
-tiktoken