refactor(qwen): moving qwen into mlx-lm (#312)

* refactor(qwen): moving qwen into mlx-lm

* chore: update doc

* chore: fix type hint

* add qwen model support in convert

* chore: fix doc

* chore: only load model in quantize_model

* chore: make the convert script only copy tokenizer files instead of load it and save

* chore: update docstring

* chore: remove unnecessary try catch

* chore: clean up for tokenizer and update  transformers 4.37

* nits in README

---------

Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
Anchen 2024-01-22 15:00:07 -08:00 committed by GitHub
parent de15532da8
commit 30be4c4734
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 80 additions and 309 deletions

View File

@ -102,11 +102,26 @@ Here are a few examples of Hugging Face models that work with this example:
- [01-ai/Yi-6B-Chat](https://huggingface.co/01-ai/Yi-6B-Chat) - [01-ai/Yi-6B-Chat](https://huggingface.co/01-ai/Yi-6B-Chat)
- [microsoft/phi-2](https://huggingface.co/microsoft/phi-2) - [microsoft/phi-2](https://huggingface.co/microsoft/phi-2)
- [mistralai/Mixtral-8x7B-Instruct-v0.1](https://huggingface.co/mistralai/Mixtral-8x7B-Instruct-v0.1) - [mistralai/Mixtral-8x7B-Instruct-v0.1](https://huggingface.co/mistralai/Mixtral-8x7B-Instruct-v0.1)
- [Qwen/Qwen-7B](https://huggingface.co/Qwen/Qwen-7B)
Most Most
[Mistral](https://huggingface.co/models?library=transformers,safetensors&other=mistral&sort=trending), [Mistral](https://huggingface.co/models?library=transformers,safetensors&other=mistral&sort=trending),
[Llama](https://huggingface.co/models?library=transformers,safetensors&other=llama&sort=trending), [Llama](https://huggingface.co/models?library=transformers,safetensors&other=llama&sort=trending),
[Phi-2](https://huggingface.co/models?library=transformers,safetensors&other=phi&sort=trending) [Phi-2](https://huggingface.co/models?library=transformers,safetensors&other=phi&sort=trending),
and and
[Mixtral](https://huggingface.co/models?library=transformers,safetensors&other=mixtral&sort=trending) [Mixtral](https://huggingface.co/models?library=transformers,safetensors&other=mixtral&sort=trending)
style models should work out of the box. style models should work out of the box.
For
[Qwen](https://huggingface.co/models?library=transformers,safetensors&other=qwen&sort=trending)
style models, you must enable the `trust_remote_code` option and specify the
`eos_token`. This ensures the tokenizer works correctly. You can do this by
passing `--trust-remote-code` and `--eos-token "<|endoftext|>"` in the command
line, or by setting these options in the Python API:
```python
model, tokenizer = load(
"qwen/Qwen-7B",
tokenizer_config={"eos_token": "<|endoftext|>", "trust_remote_code": True},
)
```

View File

@ -21,6 +21,17 @@ def setup_arg_parser():
default="mlx_model", default="mlx_model",
help="The path to the local model directory or Hugging Face repo.", help="The path to the local model directory or Hugging Face repo.",
) )
parser.add_argument(
"--trust-remote-code",
action="store_true",
help="Enable trusting remote code for tokenizer",
)
parser.add_argument(
"--eos-token",
type=str,
default=None,
help="End of sequence token for tokenizer",
)
parser.add_argument( parser.add_argument(
"--prompt", default=DEFAULT_PROMPT, help="Message to be processed by the model" "--prompt", default=DEFAULT_PROMPT, help="Message to be processed by the model"
) )
@ -40,7 +51,13 @@ def setup_arg_parser():
def main(args): def main(args):
mx.random.seed(args.seed) mx.random.seed(args.seed)
model, tokenizer = load(args.model)
# Building tokenizer_config
tokenizer_config = {"trust_remote_code": True if args.trust_remote_code else None}
if args.eos_token is not None:
tokenizer_config["eos_token"] = args.eos_token
model, tokenizer = load(args.model, tokenizer_config=tokenizer_config)
print("=" * 10) print("=" * 10)
print("Prompt:", args.prompt) print("Prompt:", args.prompt)
prompt = tokenizer.encode(args.prompt) prompt = tokenizer.encode(args.prompt)

View File

@ -1,16 +1,14 @@
import argparse
import json
from dataclasses import dataclass from dataclasses import dataclass
from pathlib import Path from typing import Tuple
import mlx.core as mx import mlx.core as mx
import mlx.nn as nn import mlx.nn as nn
from mlx.utils import tree_unflatten
from transformers import AutoTokenizer from .base import BaseModelArgs
@dataclass @dataclass
class ModelArgs: class ModelArgs(BaseModelArgs):
hidden_size: int = 2048 hidden_size: int = 2048
num_attention_heads: int = 16 num_attention_heads: int = 16
num_hidden_layers: int = 24 num_hidden_layers: int = 24
@ -20,6 +18,11 @@ class ModelArgs:
intermediate_size: int = 11008 intermediate_size: int = 11008
no_bias: bool = True no_bias: bool = True
vocab_size: int = 151936 vocab_size: int = 151936
num_key_value_heads = None
def __post_init__(self):
if self.num_key_value_heads is None:
self.num_key_value_heads = self.num_attention_heads
class RMSNorm(nn.Module): class RMSNorm(nn.Module):
@ -95,7 +98,7 @@ class MLP(nn.Module):
args.hidden_size, args.intermediate_size // 2, bias=not args.no_bias args.hidden_size, args.intermediate_size // 2, bias=not args.no_bias
) )
self.w2 = nn.Linear( self.w2 = nn.Linear(
args.intermediate_size // 2, args.hidden_size, bias=not args.no_bias args.hidden_size, args.intermediate_size // 2, bias=not args.no_bias
) )
self.c_proj = nn.Linear( self.c_proj = nn.Linear(
args.intermediate_size // 2, args.hidden_size, bias=not args.no_bias args.intermediate_size // 2, args.hidden_size, bias=not args.no_bias
@ -128,17 +131,12 @@ class TransformerBlock(nn.Module):
return x, cache return x, cache
class Qwen(nn.Module): class QwenModel(nn.Module):
def __init__(self, args: ModelArgs): def __init__(self, args: ModelArgs):
super().__init__() super().__init__()
self.embed_dim = args.hidden_size
self.wte = nn.Embedding(args.vocab_size, args.hidden_size) self.wte = nn.Embedding(args.vocab_size, args.hidden_size)
self.h = [TransformerBlock(args) for _ in range(args.num_hidden_layers)] self.h = [TransformerBlock(args) for _ in range(args.num_hidden_layers)]
self.ln_f = RMSNorm(self.embed_dim, eps=args.layer_norm_epsilon) self.ln_f = RMSNorm(args.hidden_size, eps=args.layer_norm_epsilon)
self.lm_head = nn.Linear(self.embed_dim, args.vocab_size, bias=False)
def __call__(self, inputs, mask=None, cache=None): def __call__(self, inputs, mask=None, cache=None):
x = self.wte(inputs) x = self.wte(inputs)
@ -156,123 +154,22 @@ class Qwen(nn.Module):
x, cache[e] = layer(x, mask, cache[e]) x, cache[e] = layer(x, mask, cache[e])
x = self.ln_f(x[:, T - 1 : T, :]) x = self.ln_f(x[:, T - 1 : T, :])
return self.lm_head(x), cache return x, cache
def generate(prompt: mx.array, model: Qwen, temp: 0.0): class Model(nn.Module):
def sample(logits): def __init__(self, config: ModelArgs):
if temp == 0: super().__init__()
return mx.argmax(logits, axis=-1) self.transformer = QwenModel(config)
else: self.lm_head = nn.Linear(
return mx.random.categorical(logits * (1 / temp)) config.hidden_size, config.vocab_size, bias=not config.no_bias
logits, cache = model(prompt)
y = sample(logits[:, -1, :])
yield y
while True:
logits, cache = model(y[:, None], cache=cache)
y = sample(logits.squeeze(1))
yield y
def load_model(model_path: str, tokenizer_path: str = "Qwen/Qwen-1_8B"):
model_args = ModelArgs()
model_path = Path(model_path)
with open(model_path / "config.json", "r") as f:
config = json.load(f)
model_args.vocab_size = config["vocab_size"]
model_args.hidden_size = config["hidden_size"]
model_args.num_attention_heads = config["num_attention_heads"]
model_args.num_hidden_layers = config["num_hidden_layers"]
model_args.kv_channels = config["kv_channels"]
model_args.max_position_embeddings = config["max_position_embeddings"]
model_args.layer_norm_epsilon = config["layer_norm_epsilon"]
model_args.intermediate_size = config["intermediate_size"]
model_args.no_bias = config["no_bias"]
model = Qwen(model_args)
weights = mx.load(str(model_path / "weights.npz"))
if quantization := config.get("quantization", False):
nn.QuantizedLinear.quantize_module(model, **quantization)
model.update(tree_unflatten(list(weights.items())))
tokenizer = AutoTokenizer.from_pretrained(
tokenizer_path, trust_remote_code=True, eos_token="<|endoftext|>"
)
return model, tokenizer
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Qwen inference script")
parser.add_argument(
"--model-path",
type=str,
default="mlx_model",
help="The path to the model weights and config",
)
parser.add_argument(
"--tokenizer",
help="The tokenizer to be used, defaults to Qwen/Qwen-1_8B",
default="Qwen/Qwen-1_8B",
)
parser.add_argument(
"--prompt",
help="The message to be processed by the model",
# The example from the official huggingface repo of Qwen
default="蒙古国的首都是乌兰巴托Ulaanbaatar\n冰岛的首都是雷克雅未克Reykjavik\n埃塞俄比亚的首都是",
)
parser.add_argument(
"--max-tokens",
"-m",
type=int,
default=100,
help="Maximum number of tokens to generate",
)
parser.add_argument(
"--temp",
help="The sampling temperature.",
type=float,
default=0.0,
)
parser.add_argument("--seed", type=int, default=0, help="The PRNG seed")
args = parser.parse_args()
mx.random.seed(args.seed)
model, tokenizer = load_model(args.model_path, args.tokenizer)
prompt = tokenizer(
args.prompt,
return_tensors="np",
return_attention_mask=False,
)["input_ids"]
prompt = mx.array(prompt)
print(args.prompt, end="", flush=True)
tokens = []
for token, _ in zip(generate(prompt, model, args.temp), range(args.max_tokens)):
tokens.append(token)
if (len(tokens) % 10) == 0:
mx.eval(tokens)
eos_index = next(
(i for i, t in enumerate(tokens) if t.item() == tokenizer.eos_token_id),
None,
) )
if eos_index is not None: def __call__(
tokens = tokens[:eos_index] self,
x: mx.array,
s = tokenizer.decode([t.item() for t in tokens]) mask: mx.array = None,
print(s, end="", flush=True) cache: mx.array = None,
tokens = [] ) -> Tuple[mx.array, mx.array]:
if eos_index is not None: y, cache = self.transformer(x, mask, cache)
break return self.lm_head(y), cache
mx.eval(tokens)
s = tokenizer.decode([t.item() for t in tokens])
print(s, flush=True)

View File

@ -1,4 +1,4 @@
mlx mlx
numpy numpy
transformers transformers>=4.37.0
protobuf protobuf

View File

@ -10,8 +10,7 @@ from huggingface_hub import snapshot_download
from transformers import AutoTokenizer, PreTrainedTokenizer from transformers import AutoTokenizer, PreTrainedTokenizer
# Local imports # Local imports
from .models import llama, mixtral, phi2 from .models import llama, mixtral, phi2, qwen
from .models.base import BaseModelArgs
# Constants # Constants
MODEL_MAPPING = { MODEL_MAPPING = {
@ -19,6 +18,7 @@ MODEL_MAPPING = {
"mistral": llama, # mistral is compatible with llama "mistral": llama, # mistral is compatible with llama
"mixtral": mixtral, "mixtral": mixtral,
"phi": phi2, "phi": phi2,
"qwen": qwen,
} }
linear_class_predicate = ( linear_class_predicate = (
@ -64,7 +64,13 @@ def get_model_path(path_or_hf_repo: str) -> Path:
model_path = Path( model_path = Path(
snapshot_download( snapshot_download(
repo_id=path_or_hf_repo, repo_id=path_or_hf_repo,
allow_patterns=["*.json", "*.safetensors", "*.py", "tokenizer.model"], allow_patterns=[
"*.json",
"*.safetensors",
"*.py",
"tokenizer.model",
"*.tiktoken",
],
) )
) )
return model_path return model_path
@ -196,15 +202,18 @@ def load_model(model_path: Path) -> nn.Module:
return model return model
def load(path_or_hf_repo: str) -> Tuple[nn.Module, PreTrainedTokenizer]: def load(
path_or_hf_repo: str, tokenizer_config={}
) -> Tuple[nn.Module, PreTrainedTokenizer]:
""" """
Load the model from a given path or a huggingface repository. Load the model from a given path or a huggingface repository.
Args: Args:
path_or_hf_repo (str): The path or the huggingface repository to load the model from. model_path (Path): The path or the huggingface repository to load the model from.
tokenizer_config (dict, optional): Configuration parameters specifically for the tokenizer.
Defaults to an empty dictionary.
Returns: Returns:
Tuple[nn.Module, PreTrainedTokenizer]: The loaded model and tokenizer. nn.Module: The loaded model.
Raises: Raises:
FileNotFoundError: If config file or safetensors are not found. FileNotFoundError: If config file or safetensors are not found.
@ -213,5 +222,5 @@ def load(path_or_hf_repo: str) -> Tuple[nn.Module, PreTrainedTokenizer]:
model_path = get_model_path(path_or_hf_repo) model_path = get_model_path(path_or_hf_repo)
model = load_model(model_path) model = load_model(model_path)
tokenizer = AutoTokenizer.from_pretrained(model_path) tokenizer = AutoTokenizer.from_pretrained(model_path, **tokenizer_config)
return model, tokenizer return model, tokenizer

View File

@ -1,45 +0,0 @@
# Qwen
Qwen (通义千问) are a family of language models developed by Alibaba Cloud.[^1]
The architecture of the Qwen models is similar to Llama except for the bias in
the attention layers.
## Setup
First download and convert the model with:
```sh
python convert.py
```
To generate a 4-bit quantized model, use ``-q``. For a full list of options:
The script downloads the model from Hugging Face. The default model is
`Qwen/Qwen-1_8B`. Check out the [Hugging Face
page](https://huggingface.co/Qwen) to see a list of available models.
By default, the conversion script will make the directory `mlx_model` and save
the converted `weights.npz` and `config.json` there.
## Generate
To generate text with the default prompt:
```sh
python qwen.py
```
If you change the model, make sure to pass the corresponding tokenizer. E.g.,
for Qwen 7B use:
```
python qwen.py --tokenizer Qwen/Qwen-7B
```
To see a list of options, run:
```sh
python qwen.py --help
```
[^1]: For more details on the model see the official repo of [Qwen](https://github.com/QwenLM/Qwen) and the [Hugging Face](https://huggingface.co/Qwen).

View File

@ -1,115 +0,0 @@
import argparse
import copy
import json
from pathlib import Path
import mlx.core as mx
import mlx.nn as nn
import numpy as np
import torch
from mlx.utils import tree_flatten, tree_map, tree_unflatten
from qwen import ModelArgs, Qwen
from transformers import AutoModelForCausalLM
def replace_key(key: str) -> str:
if key.startswith("transformer."):
# remove transformer prefix
key = key.replace("transformer.", "")
return key
def quantize(weights, config, args):
quantized_config = copy.deepcopy(config)
# Load the model:
model_args = ModelArgs()
model_args.vocab_size = config["vocab_size"]
model_args.hidden_size = config["hidden_size"]
model_args.num_attention_heads = config["num_attention_heads"]
model_args.num_hidden_layers = config["num_hidden_layers"]
model_args.kv_channels = config["kv_channels"]
model_args.max_position_embeddings = config["max_position_embeddings"]
model_args.layer_norm_epsilon = config["layer_norm_epsilon"]
model_args.intermediate_size = config["intermediate_size"]
model_args.no_bias = config["no_bias"]
model = Qwen(model_args)
weights = tree_map(mx.array, weights)
model.update(tree_unflatten(list(weights.items())))
# Quantize the model:
nn.QuantizedLinear.quantize_module(model, args.q_group_size, args.q_bits)
# Update the config:
quantized_config["quantization"] = {
"group_size": args.q_group_size,
"bits": args.q_bits,
}
quantized_weights = dict(tree_flatten(model.parameters()))
return quantized_weights, quantized_config
def convert(args):
mlx_path = Path(args.mlx_path)
mlx_path.mkdir(parents=True, exist_ok=True)
model = AutoModelForCausalLM.from_pretrained(
args.model, trust_remote_code=True, torch_dtype=torch.float16
)
state_dict = model.state_dict()
weights = {
replace_key(k): (
v.numpy() if v.dtype != torch.bfloat16 else v.to(torch.float32).numpy()
)
for k, v in state_dict.items()
}
config = model.config.to_dict()
if args.quantize:
print("[INFO] Quantizing")
weights, config = quantize(weights, config, args)
np.savez(str(mlx_path / "weights.npz"), **weights)
# write config
with open(mlx_path / "config.json", "w") as f:
json.dump(config, f, indent=4)
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Convert Qwen model to npz")
parser.add_argument(
"--model",
help="The huggingface model to be converted",
default="Qwen/Qwen-1_8B",
)
parser.add_argument(
"--mlx-path",
type=str,
default="mlx_model",
help="The path to save the MLX model.",
)
parser.add_argument(
"-q",
"--quantize",
help="Generate a quantized model.",
action="store_true",
)
parser.add_argument(
"--q-group-size",
help="Group size for quantization.",
type=int,
default=64,
)
parser.add_argument(
"--q-bits",
help="Bits per weight for quantization.",
type=int,
default=4,
)
args = parser.parse_args()
convert(args)

View File

@ -1,7 +0,0 @@
einops
mlx
numpy
transformers>=4.35
transformers_stream_generator>=0.0.4
torch
tiktoken