From 30be4c473408800e9e343e2e0df805a27e1ab156 Mon Sep 17 00:00:00 2001 From: Anchen Date: Mon, 22 Jan 2024 15:00:07 -0800 Subject: [PATCH] refactor(qwen): moving qwen into mlx-lm (#312) * refactor(qwen): moving qwen into mlx-lm * chore: update doc * chore: fix type hint * add qwen model support in convert * chore: fix doc * chore: only load model in quantize_model * chore: make the convert script only copy tokenizer files instead of load it and save * chore: update docstring * chore: remove unnecessary try catch * chore: clean up for tokenizer and update transformers 4.37 * nits in README --------- Co-authored-by: Awni Hannun --- llms/README.md | 17 ++- llms/mlx_lm/generate.py | 19 +++- llms/{qwen => mlx_lm/models}/qwen.py | 159 +++++---------------------- llms/mlx_lm/requirements.txt | 2 +- llms/mlx_lm/utils.py | 25 +++-- llms/qwen/README.md | 45 -------- llms/qwen/convert.py | 115 ------------------- llms/qwen/requirements.txt | 7 -- 8 files changed, 80 insertions(+), 309 deletions(-) rename llms/{qwen => mlx_lm/models}/qwen.py (50%) delete mode 100644 llms/qwen/README.md delete mode 100644 llms/qwen/convert.py delete mode 100644 llms/qwen/requirements.txt diff --git a/llms/README.md b/llms/README.md index 40dbc067..ad4940a8 100644 --- a/llms/README.md +++ b/llms/README.md @@ -102,11 +102,26 @@ Here are a few examples of Hugging Face models that work with this example: - [01-ai/Yi-6B-Chat](https://huggingface.co/01-ai/Yi-6B-Chat) - [microsoft/phi-2](https://huggingface.co/microsoft/phi-2) - [mistralai/Mixtral-8x7B-Instruct-v0.1](https://huggingface.co/mistralai/Mixtral-8x7B-Instruct-v0.1) +- [Qwen/Qwen-7B](https://huggingface.co/Qwen/Qwen-7B) Most [Mistral](https://huggingface.co/models?library=transformers,safetensors&other=mistral&sort=trending), [Llama](https://huggingface.co/models?library=transformers,safetensors&other=llama&sort=trending), -[Phi-2](https://huggingface.co/models?library=transformers,safetensors&other=phi&sort=trending) +[Phi-2](https://huggingface.co/models?library=transformers,safetensors&other=phi&sort=trending), and [Mixtral](https://huggingface.co/models?library=transformers,safetensors&other=mixtral&sort=trending) style models should work out of the box. + +For +[Qwen](https://huggingface.co/models?library=transformers,safetensors&other=qwen&sort=trending) +style models, you must enable the `trust_remote_code` option and specify the +`eos_token`. This ensures the tokenizer works correctly. You can do this by +passing `--trust-remote-code` and `--eos-token "<|endoftext|>"` in the command +line, or by setting these options in the Python API: + +```python +model, tokenizer = load( + "qwen/Qwen-7B", + tokenizer_config={"eos_token": "<|endoftext|>", "trust_remote_code": True}, +) +``` diff --git a/llms/mlx_lm/generate.py b/llms/mlx_lm/generate.py index 237fb056..530a3483 100644 --- a/llms/mlx_lm/generate.py +++ b/llms/mlx_lm/generate.py @@ -21,6 +21,17 @@ def setup_arg_parser(): default="mlx_model", help="The path to the local model directory or Hugging Face repo.", ) + parser.add_argument( + "--trust-remote-code", + action="store_true", + help="Enable trusting remote code for tokenizer", + ) + parser.add_argument( + "--eos-token", + type=str, + default=None, + help="End of sequence token for tokenizer", + ) parser.add_argument( "--prompt", default=DEFAULT_PROMPT, help="Message to be processed by the model" ) @@ -40,7 +51,13 @@ def setup_arg_parser(): def main(args): mx.random.seed(args.seed) - model, tokenizer = load(args.model) + + # Building tokenizer_config + tokenizer_config = {"trust_remote_code": True if args.trust_remote_code else None} + if args.eos_token is not None: + tokenizer_config["eos_token"] = args.eos_token + + model, tokenizer = load(args.model, tokenizer_config=tokenizer_config) print("=" * 10) print("Prompt:", args.prompt) prompt = tokenizer.encode(args.prompt) diff --git a/llms/qwen/qwen.py b/llms/mlx_lm/models/qwen.py similarity index 50% rename from llms/qwen/qwen.py rename to llms/mlx_lm/models/qwen.py index 532a8031..a086a95c 100644 --- a/llms/qwen/qwen.py +++ b/llms/mlx_lm/models/qwen.py @@ -1,16 +1,14 @@ -import argparse -import json from dataclasses import dataclass -from pathlib import Path +from typing import Tuple import mlx.core as mx import mlx.nn as nn -from mlx.utils import tree_unflatten -from transformers import AutoTokenizer + +from .base import BaseModelArgs @dataclass -class ModelArgs: +class ModelArgs(BaseModelArgs): hidden_size: int = 2048 num_attention_heads: int = 16 num_hidden_layers: int = 24 @@ -20,6 +18,11 @@ class ModelArgs: intermediate_size: int = 11008 no_bias: bool = True vocab_size: int = 151936 + num_key_value_heads = None + + def __post_init__(self): + if self.num_key_value_heads is None: + self.num_key_value_heads = self.num_attention_heads class RMSNorm(nn.Module): @@ -95,7 +98,7 @@ class MLP(nn.Module): args.hidden_size, args.intermediate_size // 2, bias=not args.no_bias ) self.w2 = nn.Linear( - args.intermediate_size // 2, args.hidden_size, bias=not args.no_bias + args.hidden_size, args.intermediate_size // 2, bias=not args.no_bias ) self.c_proj = nn.Linear( args.intermediate_size // 2, args.hidden_size, bias=not args.no_bias @@ -128,17 +131,12 @@ class TransformerBlock(nn.Module): return x, cache -class Qwen(nn.Module): +class QwenModel(nn.Module): def __init__(self, args: ModelArgs): super().__init__() - - self.embed_dim = args.hidden_size - self.wte = nn.Embedding(args.vocab_size, args.hidden_size) self.h = [TransformerBlock(args) for _ in range(args.num_hidden_layers)] - self.ln_f = RMSNorm(self.embed_dim, eps=args.layer_norm_epsilon) - - self.lm_head = nn.Linear(self.embed_dim, args.vocab_size, bias=False) + self.ln_f = RMSNorm(args.hidden_size, eps=args.layer_norm_epsilon) def __call__(self, inputs, mask=None, cache=None): x = self.wte(inputs) @@ -156,123 +154,22 @@ class Qwen(nn.Module): x, cache[e] = layer(x, mask, cache[e]) x = self.ln_f(x[:, T - 1 : T, :]) - return self.lm_head(x), cache + return x, cache -def generate(prompt: mx.array, model: Qwen, temp: 0.0): - def sample(logits): - if temp == 0: - return mx.argmax(logits, axis=-1) - else: - return mx.random.categorical(logits * (1 / temp)) +class Model(nn.Module): + def __init__(self, config: ModelArgs): + super().__init__() + self.transformer = QwenModel(config) + self.lm_head = nn.Linear( + config.hidden_size, config.vocab_size, bias=not config.no_bias + ) - logits, cache = model(prompt) - y = sample(logits[:, -1, :]) - yield y - - while True: - logits, cache = model(y[:, None], cache=cache) - y = sample(logits.squeeze(1)) - yield y - - -def load_model(model_path: str, tokenizer_path: str = "Qwen/Qwen-1_8B"): - model_args = ModelArgs() - - model_path = Path(model_path) - with open(model_path / "config.json", "r") as f: - config = json.load(f) - model_args.vocab_size = config["vocab_size"] - model_args.hidden_size = config["hidden_size"] - model_args.num_attention_heads = config["num_attention_heads"] - model_args.num_hidden_layers = config["num_hidden_layers"] - model_args.kv_channels = config["kv_channels"] - model_args.max_position_embeddings = config["max_position_embeddings"] - model_args.layer_norm_epsilon = config["layer_norm_epsilon"] - model_args.intermediate_size = config["intermediate_size"] - model_args.no_bias = config["no_bias"] - - model = Qwen(model_args) - weights = mx.load(str(model_path / "weights.npz")) - if quantization := config.get("quantization", False): - nn.QuantizedLinear.quantize_module(model, **quantization) - model.update(tree_unflatten(list(weights.items()))) - - tokenizer = AutoTokenizer.from_pretrained( - tokenizer_path, trust_remote_code=True, eos_token="<|endoftext|>" - ) - return model, tokenizer - - -if __name__ == "__main__": - parser = argparse.ArgumentParser(description="Qwen inference script") - parser.add_argument( - "--model-path", - type=str, - default="mlx_model", - help="The path to the model weights and config", - ) - parser.add_argument( - "--tokenizer", - help="The tokenizer to be used, defaults to Qwen/Qwen-1_8B", - default="Qwen/Qwen-1_8B", - ) - parser.add_argument( - "--prompt", - help="The message to be processed by the model", - # The example from the official huggingface repo of Qwen - default="蒙古国的首都是乌兰巴托(Ulaanbaatar)\n冰岛的首都是雷克雅未克(Reykjavik)\n埃塞俄比亚的首都是", - ) - parser.add_argument( - "--max-tokens", - "-m", - type=int, - default=100, - help="Maximum number of tokens to generate", - ) - parser.add_argument( - "--temp", - help="The sampling temperature.", - type=float, - default=0.0, - ) - parser.add_argument("--seed", type=int, default=0, help="The PRNG seed") - args = parser.parse_args() - - mx.random.seed(args.seed) - - model, tokenizer = load_model(args.model_path, args.tokenizer) - - prompt = tokenizer( - args.prompt, - return_tensors="np", - return_attention_mask=False, - )["input_ids"] - - prompt = mx.array(prompt) - - print(args.prompt, end="", flush=True) - - tokens = [] - for token, _ in zip(generate(prompt, model, args.temp), range(args.max_tokens)): - tokens.append(token) - - if (len(tokens) % 10) == 0: - mx.eval(tokens) - eos_index = next( - (i for i, t in enumerate(tokens) if t.item() == tokenizer.eos_token_id), - None, - ) - - if eos_index is not None: - tokens = tokens[:eos_index] - - s = tokenizer.decode([t.item() for t in tokens]) - print(s, end="", flush=True) - tokens = [] - if eos_index is not None: - break - - mx.eval(tokens) - s = tokenizer.decode([t.item() for t in tokens]) - print(s, flush=True) + def __call__( + self, + x: mx.array, + mask: mx.array = None, + cache: mx.array = None, + ) -> Tuple[mx.array, mx.array]: + y, cache = self.transformer(x, mask, cache) + return self.lm_head(y), cache diff --git a/llms/mlx_lm/requirements.txt b/llms/mlx_lm/requirements.txt index c78cefa2..a04cc7bb 100644 --- a/llms/mlx_lm/requirements.txt +++ b/llms/mlx_lm/requirements.txt @@ -1,4 +1,4 @@ mlx numpy -transformers +transformers>=4.37.0 protobuf diff --git a/llms/mlx_lm/utils.py b/llms/mlx_lm/utils.py index a7eaea52..5e8f8e2a 100644 --- a/llms/mlx_lm/utils.py +++ b/llms/mlx_lm/utils.py @@ -10,8 +10,7 @@ from huggingface_hub import snapshot_download from transformers import AutoTokenizer, PreTrainedTokenizer # Local imports -from .models import llama, mixtral, phi2 -from .models.base import BaseModelArgs +from .models import llama, mixtral, phi2, qwen # Constants MODEL_MAPPING = { @@ -19,6 +18,7 @@ MODEL_MAPPING = { "mistral": llama, # mistral is compatible with llama "mixtral": mixtral, "phi": phi2, + "qwen": qwen, } linear_class_predicate = ( @@ -64,7 +64,13 @@ def get_model_path(path_or_hf_repo: str) -> Path: model_path = Path( snapshot_download( repo_id=path_or_hf_repo, - allow_patterns=["*.json", "*.safetensors", "*.py", "tokenizer.model"], + allow_patterns=[ + "*.json", + "*.safetensors", + "*.py", + "tokenizer.model", + "*.tiktoken", + ], ) ) return model_path @@ -196,15 +202,18 @@ def load_model(model_path: Path) -> nn.Module: return model -def load(path_or_hf_repo: str) -> Tuple[nn.Module, PreTrainedTokenizer]: +def load( + path_or_hf_repo: str, tokenizer_config={} +) -> Tuple[nn.Module, PreTrainedTokenizer]: """ Load the model from a given path or a huggingface repository. Args: - path_or_hf_repo (str): The path or the huggingface repository to load the model from. - + model_path (Path): The path or the huggingface repository to load the model from. + tokenizer_config (dict, optional): Configuration parameters specifically for the tokenizer. + Defaults to an empty dictionary. Returns: - Tuple[nn.Module, PreTrainedTokenizer]: The loaded model and tokenizer. + nn.Module: The loaded model. Raises: FileNotFoundError: If config file or safetensors are not found. @@ -213,5 +222,5 @@ def load(path_or_hf_repo: str) -> Tuple[nn.Module, PreTrainedTokenizer]: model_path = get_model_path(path_or_hf_repo) model = load_model(model_path) - tokenizer = AutoTokenizer.from_pretrained(model_path) + tokenizer = AutoTokenizer.from_pretrained(model_path, **tokenizer_config) return model, tokenizer diff --git a/llms/qwen/README.md b/llms/qwen/README.md deleted file mode 100644 index 75154325..00000000 --- a/llms/qwen/README.md +++ /dev/null @@ -1,45 +0,0 @@ -# Qwen - -Qwen (通义千问) are a family of language models developed by Alibaba Cloud.[^1] -The architecture of the Qwen models is similar to Llama except for the bias in -the attention layers. - -## Setup - -First download and convert the model with: - -```sh -python convert.py -``` - -To generate a 4-bit quantized model, use ``-q``. For a full list of options: - -The script downloads the model from Hugging Face. The default model is -`Qwen/Qwen-1_8B`. Check out the [Hugging Face -page](https://huggingface.co/Qwen) to see a list of available models. - -By default, the conversion script will make the directory `mlx_model` and save -the converted `weights.npz` and `config.json` there. - -## Generate - -To generate text with the default prompt: - -```sh -python qwen.py -``` - -If you change the model, make sure to pass the corresponding tokenizer. E.g., -for Qwen 7B use: - -``` -python qwen.py --tokenizer Qwen/Qwen-7B -``` - -To see a list of options, run: - -```sh -python qwen.py --help -``` - -[^1]: For more details on the model see the official repo of [Qwen](https://github.com/QwenLM/Qwen) and the [Hugging Face](https://huggingface.co/Qwen). diff --git a/llms/qwen/convert.py b/llms/qwen/convert.py deleted file mode 100644 index 9f7bed89..00000000 --- a/llms/qwen/convert.py +++ /dev/null @@ -1,115 +0,0 @@ -import argparse -import copy -import json -from pathlib import Path - -import mlx.core as mx -import mlx.nn as nn -import numpy as np -import torch -from mlx.utils import tree_flatten, tree_map, tree_unflatten -from qwen import ModelArgs, Qwen -from transformers import AutoModelForCausalLM - - -def replace_key(key: str) -> str: - if key.startswith("transformer."): - # remove transformer prefix - key = key.replace("transformer.", "") - - return key - - -def quantize(weights, config, args): - quantized_config = copy.deepcopy(config) - - # Load the model: - model_args = ModelArgs() - model_args.vocab_size = config["vocab_size"] - model_args.hidden_size = config["hidden_size"] - model_args.num_attention_heads = config["num_attention_heads"] - model_args.num_hidden_layers = config["num_hidden_layers"] - model_args.kv_channels = config["kv_channels"] - model_args.max_position_embeddings = config["max_position_embeddings"] - model_args.layer_norm_epsilon = config["layer_norm_epsilon"] - model_args.intermediate_size = config["intermediate_size"] - model_args.no_bias = config["no_bias"] - model = Qwen(model_args) - - weights = tree_map(mx.array, weights) - model.update(tree_unflatten(list(weights.items()))) - - # Quantize the model: - nn.QuantizedLinear.quantize_module(model, args.q_group_size, args.q_bits) - - # Update the config: - quantized_config["quantization"] = { - "group_size": args.q_group_size, - "bits": args.q_bits, - } - quantized_weights = dict(tree_flatten(model.parameters())) - - return quantized_weights, quantized_config - - -def convert(args): - mlx_path = Path(args.mlx_path) - mlx_path.mkdir(parents=True, exist_ok=True) - - model = AutoModelForCausalLM.from_pretrained( - args.model, trust_remote_code=True, torch_dtype=torch.float16 - ) - state_dict = model.state_dict() - weights = { - replace_key(k): ( - v.numpy() if v.dtype != torch.bfloat16 else v.to(torch.float32).numpy() - ) - for k, v in state_dict.items() - } - config = model.config.to_dict() - - if args.quantize: - print("[INFO] Quantizing") - weights, config = quantize(weights, config, args) - - np.savez(str(mlx_path / "weights.npz"), **weights) - - # write config - with open(mlx_path / "config.json", "w") as f: - json.dump(config, f, indent=4) - - -if __name__ == "__main__": - parser = argparse.ArgumentParser(description="Convert Qwen model to npz") - - parser.add_argument( - "--model", - help="The huggingface model to be converted", - default="Qwen/Qwen-1_8B", - ) - parser.add_argument( - "--mlx-path", - type=str, - default="mlx_model", - help="The path to save the MLX model.", - ) - parser.add_argument( - "-q", - "--quantize", - help="Generate a quantized model.", - action="store_true", - ) - parser.add_argument( - "--q-group-size", - help="Group size for quantization.", - type=int, - default=64, - ) - parser.add_argument( - "--q-bits", - help="Bits per weight for quantization.", - type=int, - default=4, - ) - args = parser.parse_args() - convert(args) diff --git a/llms/qwen/requirements.txt b/llms/qwen/requirements.txt deleted file mode 100644 index 0ce17aec..00000000 --- a/llms/qwen/requirements.txt +++ /dev/null @@ -1,7 +0,0 @@ -einops -mlx -numpy -transformers>=4.35 -transformers_stream_generator>=0.0.4 -torch -tiktoken