From e17e07002a4db47b15ce5cace10fc12c8b50af83 Mon Sep 17 00:00:00 2001
From: Anchen
Date: Fri, 22 Dec 2023 17:51:06 +1100
Subject: [PATCH] feat: add example for deepseek coder
---
llms/deepseek-coder/README.md | 30 +++
llms/deepseek-coder/convert.py | 84 +++++++
llms/deepseek-coder/deepseek-coder.py | 324 ++++++++++++++++++++++++++
llms/deepseek-coder/requirements.txt | 4 +
4 files changed, 442 insertions(+)
create mode 100644 llms/deepseek-coder/README.md
create mode 100644 llms/deepseek-coder/convert.py
create mode 100644 llms/deepseek-coder/deepseek-coder.py
create mode 100644 llms/deepseek-coder/requirements.txt
diff --git a/llms/deepseek-coder/README.md b/llms/deepseek-coder/README.md
new file mode 100644
index 00000000..4d38664b
--- /dev/null
+++ b/llms/deepseek-coder/README.md
@@ -0,0 +1,30 @@
+# Deepseek Coder
+
+Deepseek Coder is an advanced series of code language models based on LLama architecture, trained from scratch on a massive corpus of 2T tokens, with a unique composition of 87% code and 13% natural language in both English and Chinese.
+
+### Setup
+
+Install the dependencies:
+
+```
+pip install -r requirements.txt
+```
+
+Next, download and convert the model.
+```sh
+python convert.py --model-path --mlx-path
+```
+
+By default, the conversion script will save
+the converted `weights.npz`, `tokenizer`, and `config.json` there in the mlx-path you speficied .
+
+
+### Run
+
+Once you've converted the weights to MLX format, you can interact with the
+Deepseek coder model:
+
+```
+python deepseek-coder.py --model-path --prompt "write a quick sort algorithm in python."
+```
+
diff --git a/llms/deepseek-coder/convert.py b/llms/deepseek-coder/convert.py
new file mode 100644
index 00000000..689c3359
--- /dev/null
+++ b/llms/deepseek-coder/convert.py
@@ -0,0 +1,84 @@
+import argparse
+from pathlib import Path
+import json
+
+import numpy as np
+import torch
+from transformers import AutoModelForCausalLM, AutoTokenizer
+
+
+def convert(args):
+ model_path = Path(args.model_path)
+
+ mlx_path = Path(args.mlx_path)
+ mlx_path.mkdir(parents=True, exist_ok=True)
+
+ model = AutoModelForCausalLM.from_pretrained(
+ str(model_path), trust_remote_code=True, torch_dtype=torch.float16
+ )
+ config = model.config.to_dict()
+
+ state_dict = model.state_dict()
+ tokenizer = AutoTokenizer.from_pretrained(str(model_path), trust_remote_code=True)
+
+ # things to change
+ # 1. there's no "model." in the weight names
+ state_dict = {k.replace("model.", ""): v for k, v in state_dict.items()}
+
+ # 2. mlp is called feed_forward
+ state_dict = {k.replace("mlp", "feed_forward"): v for k, v in state_dict.items()}
+
+ # 3. up_proj, down_proj, gate_proj
+ state_dict = {k.replace("down_proj", "w2"): v for k, v in state_dict.items()}
+ state_dict = {k.replace("up_proj", "w3"): v for k, v in state_dict.items()}
+ state_dict = {k.replace("gate_proj", "w1"): v for k, v in state_dict.items()}
+
+ # 4. layernorms
+ state_dict = {
+ k.replace("input_layernorm", "attention_norm"): v for k, v in state_dict.items()
+ }
+ state_dict = {
+ k.replace("post_attention_layernorm", "ffn_norm"): v
+ for k, v in state_dict.items()
+ }
+
+ # 5. lm head
+ state_dict = {k.replace("lm_head", "output"): v for k, v in state_dict.items()}
+
+ # 6. token emb
+ state_dict = {
+ k.replace("embed_tokens", "tok_embeddings"): v for k, v in state_dict.items()
+ }
+
+ # 7. attention
+ state_dict = {k.replace("self_attn", "attention"): v for k, v in state_dict.items()}
+ state_dict = {k.replace("q_proj", "wq"): v for k, v in state_dict.items()}
+ state_dict = {k.replace("k_proj", "wk"): v for k, v in state_dict.items()}
+ state_dict = {k.replace("v_proj", "wv"): v for k, v in state_dict.items()}
+ state_dict = {k.replace("o_proj", "wo"): v for k, v in state_dict.items()}
+
+ weights = {k: v.numpy() for k, v in state_dict.items()}
+
+ np.savez(str(mlx_path / "weights.npz"), **weights)
+ tokenizer.save_pretrained(mlx_path)
+ with open(mlx_path / "config.json", "w") as f:
+ json.dump(config, f, indent=4)
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser(description="Convert Deepseek coder model to npz")
+
+ parser.add_argument(
+ "--model-path",
+ help="The huggingface model to be converted",
+ default="deepseek-ai/deepseek-coder-6.7b-instruct",
+ )
+
+ parser.add_argument(
+ "--mlx-path",
+ type=str,
+ default="mlx_model",
+ help="The path to save the MLX model.",
+ )
+ args = parser.parse_args()
+ convert(args)
diff --git a/llms/deepseek-coder/deepseek-coder.py b/llms/deepseek-coder/deepseek-coder.py
new file mode 100644
index 00000000..de52eba3
--- /dev/null
+++ b/llms/deepseek-coder/deepseek-coder.py
@@ -0,0 +1,324 @@
+import argparse
+import math
+import json
+from dataclasses import dataclass
+from pathlib import Path
+from typing import Optional, Tuple
+
+import mlx.core as mx
+import mlx.nn as nn
+from mlx.utils import tree_unflatten
+from transformers import AutoTokenizer
+
+
+@dataclass
+class ModelArgs:
+ hidden_size: int = 4096
+ num_attention_heads: int = 32
+ num_hidden_layers: int = 32
+ num_key_value_heads: int = 32
+ max_position_embeddings: int = 16384
+ layer_norm_epsilon: float = 1e-6
+ intermediate_size: int = 11008
+ rope_theta: float = 100000
+ rope_scaling_factor: float = 4.0
+ vocab_size: int = 32256
+
+
+class RMSNorm(nn.Module):
+ def __init__(self, dims: int, eps: float = 1e-5):
+ super().__init__()
+ self.weight = mx.ones((dims,))
+ self.eps = eps
+
+ def _norm(self, x):
+ return x * mx.rsqrt(x.square().mean(-1, keepdims=True) + self.eps)
+
+ def __call__(self, x):
+ output = self._norm(x.astype(mx.float32)).astype(x.dtype)
+ return self.weight * output
+
+
+class LinearScalingRoPE(nn.RoPE):
+ def __init__(
+ self, dims: int, rope_scaling_factor: float = 4.0, base: float = 10000
+ ):
+ super().__init__(dims)
+ self.base = base
+ self.rope_scaling_factor = rope_scaling_factor
+
+ def __call__(self, x, offset: int = 0):
+ shape = x.shape
+ x = mx.reshape(x, (-1, shape[-2], shape[-1]))
+ N = x.shape[1] + offset
+ costheta, sintheta = LinearScalingRoPE.create_cos_sin_theta(
+ self.rope_scaling_factor,
+ N,
+ self.dims,
+ offset=offset,
+ base=self.base,
+ dtype=x.dtype,
+ )
+
+ rx = self._compute_rope(costheta, sintheta, x)
+
+ return mx.reshape(rx, shape)
+
+ @staticmethod
+ def create_cos_sin_theta(
+ rope_scaling_factor: float,
+ N: int,
+ D: int,
+ offset: int = 0,
+ base: float = 10000,
+ dtype=mx.float32,
+ ):
+ D = D // 2
+ positions = mx.arange(offset, N, dtype=dtype)
+ positions = positions / rope_scaling_factor
+ freqs = mx.exp(-mx.arange(0.0, D, dtype=dtype) * (math.log(base) / D))
+ theta = mx.reshape(positions, (-1, 1)) * mx.reshape(freqs, (1, -1))
+ return mx.cos(theta), mx.sin(theta)
+
+
+class Attention(nn.Module):
+ def __init__(self, args: ModelArgs):
+ super().__init__()
+ self.num_attention_heads: int = args.num_attention_heads
+ self.num_key_value_heads: int = args.num_key_value_heads
+ self.repeats = self.num_attention_heads // self.num_key_value_heads
+
+ self.head_dim = args.hidden_size // args.num_attention_heads
+
+ self.scale = self.head_dim**-0.5
+
+ self.wq = nn.Linear(
+ args.hidden_size, args.num_attention_heads * self.head_dim, bias=False
+ )
+ self.wk = nn.Linear(
+ args.hidden_size, args.num_key_value_heads * self.head_dim, bias=False
+ )
+ self.wv = nn.Linear(
+ args.hidden_size, args.num_key_value_heads * self.head_dim, bias=False
+ )
+ self.wo = nn.Linear(
+ args.num_attention_heads * self.head_dim, args.hidden_size, bias=False
+ )
+ self.rope = LinearScalingRoPE(
+ self.head_dim, rope_scaling_factor=4.0, base=args.rope_theta
+ )
+
+ def __call__(
+ self,
+ x: mx.array,
+ mask: Optional[mx.array] = None,
+ cache: Optional[Tuple[mx.array, mx.array]] = None,
+ ) -> mx.array:
+ B, L, D = x.shape
+
+ queries, keys, values = self.wq(x), self.wk(x), self.wv(x)
+
+ # Prepare the queries, keys and values for the attention computation
+ queries = queries.reshape(B, L, self.num_attention_heads, -1).transpose(
+ 0, 2, 1, 3
+ )
+ keys = keys.reshape(B, L, self.num_key_value_heads, -1).transpose(0, 2, 1, 3)
+ values = values.reshape(B, L, self.num_key_value_heads, -1).transpose(
+ 0, 2, 1, 3
+ )
+
+ def repeat(a):
+ a = mx.concatenate([mx.expand_dims(a, 2)] * self.repeats, axis=2)
+ return a.reshape([B, self.num_attention_heads, L, -1])
+
+ keys, values = map(repeat, (keys, values))
+
+ if cache is not None:
+ key_cache, value_cache = cache
+ queries = self.rope(queries, offset=key_cache.shape[2])
+ keys = self.rope(keys, offset=key_cache.shape[2])
+ keys = mx.concatenate([key_cache, keys], axis=2)
+ values = mx.concatenate([value_cache, values], axis=2)
+ else:
+ queries = self.rope(queries)
+ keys = self.rope(keys)
+
+ scores = (queries * self.scale) @ keys.transpose(0, 1, 3, 2)
+ if mask is not None:
+ scores += mask
+ scores = mx.softmax(scores.astype(mx.float32), axis=-1).astype(scores.dtype)
+ output = (scores @ values).transpose(0, 2, 1, 3).reshape(B, L, -1)
+ return self.wo(output), (keys, values)
+
+
+class FeedForward(nn.Module):
+ def __init__(self, args: ModelArgs):
+ super().__init__()
+ self.w1 = nn.Linear(args.hidden_size, args.intermediate_size, bias=False)
+ self.w2 = nn.Linear(args.intermediate_size, args.hidden_size, bias=False)
+ self.w3 = nn.Linear(args.hidden_size, args.intermediate_size, bias=False)
+
+ def __call__(self, x) -> mx.array:
+ return self.w2(nn.silu(self.w1(x)) * self.w3(x))
+
+
+class TransformerBlock(nn.Module):
+ def __init__(self, args: ModelArgs):
+ super().__init__()
+ self.attention = Attention(args)
+ self.feed_forward = FeedForward(args=args)
+ self.attention_norm = RMSNorm(args.hidden_size, eps=args.layer_norm_epsilon)
+ self.ffn_norm = RMSNorm(args.hidden_size, eps=args.layer_norm_epsilon)
+
+ def __call__(
+ self,
+ x: mx.array,
+ mask: Optional[mx.array] = None,
+ cache: Optional[Tuple[mx.array, mx.array]] = None,
+ ) -> mx.array:
+ r, cache = self.attention(self.attention_norm(x), mask, cache)
+ h = x + r
+ r = self.feed_forward(self.ffn_norm(h))
+ out = h + r
+ return out, cache
+
+
+class DeepseekCoder(nn.Module):
+ def __init__(self, args: ModelArgs):
+ super().__init__()
+ self.args = args
+ self.vocab_size = args.vocab_size
+ self.tok_embeddings = nn.Embedding(args.vocab_size, args.hidden_size)
+ self.layers = [
+ TransformerBlock(args=args) for _ in range(args.num_hidden_layers)
+ ]
+ self.norm = RMSNorm(args.hidden_size, eps=args.layer_norm_epsilon)
+ self.output = nn.Linear(args.hidden_size, args.vocab_size, bias=False)
+
+ def __call__(self, x, mask=None, cache=None):
+ x = self.tok_embeddings(x)
+ mask = None
+ T = x.shape[1]
+ if T > 1:
+ mask = nn.MultiHeadAttention.create_additive_causal_mask(T)
+ mask = mask.astype(x.dtype)
+
+ if cache is None:
+ cache = [None] * len(self.layers)
+
+ for e, layer in enumerate(self.layers):
+ x, cache[e] = layer(x, mask, cache[e])
+ x = self.norm(x)
+ return self.output(x), cache
+
+
+def generate(prompt: mx.array, model: DeepseekCoder, temp: 0.0):
+ def sample(logits):
+ if temp == 0:
+ return mx.argmax(logits, axis=-1)
+ else:
+ return mx.random.categorical(logits * (1 / temp))
+
+ logits, cache = model(prompt)
+ y = sample(logits[:, -1, :])
+ yield y
+
+ while True:
+ logits, cache = model(y[:, None], cache=cache)
+ y = sample(logits.squeeze(1))
+ yield y
+
+
+def load_model(model_path: str):
+ model_args = ModelArgs()
+
+ model_path = Path(model_path)
+ with open(model_path / "config.json", "r") as f:
+ config = json.load(f)
+ model_args.vocab_size = config["vocab_size"]
+ model_args.hidden_size = config["hidden_size"]
+ model_args.num_attention_heads = config["num_attention_heads"]
+ model_args.num_key_value_heads = config["num_key_value_heads"]
+ model_args.num_hidden_layers = config["num_hidden_layers"]
+ model_args.max_position_embeddings = config["max_position_embeddings"]
+ model_args.layer_norm_epsilon = config["rms_norm_eps"]
+ model_args.intermediate_size = config["intermediate_size"]
+ model_args.rope_scaling_factor = config["rope_scaling"]["factor"]
+
+ model = DeepseekCoder(model_args)
+ weights = mx.load(str(model_path / "weights.npz"))
+ if quantization := config.get("quantization", False):
+ nn.QuantizedLinear.quantize_module(model, **quantization)
+ model.update(tree_unflatten(list(weights.items())))
+
+ tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
+ return model, tokenizer
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser(description="Deepseek coder inference script")
+ parser.add_argument(
+ "--model-path",
+ type=str,
+ default="mlx_model",
+ help="The path to the mlx model weights, tokenizer and config",
+ )
+
+ parser.add_argument(
+ "--prompt",
+ help="The message to be processed by the model",
+ default="### Instruction: \nwrite a quick sort algorithm in python.\n### Response: \n",
+ )
+ parser.add_argument(
+ "--max-tokens",
+ "-m",
+ type=int,
+ default=500,
+ help="Maximum number of tokens to generate",
+ )
+ parser.add_argument(
+ "--temp",
+ help="The sampling temperature.",
+ type=float,
+ default=0.6,
+ )
+ parser.add_argument("--seed", type=int, default=0, help="The PRNG seed")
+ args = parser.parse_args()
+
+ mx.random.seed(args.seed)
+
+ model, tokenizer = load_model(args.model_path)
+
+ prompt = tokenizer(
+ args.prompt,
+ return_tensors="np",
+ return_attention_mask=False,
+ )["input_ids"]
+
+ prompt = mx.array(prompt)
+
+ print(args.prompt, end="", flush=True)
+
+ tokens = []
+ for token, _ in zip(generate(prompt, model, args.temp), range(args.max_tokens)):
+ tokens.append(token)
+
+ if (len(tokens) % 10) == 0:
+ mx.eval(tokens)
+ eos_index = next(
+ (i for i, t in enumerate(tokens) if t.item() == tokenizer.eos_token_id),
+ None,
+ )
+
+ if eos_index is not None:
+ tokens = tokens[:eos_index]
+
+ s = tokenizer.decode([t.item() for t in tokens])
+ print(s, end="", flush=True)
+ tokens = []
+ if eos_index is not None:
+ break
+
+ mx.eval(tokens)
+ s = tokenizer.decode([t.item() for t in tokens])
+ print(s, flush=True)
diff --git a/llms/deepseek-coder/requirements.txt b/llms/deepseek-coder/requirements.txt
new file mode 100644
index 00000000..3417c23b
--- /dev/null
+++ b/llms/deepseek-coder/requirements.txt
@@ -0,0 +1,4 @@
+torch
+mlx
+numpy
+transformers>=4.35
\ No newline at end of file