add deepseek coder example (#172)

* feat: add example for deepseek coder

* chore: remove hardcoded rope_scaling_factor

* feat: add quantization support

* chore: update readme

* chore: clean up the rope scalling factor param in create cos sin theta

* feat: add repetition_penalty

* style /consistency changes to ease future integration

* nits in README

* one more typo

---------

Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
Anchen 2023-12-29 16:42:22 +11:00 committed by GitHub
parent 37fd2464dc
commit 31ddbd7806
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 529 additions and 11 deletions

View File

@ -0,0 +1,45 @@
# Deepseek Coder
Deepseek Coder is a family of code generating language models based on the
Llama architecture.[^1] The models were trained from scratch on a corpus of 2T
tokens, with a composition of 87% code and 13% natural language containing both
English and Chinese.
### Setup
Install the dependencies:
```
pip install -r requirements.txt
```
Next, download and convert the model.
```sh
python convert.py --hf-path <path_to_huggingface_model>
```
To generate a 4-bit quantized model, use `-q`. For a full list of options run:
```
python convert.py --help
```
The converter downloads the model from Hugging Face. The default model is
`deepseek-ai/deepseek-coder-6.7b-instruct`. Check out the [Hugging Face
page](https://huggingface.co/deepseek-ai) to see a list of available models.
By default, the conversion script will save the converted `weights.npz`,
tokenizer, and `config.json` in the `mlx_model` directory.
### Run
Once you've converted the weights, you can interact with the Deepseek coder
model:
```
python deepseek_coder.py --prompt "write a quick sort algorithm in python."
```
[^1]: For more information [blog post](https://deepseekcoder.github.io/) by
DeepSeek AI

View File

@ -0,0 +1,154 @@
import argparse
import copy
import json
from pathlib import Path
import mlx.core as mx
import mlx.nn as nn
import numpy as np
import torch
from deepseek_coder import DeepseekCoder, ModelArgs
from mlx.utils import tree_flatten, tree_map, tree_unflatten
from transformers import AutoModelForCausalLM, AutoTokenizer
def quantize(weights, config, args):
quantized_config = copy.deepcopy(config)
# Load the model:
model_args = ModelArgs(**config)
model = DeepseekCoder(model_args)
weights = tree_map(mx.array, weights)
model.update(tree_unflatten(list(weights.items())))
# Quantize the model:
nn.QuantizedLinear.quantize_module(model, args.q_group_size, args.q_bits)
# Update the config:
quantized_config["quantization"] = {
"group_size": args.q_group_size,
"bits": args.q_bits,
}
quantized_weights = dict(tree_flatten(model.parameters()))
return quantized_weights, quantized_config
def convert(args):
hf_path = Path(args.hf_path)
model = AutoModelForCausalLM.from_pretrained(
str(hf_path), trust_remote_code=True, torch_dtype=torch.float16
)
config = model.config.to_dict()
state_dict = model.state_dict()
tokenizer = AutoTokenizer.from_pretrained(str(hf_path), trust_remote_code=True)
# things to change
# 1. there's no "model." in the weight names
state_dict = {k.replace("model.", ""): v for k, v in state_dict.items()}
# 2. mlp is called feed_forward
state_dict = {k.replace("mlp", "feed_forward"): v for k, v in state_dict.items()}
# 3. up_proj, down_proj, gate_proj
state_dict = {k.replace("down_proj", "w2"): v for k, v in state_dict.items()}
state_dict = {k.replace("up_proj", "w3"): v for k, v in state_dict.items()}
state_dict = {k.replace("gate_proj", "w1"): v for k, v in state_dict.items()}
# 4. layernorms
state_dict = {
k.replace("input_layernorm", "attention_norm"): v for k, v in state_dict.items()
}
state_dict = {
k.replace("post_attention_layernorm", "ffn_norm"): v
for k, v in state_dict.items()
}
# 5. lm head
state_dict = {k.replace("lm_head", "output"): v for k, v in state_dict.items()}
# 6. token emb
state_dict = {
k.replace("embed_tokens", "tok_embeddings"): v for k, v in state_dict.items()
}
# 7. attention
state_dict = {k.replace("self_attn", "attention"): v for k, v in state_dict.items()}
state_dict = {k.replace("q_proj", "wq"): v for k, v in state_dict.items()}
state_dict = {k.replace("k_proj", "wk"): v for k, v in state_dict.items()}
state_dict = {k.replace("v_proj", "wv"): v for k, v in state_dict.items()}
state_dict = {k.replace("o_proj", "wo"): v for k, v in state_dict.items()}
weights = {k: v.numpy() for k, v in state_dict.items()}
config["rope_scaling_factor"] = config["rope_scaling"]["factor"]
keep_keys = set(
[
"vocab_size",
"hidden_size",
"num_attention_heads",
"num_key_value_heads",
"num_hidden_layers",
"max_position_embeddings",
"rms_norm_eps",
"intermediate_size",
"rope_scaling_factor",
]
)
for k in list(config.keys()):
if k not in keep_keys:
config.pop(k)
return weights, config, tokenizer
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Convert Deepseek coder model to npz")
parser.add_argument(
"--hf-path",
help="The huggingface model to be converted",
default="deepseek-ai/deepseek-coder-6.7b-instruct",
)
parser.add_argument(
"--mlx-path",
type=str,
default="mlx_model",
help="The path to save the MLX model.",
)
parser.add_argument(
"-q",
"--quantize",
help="Generate a quantized model.",
action="store_true",
)
parser.add_argument(
"--q-group-size",
help="Group size for quantization.",
type=int,
default=64,
)
parser.add_argument(
"--q-bits",
help="Bits per weight for quantization.",
type=int,
default=4,
)
args = parser.parse_args()
mlx_path = Path(args.mlx_path)
mlx_path.mkdir(parents=True, exist_ok=True)
weights, config, tokenizer = convert(args)
if args.quantize:
print("[INFO] Quantizing")
weights, config = quantize(weights, config, args)
np.savez(str(mlx_path / "weights.npz"), **weights)
tokenizer.save_pretrained(mlx_path)
with open(mlx_path / "config.json", "w") as f:
config["model_type"] = "deepseek_coder"
json.dump(config, f, indent=4)

View File

@ -0,0 +1,309 @@
import argparse
import json
import math
from dataclasses import dataclass
from pathlib import Path
from typing import Optional, Tuple
import mlx.core as mx
import mlx.nn as nn
from mlx.utils import tree_unflatten
from transformers import AutoTokenizer
@dataclass
class ModelArgs:
hidden_size: int = 4096
num_attention_heads: int = 32
num_hidden_layers: int = 32
num_key_value_heads: int = 32
max_position_embeddings: int = 16384
rms_norm_eps: float = 1e-6
intermediate_size: int = 11008
rope_theta: float = 100000
rope_scaling_factor: float = 4.0
vocab_size: int = 32256
class RMSNorm(nn.Module):
def __init__(self, dims: int, eps: float = 1e-5):
super().__init__()
self.weight = mx.ones((dims,))
self.eps = eps
def _norm(self, x):
return x * mx.rsqrt(x.square().mean(-1, keepdims=True) + self.eps)
def __call__(self, x):
output = self._norm(x.astype(mx.float32)).astype(x.dtype)
return self.weight * output
class LinearScalingRoPE(nn.RoPE):
def __init__(
self, dims: int, rope_scaling_factor: float = 4.0, base: float = 10000
):
super().__init__(dims)
self.base = base
self.rope_scaling_factor = rope_scaling_factor
def __call__(self, x, offset: int = 0):
shape = x.shape
x = mx.reshape(x, (-1, shape[-2], shape[-1]))
N = x.shape[1] + offset
costheta, sintheta = LinearScalingRoPE.create_cos_sin_theta(
N,
self.dims,
offset=offset,
base=self.base,
rope_scaling_factor=self.rope_scaling_factor,
dtype=x.dtype,
)
rx = self._compute_rope(costheta, sintheta, x)
return mx.reshape(rx, shape)
@staticmethod
def create_cos_sin_theta(
N: int,
D: int,
offset: int = 0,
base: float = 10000,
rope_scaling_factor: float = 1.0,
dtype=mx.float32,
):
D = D // 2
positions = mx.arange(offset, N, dtype=dtype)
positions = positions / rope_scaling_factor
freqs = mx.exp(-mx.arange(0.0, D, dtype=dtype) * (math.log(base) / D))
theta = mx.reshape(positions, (-1, 1)) * mx.reshape(freqs, (1, -1))
return mx.cos(theta), mx.sin(theta)
class Attention(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.num_attention_heads: int = args.num_attention_heads
self.num_key_value_heads: int = args.num_key_value_heads
self.repeats = self.num_attention_heads // self.num_key_value_heads
self.head_dim = args.hidden_size // args.num_attention_heads
self.scale = self.head_dim**-0.5
self.wq = nn.Linear(
args.hidden_size, args.num_attention_heads * self.head_dim, bias=False
)
self.wk = nn.Linear(
args.hidden_size, args.num_key_value_heads * self.head_dim, bias=False
)
self.wv = nn.Linear(
args.hidden_size, args.num_key_value_heads * self.head_dim, bias=False
)
self.wo = nn.Linear(
args.num_attention_heads * self.head_dim, args.hidden_size, bias=False
)
self.rope = LinearScalingRoPE(
self.head_dim,
rope_scaling_factor=args.rope_scaling_factor,
base=args.rope_theta,
)
def __call__(
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[Tuple[mx.array, mx.array]] = None,
) -> mx.array:
B, L, D = x.shape
queries, keys, values = self.wq(x), self.wk(x), self.wv(x)
# Prepare the queries, keys and values for the attention computation
queries = queries.reshape(B, L, self.num_attention_heads, -1).transpose(
0, 2, 1, 3
)
keys = keys.reshape(B, L, self.num_key_value_heads, -1).transpose(0, 2, 1, 3)
values = values.reshape(B, L, self.num_key_value_heads, -1).transpose(
0, 2, 1, 3
)
def repeat(a):
a = mx.concatenate([mx.expand_dims(a, 2)] * self.repeats, axis=2)
return a.reshape([B, self.num_attention_heads, L, -1])
keys, values = map(repeat, (keys, values))
if cache is not None:
key_cache, value_cache = cache
queries = self.rope(queries, offset=key_cache.shape[2])
keys = self.rope(keys, offset=key_cache.shape[2])
keys = mx.concatenate([key_cache, keys], axis=2)
values = mx.concatenate([value_cache, values], axis=2)
else:
queries = self.rope(queries)
keys = self.rope(keys)
scores = (queries * self.scale) @ keys.transpose(0, 1, 3, 2)
if mask is not None:
scores += mask
scores = mx.softmax(scores.astype(mx.float32), axis=-1).astype(scores.dtype)
output = (scores @ values).transpose(0, 2, 1, 3).reshape(B, L, -1)
return self.wo(output), (keys, values)
class FeedForward(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.w1 = nn.Linear(args.hidden_size, args.intermediate_size, bias=False)
self.w2 = nn.Linear(args.intermediate_size, args.hidden_size, bias=False)
self.w3 = nn.Linear(args.hidden_size, args.intermediate_size, bias=False)
def __call__(self, x) -> mx.array:
return self.w2(nn.silu(self.w1(x)) * self.w3(x))
class TransformerBlock(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.attention = Attention(args)
self.feed_forward = FeedForward(args=args)
self.attention_norm = RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
self.ffn_norm = RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
def __call__(
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[Tuple[mx.array, mx.array]] = None,
) -> mx.array:
r, cache = self.attention(self.attention_norm(x), mask, cache)
h = x + r
r = self.feed_forward(self.ffn_norm(h))
out = h + r
return out, cache
class DeepseekCoder(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.args = args
self.vocab_size = args.vocab_size
self.tok_embeddings = nn.Embedding(args.vocab_size, args.hidden_size)
self.layers = [
TransformerBlock(args=args) for _ in range(args.num_hidden_layers)
]
self.norm = RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
self.output = nn.Linear(args.hidden_size, args.vocab_size, bias=False)
def __call__(self, x, mask=None, cache=None):
x = self.tok_embeddings(x)
mask = None
T = x.shape[1]
if T > 1:
mask = nn.MultiHeadAttention.create_additive_causal_mask(T)
mask = mask.astype(x.dtype)
if cache is None:
cache = [None] * len(self.layers)
for e, layer in enumerate(self.layers):
x, cache[e] = layer(x, mask, cache[e])
x = self.norm(x)
return self.output(x), cache
def generate(
prompt: mx.array,
model: DeepseekCoder,
temp: float = 0.0,
):
def sample(logits):
if temp == 0:
return mx.argmax(logits, axis=-1)
else:
return mx.random.categorical(logits * (1 / temp))
y = prompt
cache = None
while True:
logits, cache = model(y[None], cache=cache)
logits = logits[:, -1, :]
y = sample(logits)
yield y
def load_model(model_path: str):
model_path = Path(model_path)
with open(model_path / "config.json", "r") as f:
config = json.load(f)
config.pop("model_type")
quantization = config.pop("quantization", None)
model_args = ModelArgs(**config)
model = DeepseekCoder(model_args)
weights = mx.load(str(model_path / "weights.npz"))
if quantization := config.get("quantization", False):
nn.QuantizedLinear.quantize_module(model, **quantization)
model.update(tree_unflatten(list(weights.items())))
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
return model, tokenizer
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Deepseek coder inference script")
parser.add_argument(
"--model-path",
type=str,
default="mlx_model",
help="The path to the mlx model weights, tokenizer, and config",
)
parser.add_argument(
"--prompt",
help="The message to be processed by the model",
default="### Instruction: \nwrite a quick sort algorithm in python.\n### Response: \n",
)
parser.add_argument(
"--max-tokens",
"-m",
type=int,
default=100,
help="Maximum number of tokens to generate",
)
parser.add_argument(
"--temp",
help="The sampling temperature.",
type=float,
default=0.6,
)
parser.add_argument("--seed", type=int, default=0, help="The PRNG seed")
args = parser.parse_args()
mx.random.seed(args.seed)
model, tokenizer = load_model(args.model_path)
prompt = tokenizer(args.prompt, return_tensors="np", return_attention_mask=False,)[
"input_ids"
][0]
prompt = mx.array(prompt)
print(args.prompt, end="", flush=True)
tokens = []
skip = 0
for token, _ in zip(
generate(prompt, model, args.temp),
range(args.max_tokens),
):
if token == tokenizer.eos_token_id:
break
tokens.append(token.item())
s = tokenizer.decode(tokens)
print(s[skip:], end="", flush=True)
skip = len(s)
print(tokenizer.decode(tokens)[skip:], flush=True)

View File

@ -0,0 +1,4 @@
torch
mlx
numpy
transformers>=4.35

View File

@ -15,6 +15,7 @@ import torch
from llama import Llama, ModelArgs, sanitize_config
from mlx.utils import tree_flatten, tree_map, tree_unflatten
def llama(model_path):
SHARD_FIRST = ["wv", "wq", "wk", "w1", "w3", "output"]
SHARD_SECOND = ["tok_embeddings", "wo", "w2"]
@ -185,13 +186,13 @@ if __name__ == "__main__":
action="store_true",
)
parser.add_argument(
"--q_group_size",
"--q-group-size",
help="Group size for quantization.",
type=int,
default=64,
)
parser.add_argument(
"--q_bits",
"--q-bits",
help="Bits per weight for quantization.",
type=int,
default=4,

View File

@ -57,13 +57,13 @@ if __name__ == "__main__":
action="store_true",
)
parser.add_argument(
"--q_group_size",
"--q-group-size",
help="Group size for quantization.",
type=int,
default=64,
)
parser.add_argument(
"--q_bits",
"--q-bits",
help="Bits per weight for quantization.",
type=int,
default=4,

View File

@ -110,13 +110,13 @@ if __name__ == "__main__":
action="store_true",
)
parser.add_argument(
"--q_group_size",
"--q-group-size",
help="Group size for quantization.",
type=int,
default=64,
)
parser.add_argument(
"--q_bits",
"--q-bits",
help="Bits per weight for quantization.",
type=int,
default=4,

View File

@ -56,13 +56,13 @@ def convert():
action="store_true",
)
parser.add_argument(
"--q_group_size",
"--q-group-size",
help="Group size for quantization.",
type=int,
default=64,
)
parser.add_argument(
"--q_bits",
"--q-bits",
help="Bits per weight for quantization.",
type=int,
default=4,

View File

@ -60,7 +60,12 @@ def convert(args):
args.model, trust_remote_code=True, torch_dtype=torch.float16
)
state_dict = model.state_dict()
weights = {replace_key(k): (v.numpy() if v.dtype != torch.bfloat16 else v.to(torch.float32).numpy()) for k, v in state_dict.items()}
weights = {
replace_key(k): (
v.numpy() if v.dtype != torch.bfloat16 else v.to(torch.float32).numpy()
)
for k, v in state_dict.items()
}
config = model.config.to_dict()
if args.quantize:
@ -95,13 +100,13 @@ if __name__ == "__main__":
action="store_true",
)
parser.add_argument(
"--q_group_size",
"--q-group-size",
help="Group size for quantization.",
type=int,
default=64,
)
parser.add_argument(
"--q_bits",
"--q-bits",
help="Bits per weight for quantization.",
type=int,
default=4,