mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-25 01:41:19 +08:00
refactor(qwen): moving qwen into mlx-lm (#312)
* refactor(qwen): moving qwen into mlx-lm * chore: update doc * chore: fix type hint * add qwen model support in convert * chore: fix doc * chore: only load model in quantize_model * chore: make the convert script only copy tokenizer files instead of load it and save * chore: update docstring * chore: remove unnecessary try catch * chore: clean up for tokenizer and update transformers 4.37 * nits in README --------- Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
parent
de15532da8
commit
30be4c4734
@ -102,11 +102,26 @@ Here are a few examples of Hugging Face models that work with this example:
|
|||||||
- [01-ai/Yi-6B-Chat](https://huggingface.co/01-ai/Yi-6B-Chat)
|
- [01-ai/Yi-6B-Chat](https://huggingface.co/01-ai/Yi-6B-Chat)
|
||||||
- [microsoft/phi-2](https://huggingface.co/microsoft/phi-2)
|
- [microsoft/phi-2](https://huggingface.co/microsoft/phi-2)
|
||||||
- [mistralai/Mixtral-8x7B-Instruct-v0.1](https://huggingface.co/mistralai/Mixtral-8x7B-Instruct-v0.1)
|
- [mistralai/Mixtral-8x7B-Instruct-v0.1](https://huggingface.co/mistralai/Mixtral-8x7B-Instruct-v0.1)
|
||||||
|
- [Qwen/Qwen-7B](https://huggingface.co/Qwen/Qwen-7B)
|
||||||
|
|
||||||
Most
|
Most
|
||||||
[Mistral](https://huggingface.co/models?library=transformers,safetensors&other=mistral&sort=trending),
|
[Mistral](https://huggingface.co/models?library=transformers,safetensors&other=mistral&sort=trending),
|
||||||
[Llama](https://huggingface.co/models?library=transformers,safetensors&other=llama&sort=trending),
|
[Llama](https://huggingface.co/models?library=transformers,safetensors&other=llama&sort=trending),
|
||||||
[Phi-2](https://huggingface.co/models?library=transformers,safetensors&other=phi&sort=trending)
|
[Phi-2](https://huggingface.co/models?library=transformers,safetensors&other=phi&sort=trending),
|
||||||
and
|
and
|
||||||
[Mixtral](https://huggingface.co/models?library=transformers,safetensors&other=mixtral&sort=trending)
|
[Mixtral](https://huggingface.co/models?library=transformers,safetensors&other=mixtral&sort=trending)
|
||||||
style models should work out of the box.
|
style models should work out of the box.
|
||||||
|
|
||||||
|
For
|
||||||
|
[Qwen](https://huggingface.co/models?library=transformers,safetensors&other=qwen&sort=trending)
|
||||||
|
style models, you must enable the `trust_remote_code` option and specify the
|
||||||
|
`eos_token`. This ensures the tokenizer works correctly. You can do this by
|
||||||
|
passing `--trust-remote-code` and `--eos-token "<|endoftext|>"` in the command
|
||||||
|
line, or by setting these options in the Python API:
|
||||||
|
|
||||||
|
```python
|
||||||
|
model, tokenizer = load(
|
||||||
|
"qwen/Qwen-7B",
|
||||||
|
tokenizer_config={"eos_token": "<|endoftext|>", "trust_remote_code": True},
|
||||||
|
)
|
||||||
|
```
|
||||||
|
@ -21,6 +21,17 @@ def setup_arg_parser():
|
|||||||
default="mlx_model",
|
default="mlx_model",
|
||||||
help="The path to the local model directory or Hugging Face repo.",
|
help="The path to the local model directory or Hugging Face repo.",
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--trust-remote-code",
|
||||||
|
action="store_true",
|
||||||
|
help="Enable trusting remote code for tokenizer",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--eos-token",
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
help="End of sequence token for tokenizer",
|
||||||
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--prompt", default=DEFAULT_PROMPT, help="Message to be processed by the model"
|
"--prompt", default=DEFAULT_PROMPT, help="Message to be processed by the model"
|
||||||
)
|
)
|
||||||
@ -40,7 +51,13 @@ def setup_arg_parser():
|
|||||||
|
|
||||||
def main(args):
|
def main(args):
|
||||||
mx.random.seed(args.seed)
|
mx.random.seed(args.seed)
|
||||||
model, tokenizer = load(args.model)
|
|
||||||
|
# Building tokenizer_config
|
||||||
|
tokenizer_config = {"trust_remote_code": True if args.trust_remote_code else None}
|
||||||
|
if args.eos_token is not None:
|
||||||
|
tokenizer_config["eos_token"] = args.eos_token
|
||||||
|
|
||||||
|
model, tokenizer = load(args.model, tokenizer_config=tokenizer_config)
|
||||||
print("=" * 10)
|
print("=" * 10)
|
||||||
print("Prompt:", args.prompt)
|
print("Prompt:", args.prompt)
|
||||||
prompt = tokenizer.encode(args.prompt)
|
prompt = tokenizer.encode(args.prompt)
|
||||||
|
@ -1,16 +1,14 @@
|
|||||||
import argparse
|
|
||||||
import json
|
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from pathlib import Path
|
from typing import Tuple
|
||||||
|
|
||||||
import mlx.core as mx
|
import mlx.core as mx
|
||||||
import mlx.nn as nn
|
import mlx.nn as nn
|
||||||
from mlx.utils import tree_unflatten
|
|
||||||
from transformers import AutoTokenizer
|
from .base import BaseModelArgs
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class ModelArgs:
|
class ModelArgs(BaseModelArgs):
|
||||||
hidden_size: int = 2048
|
hidden_size: int = 2048
|
||||||
num_attention_heads: int = 16
|
num_attention_heads: int = 16
|
||||||
num_hidden_layers: int = 24
|
num_hidden_layers: int = 24
|
||||||
@ -20,6 +18,11 @@ class ModelArgs:
|
|||||||
intermediate_size: int = 11008
|
intermediate_size: int = 11008
|
||||||
no_bias: bool = True
|
no_bias: bool = True
|
||||||
vocab_size: int = 151936
|
vocab_size: int = 151936
|
||||||
|
num_key_value_heads = None
|
||||||
|
|
||||||
|
def __post_init__(self):
|
||||||
|
if self.num_key_value_heads is None:
|
||||||
|
self.num_key_value_heads = self.num_attention_heads
|
||||||
|
|
||||||
|
|
||||||
class RMSNorm(nn.Module):
|
class RMSNorm(nn.Module):
|
||||||
@ -95,7 +98,7 @@ class MLP(nn.Module):
|
|||||||
args.hidden_size, args.intermediate_size // 2, bias=not args.no_bias
|
args.hidden_size, args.intermediate_size // 2, bias=not args.no_bias
|
||||||
)
|
)
|
||||||
self.w2 = nn.Linear(
|
self.w2 = nn.Linear(
|
||||||
args.intermediate_size // 2, args.hidden_size, bias=not args.no_bias
|
args.hidden_size, args.intermediate_size // 2, bias=not args.no_bias
|
||||||
)
|
)
|
||||||
self.c_proj = nn.Linear(
|
self.c_proj = nn.Linear(
|
||||||
args.intermediate_size // 2, args.hidden_size, bias=not args.no_bias
|
args.intermediate_size // 2, args.hidden_size, bias=not args.no_bias
|
||||||
@ -128,17 +131,12 @@ class TransformerBlock(nn.Module):
|
|||||||
return x, cache
|
return x, cache
|
||||||
|
|
||||||
|
|
||||||
class Qwen(nn.Module):
|
class QwenModel(nn.Module):
|
||||||
def __init__(self, args: ModelArgs):
|
def __init__(self, args: ModelArgs):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
self.embed_dim = args.hidden_size
|
|
||||||
|
|
||||||
self.wte = nn.Embedding(args.vocab_size, args.hidden_size)
|
self.wte = nn.Embedding(args.vocab_size, args.hidden_size)
|
||||||
self.h = [TransformerBlock(args) for _ in range(args.num_hidden_layers)]
|
self.h = [TransformerBlock(args) for _ in range(args.num_hidden_layers)]
|
||||||
self.ln_f = RMSNorm(self.embed_dim, eps=args.layer_norm_epsilon)
|
self.ln_f = RMSNorm(args.hidden_size, eps=args.layer_norm_epsilon)
|
||||||
|
|
||||||
self.lm_head = nn.Linear(self.embed_dim, args.vocab_size, bias=False)
|
|
||||||
|
|
||||||
def __call__(self, inputs, mask=None, cache=None):
|
def __call__(self, inputs, mask=None, cache=None):
|
||||||
x = self.wte(inputs)
|
x = self.wte(inputs)
|
||||||
@ -156,123 +154,22 @@ class Qwen(nn.Module):
|
|||||||
x, cache[e] = layer(x, mask, cache[e])
|
x, cache[e] = layer(x, mask, cache[e])
|
||||||
|
|
||||||
x = self.ln_f(x[:, T - 1 : T, :])
|
x = self.ln_f(x[:, T - 1 : T, :])
|
||||||
return self.lm_head(x), cache
|
return x, cache
|
||||||
|
|
||||||
|
|
||||||
def generate(prompt: mx.array, model: Qwen, temp: 0.0):
|
class Model(nn.Module):
|
||||||
def sample(logits):
|
def __init__(self, config: ModelArgs):
|
||||||
if temp == 0:
|
super().__init__()
|
||||||
return mx.argmax(logits, axis=-1)
|
self.transformer = QwenModel(config)
|
||||||
else:
|
self.lm_head = nn.Linear(
|
||||||
return mx.random.categorical(logits * (1 / temp))
|
config.hidden_size, config.vocab_size, bias=not config.no_bias
|
||||||
|
|
||||||
logits, cache = model(prompt)
|
|
||||||
y = sample(logits[:, -1, :])
|
|
||||||
yield y
|
|
||||||
|
|
||||||
while True:
|
|
||||||
logits, cache = model(y[:, None], cache=cache)
|
|
||||||
y = sample(logits.squeeze(1))
|
|
||||||
yield y
|
|
||||||
|
|
||||||
|
|
||||||
def load_model(model_path: str, tokenizer_path: str = "Qwen/Qwen-1_8B"):
|
|
||||||
model_args = ModelArgs()
|
|
||||||
|
|
||||||
model_path = Path(model_path)
|
|
||||||
with open(model_path / "config.json", "r") as f:
|
|
||||||
config = json.load(f)
|
|
||||||
model_args.vocab_size = config["vocab_size"]
|
|
||||||
model_args.hidden_size = config["hidden_size"]
|
|
||||||
model_args.num_attention_heads = config["num_attention_heads"]
|
|
||||||
model_args.num_hidden_layers = config["num_hidden_layers"]
|
|
||||||
model_args.kv_channels = config["kv_channels"]
|
|
||||||
model_args.max_position_embeddings = config["max_position_embeddings"]
|
|
||||||
model_args.layer_norm_epsilon = config["layer_norm_epsilon"]
|
|
||||||
model_args.intermediate_size = config["intermediate_size"]
|
|
||||||
model_args.no_bias = config["no_bias"]
|
|
||||||
|
|
||||||
model = Qwen(model_args)
|
|
||||||
weights = mx.load(str(model_path / "weights.npz"))
|
|
||||||
if quantization := config.get("quantization", False):
|
|
||||||
nn.QuantizedLinear.quantize_module(model, **quantization)
|
|
||||||
model.update(tree_unflatten(list(weights.items())))
|
|
||||||
|
|
||||||
tokenizer = AutoTokenizer.from_pretrained(
|
|
||||||
tokenizer_path, trust_remote_code=True, eos_token="<|endoftext|>"
|
|
||||||
)
|
|
||||||
return model, tokenizer
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
parser = argparse.ArgumentParser(description="Qwen inference script")
|
|
||||||
parser.add_argument(
|
|
||||||
"--model-path",
|
|
||||||
type=str,
|
|
||||||
default="mlx_model",
|
|
||||||
help="The path to the model weights and config",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--tokenizer",
|
|
||||||
help="The tokenizer to be used, defaults to Qwen/Qwen-1_8B",
|
|
||||||
default="Qwen/Qwen-1_8B",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--prompt",
|
|
||||||
help="The message to be processed by the model",
|
|
||||||
# The example from the official huggingface repo of Qwen
|
|
||||||
default="蒙古国的首都是乌兰巴托(Ulaanbaatar)\n冰岛的首都是雷克雅未克(Reykjavik)\n埃塞俄比亚的首都是",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--max-tokens",
|
|
||||||
"-m",
|
|
||||||
type=int,
|
|
||||||
default=100,
|
|
||||||
help="Maximum number of tokens to generate",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--temp",
|
|
||||||
help="The sampling temperature.",
|
|
||||||
type=float,
|
|
||||||
default=0.0,
|
|
||||||
)
|
|
||||||
parser.add_argument("--seed", type=int, default=0, help="The PRNG seed")
|
|
||||||
args = parser.parse_args()
|
|
||||||
|
|
||||||
mx.random.seed(args.seed)
|
|
||||||
|
|
||||||
model, tokenizer = load_model(args.model_path, args.tokenizer)
|
|
||||||
|
|
||||||
prompt = tokenizer(
|
|
||||||
args.prompt,
|
|
||||||
return_tensors="np",
|
|
||||||
return_attention_mask=False,
|
|
||||||
)["input_ids"]
|
|
||||||
|
|
||||||
prompt = mx.array(prompt)
|
|
||||||
|
|
||||||
print(args.prompt, end="", flush=True)
|
|
||||||
|
|
||||||
tokens = []
|
|
||||||
for token, _ in zip(generate(prompt, model, args.temp), range(args.max_tokens)):
|
|
||||||
tokens.append(token)
|
|
||||||
|
|
||||||
if (len(tokens) % 10) == 0:
|
|
||||||
mx.eval(tokens)
|
|
||||||
eos_index = next(
|
|
||||||
(i for i, t in enumerate(tokens) if t.item() == tokenizer.eos_token_id),
|
|
||||||
None,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
if eos_index is not None:
|
def __call__(
|
||||||
tokens = tokens[:eos_index]
|
self,
|
||||||
|
x: mx.array,
|
||||||
s = tokenizer.decode([t.item() for t in tokens])
|
mask: mx.array = None,
|
||||||
print(s, end="", flush=True)
|
cache: mx.array = None,
|
||||||
tokens = []
|
) -> Tuple[mx.array, mx.array]:
|
||||||
if eos_index is not None:
|
y, cache = self.transformer(x, mask, cache)
|
||||||
break
|
return self.lm_head(y), cache
|
||||||
|
|
||||||
mx.eval(tokens)
|
|
||||||
s = tokenizer.decode([t.item() for t in tokens])
|
|
||||||
print(s, flush=True)
|
|
@ -1,4 +1,4 @@
|
|||||||
mlx
|
mlx
|
||||||
numpy
|
numpy
|
||||||
transformers
|
transformers>=4.37.0
|
||||||
protobuf
|
protobuf
|
||||||
|
@ -10,8 +10,7 @@ from huggingface_hub import snapshot_download
|
|||||||
from transformers import AutoTokenizer, PreTrainedTokenizer
|
from transformers import AutoTokenizer, PreTrainedTokenizer
|
||||||
|
|
||||||
# Local imports
|
# Local imports
|
||||||
from .models import llama, mixtral, phi2
|
from .models import llama, mixtral, phi2, qwen
|
||||||
from .models.base import BaseModelArgs
|
|
||||||
|
|
||||||
# Constants
|
# Constants
|
||||||
MODEL_MAPPING = {
|
MODEL_MAPPING = {
|
||||||
@ -19,6 +18,7 @@ MODEL_MAPPING = {
|
|||||||
"mistral": llama, # mistral is compatible with llama
|
"mistral": llama, # mistral is compatible with llama
|
||||||
"mixtral": mixtral,
|
"mixtral": mixtral,
|
||||||
"phi": phi2,
|
"phi": phi2,
|
||||||
|
"qwen": qwen,
|
||||||
}
|
}
|
||||||
|
|
||||||
linear_class_predicate = (
|
linear_class_predicate = (
|
||||||
@ -64,7 +64,13 @@ def get_model_path(path_or_hf_repo: str) -> Path:
|
|||||||
model_path = Path(
|
model_path = Path(
|
||||||
snapshot_download(
|
snapshot_download(
|
||||||
repo_id=path_or_hf_repo,
|
repo_id=path_or_hf_repo,
|
||||||
allow_patterns=["*.json", "*.safetensors", "*.py", "tokenizer.model"],
|
allow_patterns=[
|
||||||
|
"*.json",
|
||||||
|
"*.safetensors",
|
||||||
|
"*.py",
|
||||||
|
"tokenizer.model",
|
||||||
|
"*.tiktoken",
|
||||||
|
],
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
return model_path
|
return model_path
|
||||||
@ -196,15 +202,18 @@ def load_model(model_path: Path) -> nn.Module:
|
|||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
def load(path_or_hf_repo: str) -> Tuple[nn.Module, PreTrainedTokenizer]:
|
def load(
|
||||||
|
path_or_hf_repo: str, tokenizer_config={}
|
||||||
|
) -> Tuple[nn.Module, PreTrainedTokenizer]:
|
||||||
"""
|
"""
|
||||||
Load the model from a given path or a huggingface repository.
|
Load the model from a given path or a huggingface repository.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
path_or_hf_repo (str): The path or the huggingface repository to load the model from.
|
model_path (Path): The path or the huggingface repository to load the model from.
|
||||||
|
tokenizer_config (dict, optional): Configuration parameters specifically for the tokenizer.
|
||||||
|
Defaults to an empty dictionary.
|
||||||
Returns:
|
Returns:
|
||||||
Tuple[nn.Module, PreTrainedTokenizer]: The loaded model and tokenizer.
|
nn.Module: The loaded model.
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
FileNotFoundError: If config file or safetensors are not found.
|
FileNotFoundError: If config file or safetensors are not found.
|
||||||
@ -213,5 +222,5 @@ def load(path_or_hf_repo: str) -> Tuple[nn.Module, PreTrainedTokenizer]:
|
|||||||
model_path = get_model_path(path_or_hf_repo)
|
model_path = get_model_path(path_or_hf_repo)
|
||||||
|
|
||||||
model = load_model(model_path)
|
model = load_model(model_path)
|
||||||
tokenizer = AutoTokenizer.from_pretrained(model_path)
|
tokenizer = AutoTokenizer.from_pretrained(model_path, **tokenizer_config)
|
||||||
return model, tokenizer
|
return model, tokenizer
|
||||||
|
@ -1,45 +0,0 @@
|
|||||||
# Qwen
|
|
||||||
|
|
||||||
Qwen (通义千问) are a family of language models developed by Alibaba Cloud.[^1]
|
|
||||||
The architecture of the Qwen models is similar to Llama except for the bias in
|
|
||||||
the attention layers.
|
|
||||||
|
|
||||||
## Setup
|
|
||||||
|
|
||||||
First download and convert the model with:
|
|
||||||
|
|
||||||
```sh
|
|
||||||
python convert.py
|
|
||||||
```
|
|
||||||
|
|
||||||
To generate a 4-bit quantized model, use ``-q``. For a full list of options:
|
|
||||||
|
|
||||||
The script downloads the model from Hugging Face. The default model is
|
|
||||||
`Qwen/Qwen-1_8B`. Check out the [Hugging Face
|
|
||||||
page](https://huggingface.co/Qwen) to see a list of available models.
|
|
||||||
|
|
||||||
By default, the conversion script will make the directory `mlx_model` and save
|
|
||||||
the converted `weights.npz` and `config.json` there.
|
|
||||||
|
|
||||||
## Generate
|
|
||||||
|
|
||||||
To generate text with the default prompt:
|
|
||||||
|
|
||||||
```sh
|
|
||||||
python qwen.py
|
|
||||||
```
|
|
||||||
|
|
||||||
If you change the model, make sure to pass the corresponding tokenizer. E.g.,
|
|
||||||
for Qwen 7B use:
|
|
||||||
|
|
||||||
```
|
|
||||||
python qwen.py --tokenizer Qwen/Qwen-7B
|
|
||||||
```
|
|
||||||
|
|
||||||
To see a list of options, run:
|
|
||||||
|
|
||||||
```sh
|
|
||||||
python qwen.py --help
|
|
||||||
```
|
|
||||||
|
|
||||||
[^1]: For more details on the model see the official repo of [Qwen](https://github.com/QwenLM/Qwen) and the [Hugging Face](https://huggingface.co/Qwen).
|
|
@ -1,115 +0,0 @@
|
|||||||
import argparse
|
|
||||||
import copy
|
|
||||||
import json
|
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
import mlx.core as mx
|
|
||||||
import mlx.nn as nn
|
|
||||||
import numpy as np
|
|
||||||
import torch
|
|
||||||
from mlx.utils import tree_flatten, tree_map, tree_unflatten
|
|
||||||
from qwen import ModelArgs, Qwen
|
|
||||||
from transformers import AutoModelForCausalLM
|
|
||||||
|
|
||||||
|
|
||||||
def replace_key(key: str) -> str:
|
|
||||||
if key.startswith("transformer."):
|
|
||||||
# remove transformer prefix
|
|
||||||
key = key.replace("transformer.", "")
|
|
||||||
|
|
||||||
return key
|
|
||||||
|
|
||||||
|
|
||||||
def quantize(weights, config, args):
|
|
||||||
quantized_config = copy.deepcopy(config)
|
|
||||||
|
|
||||||
# Load the model:
|
|
||||||
model_args = ModelArgs()
|
|
||||||
model_args.vocab_size = config["vocab_size"]
|
|
||||||
model_args.hidden_size = config["hidden_size"]
|
|
||||||
model_args.num_attention_heads = config["num_attention_heads"]
|
|
||||||
model_args.num_hidden_layers = config["num_hidden_layers"]
|
|
||||||
model_args.kv_channels = config["kv_channels"]
|
|
||||||
model_args.max_position_embeddings = config["max_position_embeddings"]
|
|
||||||
model_args.layer_norm_epsilon = config["layer_norm_epsilon"]
|
|
||||||
model_args.intermediate_size = config["intermediate_size"]
|
|
||||||
model_args.no_bias = config["no_bias"]
|
|
||||||
model = Qwen(model_args)
|
|
||||||
|
|
||||||
weights = tree_map(mx.array, weights)
|
|
||||||
model.update(tree_unflatten(list(weights.items())))
|
|
||||||
|
|
||||||
# Quantize the model:
|
|
||||||
nn.QuantizedLinear.quantize_module(model, args.q_group_size, args.q_bits)
|
|
||||||
|
|
||||||
# Update the config:
|
|
||||||
quantized_config["quantization"] = {
|
|
||||||
"group_size": args.q_group_size,
|
|
||||||
"bits": args.q_bits,
|
|
||||||
}
|
|
||||||
quantized_weights = dict(tree_flatten(model.parameters()))
|
|
||||||
|
|
||||||
return quantized_weights, quantized_config
|
|
||||||
|
|
||||||
|
|
||||||
def convert(args):
|
|
||||||
mlx_path = Path(args.mlx_path)
|
|
||||||
mlx_path.mkdir(parents=True, exist_ok=True)
|
|
||||||
|
|
||||||
model = AutoModelForCausalLM.from_pretrained(
|
|
||||||
args.model, trust_remote_code=True, torch_dtype=torch.float16
|
|
||||||
)
|
|
||||||
state_dict = model.state_dict()
|
|
||||||
weights = {
|
|
||||||
replace_key(k): (
|
|
||||||
v.numpy() if v.dtype != torch.bfloat16 else v.to(torch.float32).numpy()
|
|
||||||
)
|
|
||||||
for k, v in state_dict.items()
|
|
||||||
}
|
|
||||||
config = model.config.to_dict()
|
|
||||||
|
|
||||||
if args.quantize:
|
|
||||||
print("[INFO] Quantizing")
|
|
||||||
weights, config = quantize(weights, config, args)
|
|
||||||
|
|
||||||
np.savez(str(mlx_path / "weights.npz"), **weights)
|
|
||||||
|
|
||||||
# write config
|
|
||||||
with open(mlx_path / "config.json", "w") as f:
|
|
||||||
json.dump(config, f, indent=4)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
parser = argparse.ArgumentParser(description="Convert Qwen model to npz")
|
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--model",
|
|
||||||
help="The huggingface model to be converted",
|
|
||||||
default="Qwen/Qwen-1_8B",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--mlx-path",
|
|
||||||
type=str,
|
|
||||||
default="mlx_model",
|
|
||||||
help="The path to save the MLX model.",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"-q",
|
|
||||||
"--quantize",
|
|
||||||
help="Generate a quantized model.",
|
|
||||||
action="store_true",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--q-group-size",
|
|
||||||
help="Group size for quantization.",
|
|
||||||
type=int,
|
|
||||||
default=64,
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--q-bits",
|
|
||||||
help="Bits per weight for quantization.",
|
|
||||||
type=int,
|
|
||||||
default=4,
|
|
||||||
)
|
|
||||||
args = parser.parse_args()
|
|
||||||
convert(args)
|
|
@ -1,7 +0,0 @@
|
|||||||
einops
|
|
||||||
mlx
|
|
||||||
numpy
|
|
||||||
transformers>=4.35
|
|
||||||
transformers_stream_generator>=0.0.4
|
|
||||||
torch
|
|
||||||
tiktoken
|
|
Loading…
Reference in New Issue
Block a user