feat: add example for deepseek coder

This commit is contained in:
Anchen 2023-12-22 17:51:06 +11:00 committed by Awni Hannun
parent 50fceb1a28
commit e17e07002a
4 changed files with 442 additions and 0 deletions

View File

@ -0,0 +1,30 @@
# Deepseek Coder
Deepseek Coder is an advanced series of code language models based on LLama architecture, trained from scratch on a massive corpus of 2T tokens, with a unique composition of 87% code and 13% natural language in both English and Chinese.
### Setup
Install the dependencies:
```
pip install -r requirements.txt
```
Next, download and convert the model.
```sh
python convert.py --model-path <path_to_huggingface_model> --mlx-path <path_to_save_converted_model>
```
By default, the conversion script will save
the converted `weights.npz`, `tokenizer`, and `config.json` there in the mlx-path you speficied .
### Run
Once you've converted the weights to MLX format, you can interact with the
Deepseek coder model:
```
python deepseek-coder.py --model-path <path_to_save_converted_model> --prompt "write a quick sort algorithm in python."
```

View File

@ -0,0 +1,84 @@
import argparse
from pathlib import Path
import json
import numpy as np
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
def convert(args):
model_path = Path(args.model_path)
mlx_path = Path(args.mlx_path)
mlx_path.mkdir(parents=True, exist_ok=True)
model = AutoModelForCausalLM.from_pretrained(
str(model_path), trust_remote_code=True, torch_dtype=torch.float16
)
config = model.config.to_dict()
state_dict = model.state_dict()
tokenizer = AutoTokenizer.from_pretrained(str(model_path), trust_remote_code=True)
# things to change
# 1. there's no "model." in the weight names
state_dict = {k.replace("model.", ""): v for k, v in state_dict.items()}
# 2. mlp is called feed_forward
state_dict = {k.replace("mlp", "feed_forward"): v for k, v in state_dict.items()}
# 3. up_proj, down_proj, gate_proj
state_dict = {k.replace("down_proj", "w2"): v for k, v in state_dict.items()}
state_dict = {k.replace("up_proj", "w3"): v for k, v in state_dict.items()}
state_dict = {k.replace("gate_proj", "w1"): v for k, v in state_dict.items()}
# 4. layernorms
state_dict = {
k.replace("input_layernorm", "attention_norm"): v for k, v in state_dict.items()
}
state_dict = {
k.replace("post_attention_layernorm", "ffn_norm"): v
for k, v in state_dict.items()
}
# 5. lm head
state_dict = {k.replace("lm_head", "output"): v for k, v in state_dict.items()}
# 6. token emb
state_dict = {
k.replace("embed_tokens", "tok_embeddings"): v for k, v in state_dict.items()
}
# 7. attention
state_dict = {k.replace("self_attn", "attention"): v for k, v in state_dict.items()}
state_dict = {k.replace("q_proj", "wq"): v for k, v in state_dict.items()}
state_dict = {k.replace("k_proj", "wk"): v for k, v in state_dict.items()}
state_dict = {k.replace("v_proj", "wv"): v for k, v in state_dict.items()}
state_dict = {k.replace("o_proj", "wo"): v for k, v in state_dict.items()}
weights = {k: v.numpy() for k, v in state_dict.items()}
np.savez(str(mlx_path / "weights.npz"), **weights)
tokenizer.save_pretrained(mlx_path)
with open(mlx_path / "config.json", "w") as f:
json.dump(config, f, indent=4)
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Convert Deepseek coder model to npz")
parser.add_argument(
"--model-path",
help="The huggingface model to be converted",
default="deepseek-ai/deepseek-coder-6.7b-instruct",
)
parser.add_argument(
"--mlx-path",
type=str,
default="mlx_model",
help="The path to save the MLX model.",
)
args = parser.parse_args()
convert(args)

View File

@ -0,0 +1,324 @@
import argparse
import math
import json
from dataclasses import dataclass
from pathlib import Path
from typing import Optional, Tuple
import mlx.core as mx
import mlx.nn as nn
from mlx.utils import tree_unflatten
from transformers import AutoTokenizer
@dataclass
class ModelArgs:
hidden_size: int = 4096
num_attention_heads: int = 32
num_hidden_layers: int = 32
num_key_value_heads: int = 32
max_position_embeddings: int = 16384
layer_norm_epsilon: float = 1e-6
intermediate_size: int = 11008
rope_theta: float = 100000
rope_scaling_factor: float = 4.0
vocab_size: int = 32256
class RMSNorm(nn.Module):
def __init__(self, dims: int, eps: float = 1e-5):
super().__init__()
self.weight = mx.ones((dims,))
self.eps = eps
def _norm(self, x):
return x * mx.rsqrt(x.square().mean(-1, keepdims=True) + self.eps)
def __call__(self, x):
output = self._norm(x.astype(mx.float32)).astype(x.dtype)
return self.weight * output
class LinearScalingRoPE(nn.RoPE):
def __init__(
self, dims: int, rope_scaling_factor: float = 4.0, base: float = 10000
):
super().__init__(dims)
self.base = base
self.rope_scaling_factor = rope_scaling_factor
def __call__(self, x, offset: int = 0):
shape = x.shape
x = mx.reshape(x, (-1, shape[-2], shape[-1]))
N = x.shape[1] + offset
costheta, sintheta = LinearScalingRoPE.create_cos_sin_theta(
self.rope_scaling_factor,
N,
self.dims,
offset=offset,
base=self.base,
dtype=x.dtype,
)
rx = self._compute_rope(costheta, sintheta, x)
return mx.reshape(rx, shape)
@staticmethod
def create_cos_sin_theta(
rope_scaling_factor: float,
N: int,
D: int,
offset: int = 0,
base: float = 10000,
dtype=mx.float32,
):
D = D // 2
positions = mx.arange(offset, N, dtype=dtype)
positions = positions / rope_scaling_factor
freqs = mx.exp(-mx.arange(0.0, D, dtype=dtype) * (math.log(base) / D))
theta = mx.reshape(positions, (-1, 1)) * mx.reshape(freqs, (1, -1))
return mx.cos(theta), mx.sin(theta)
class Attention(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.num_attention_heads: int = args.num_attention_heads
self.num_key_value_heads: int = args.num_key_value_heads
self.repeats = self.num_attention_heads // self.num_key_value_heads
self.head_dim = args.hidden_size // args.num_attention_heads
self.scale = self.head_dim**-0.5
self.wq = nn.Linear(
args.hidden_size, args.num_attention_heads * self.head_dim, bias=False
)
self.wk = nn.Linear(
args.hidden_size, args.num_key_value_heads * self.head_dim, bias=False
)
self.wv = nn.Linear(
args.hidden_size, args.num_key_value_heads * self.head_dim, bias=False
)
self.wo = nn.Linear(
args.num_attention_heads * self.head_dim, args.hidden_size, bias=False
)
self.rope = LinearScalingRoPE(
self.head_dim, rope_scaling_factor=4.0, base=args.rope_theta
)
def __call__(
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[Tuple[mx.array, mx.array]] = None,
) -> mx.array:
B, L, D = x.shape
queries, keys, values = self.wq(x), self.wk(x), self.wv(x)
# Prepare the queries, keys and values for the attention computation
queries = queries.reshape(B, L, self.num_attention_heads, -1).transpose(
0, 2, 1, 3
)
keys = keys.reshape(B, L, self.num_key_value_heads, -1).transpose(0, 2, 1, 3)
values = values.reshape(B, L, self.num_key_value_heads, -1).transpose(
0, 2, 1, 3
)
def repeat(a):
a = mx.concatenate([mx.expand_dims(a, 2)] * self.repeats, axis=2)
return a.reshape([B, self.num_attention_heads, L, -1])
keys, values = map(repeat, (keys, values))
if cache is not None:
key_cache, value_cache = cache
queries = self.rope(queries, offset=key_cache.shape[2])
keys = self.rope(keys, offset=key_cache.shape[2])
keys = mx.concatenate([key_cache, keys], axis=2)
values = mx.concatenate([value_cache, values], axis=2)
else:
queries = self.rope(queries)
keys = self.rope(keys)
scores = (queries * self.scale) @ keys.transpose(0, 1, 3, 2)
if mask is not None:
scores += mask
scores = mx.softmax(scores.astype(mx.float32), axis=-1).astype(scores.dtype)
output = (scores @ values).transpose(0, 2, 1, 3).reshape(B, L, -1)
return self.wo(output), (keys, values)
class FeedForward(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.w1 = nn.Linear(args.hidden_size, args.intermediate_size, bias=False)
self.w2 = nn.Linear(args.intermediate_size, args.hidden_size, bias=False)
self.w3 = nn.Linear(args.hidden_size, args.intermediate_size, bias=False)
def __call__(self, x) -> mx.array:
return self.w2(nn.silu(self.w1(x)) * self.w3(x))
class TransformerBlock(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.attention = Attention(args)
self.feed_forward = FeedForward(args=args)
self.attention_norm = RMSNorm(args.hidden_size, eps=args.layer_norm_epsilon)
self.ffn_norm = RMSNorm(args.hidden_size, eps=args.layer_norm_epsilon)
def __call__(
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[Tuple[mx.array, mx.array]] = None,
) -> mx.array:
r, cache = self.attention(self.attention_norm(x), mask, cache)
h = x + r
r = self.feed_forward(self.ffn_norm(h))
out = h + r
return out, cache
class DeepseekCoder(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.args = args
self.vocab_size = args.vocab_size
self.tok_embeddings = nn.Embedding(args.vocab_size, args.hidden_size)
self.layers = [
TransformerBlock(args=args) for _ in range(args.num_hidden_layers)
]
self.norm = RMSNorm(args.hidden_size, eps=args.layer_norm_epsilon)
self.output = nn.Linear(args.hidden_size, args.vocab_size, bias=False)
def __call__(self, x, mask=None, cache=None):
x = self.tok_embeddings(x)
mask = None
T = x.shape[1]
if T > 1:
mask = nn.MultiHeadAttention.create_additive_causal_mask(T)
mask = mask.astype(x.dtype)
if cache is None:
cache = [None] * len(self.layers)
for e, layer in enumerate(self.layers):
x, cache[e] = layer(x, mask, cache[e])
x = self.norm(x)
return self.output(x), cache
def generate(prompt: mx.array, model: DeepseekCoder, temp: 0.0):
def sample(logits):
if temp == 0:
return mx.argmax(logits, axis=-1)
else:
return mx.random.categorical(logits * (1 / temp))
logits, cache = model(prompt)
y = sample(logits[:, -1, :])
yield y
while True:
logits, cache = model(y[:, None], cache=cache)
y = sample(logits.squeeze(1))
yield y
def load_model(model_path: str):
model_args = ModelArgs()
model_path = Path(model_path)
with open(model_path / "config.json", "r") as f:
config = json.load(f)
model_args.vocab_size = config["vocab_size"]
model_args.hidden_size = config["hidden_size"]
model_args.num_attention_heads = config["num_attention_heads"]
model_args.num_key_value_heads = config["num_key_value_heads"]
model_args.num_hidden_layers = config["num_hidden_layers"]
model_args.max_position_embeddings = config["max_position_embeddings"]
model_args.layer_norm_epsilon = config["rms_norm_eps"]
model_args.intermediate_size = config["intermediate_size"]
model_args.rope_scaling_factor = config["rope_scaling"]["factor"]
model = DeepseekCoder(model_args)
weights = mx.load(str(model_path / "weights.npz"))
if quantization := config.get("quantization", False):
nn.QuantizedLinear.quantize_module(model, **quantization)
model.update(tree_unflatten(list(weights.items())))
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
return model, tokenizer
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Deepseek coder inference script")
parser.add_argument(
"--model-path",
type=str,
default="mlx_model",
help="The path to the mlx model weights, tokenizer and config",
)
parser.add_argument(
"--prompt",
help="The message to be processed by the model",
default="### Instruction: \nwrite a quick sort algorithm in python.\n### Response: \n",
)
parser.add_argument(
"--max-tokens",
"-m",
type=int,
default=500,
help="Maximum number of tokens to generate",
)
parser.add_argument(
"--temp",
help="The sampling temperature.",
type=float,
default=0.6,
)
parser.add_argument("--seed", type=int, default=0, help="The PRNG seed")
args = parser.parse_args()
mx.random.seed(args.seed)
model, tokenizer = load_model(args.model_path)
prompt = tokenizer(
args.prompt,
return_tensors="np",
return_attention_mask=False,
)["input_ids"]
prompt = mx.array(prompt)
print(args.prompt, end="", flush=True)
tokens = []
for token, _ in zip(generate(prompt, model, args.temp), range(args.max_tokens)):
tokens.append(token)
if (len(tokens) % 10) == 0:
mx.eval(tokens)
eos_index = next(
(i for i, t in enumerate(tokens) if t.item() == tokenizer.eos_token_id),
None,
)
if eos_index is not None:
tokens = tokens[:eos_index]
s = tokenizer.decode([t.item() for t in tokens])
print(s, end="", flush=True)
tokens = []
if eos_index is not None:
break
mx.eval(tokens)
s = tokenizer.decode([t.item() for t in tokens])
print(s, flush=True)

View File

@ -0,0 +1,4 @@
torch
mlx
numpy
transformers>=4.35