From e17e07002a4db47b15ce5cace10fc12c8b50af83 Mon Sep 17 00:00:00 2001 From: Anchen Date: Fri, 22 Dec 2023 17:51:06 +1100 Subject: [PATCH] feat: add example for deepseek coder --- llms/deepseek-coder/README.md | 30 +++ llms/deepseek-coder/convert.py | 84 +++++++ llms/deepseek-coder/deepseek-coder.py | 324 ++++++++++++++++++++++++++ llms/deepseek-coder/requirements.txt | 4 + 4 files changed, 442 insertions(+) create mode 100644 llms/deepseek-coder/README.md create mode 100644 llms/deepseek-coder/convert.py create mode 100644 llms/deepseek-coder/deepseek-coder.py create mode 100644 llms/deepseek-coder/requirements.txt diff --git a/llms/deepseek-coder/README.md b/llms/deepseek-coder/README.md new file mode 100644 index 00000000..4d38664b --- /dev/null +++ b/llms/deepseek-coder/README.md @@ -0,0 +1,30 @@ +# Deepseek Coder + +Deepseek Coder is an advanced series of code language models based on LLama architecture, trained from scratch on a massive corpus of 2T tokens, with a unique composition of 87% code and 13% natural language in both English and Chinese. + +### Setup + +Install the dependencies: + +``` +pip install -r requirements.txt +``` + +Next, download and convert the model. +```sh +python convert.py --model-path --mlx-path +``` + +By default, the conversion script will save +the converted `weights.npz`, `tokenizer`, and `config.json` there in the mlx-path you speficied . + + +### Run + +Once you've converted the weights to MLX format, you can interact with the +Deepseek coder model: + +``` +python deepseek-coder.py --model-path --prompt "write a quick sort algorithm in python." +``` + diff --git a/llms/deepseek-coder/convert.py b/llms/deepseek-coder/convert.py new file mode 100644 index 00000000..689c3359 --- /dev/null +++ b/llms/deepseek-coder/convert.py @@ -0,0 +1,84 @@ +import argparse +from pathlib import Path +import json + +import numpy as np +import torch +from transformers import AutoModelForCausalLM, AutoTokenizer + + +def convert(args): + model_path = Path(args.model_path) + + mlx_path = Path(args.mlx_path) + mlx_path.mkdir(parents=True, exist_ok=True) + + model = AutoModelForCausalLM.from_pretrained( + str(model_path), trust_remote_code=True, torch_dtype=torch.float16 + ) + config = model.config.to_dict() + + state_dict = model.state_dict() + tokenizer = AutoTokenizer.from_pretrained(str(model_path), trust_remote_code=True) + + # things to change + # 1. there's no "model." in the weight names + state_dict = {k.replace("model.", ""): v for k, v in state_dict.items()} + + # 2. mlp is called feed_forward + state_dict = {k.replace("mlp", "feed_forward"): v for k, v in state_dict.items()} + + # 3. up_proj, down_proj, gate_proj + state_dict = {k.replace("down_proj", "w2"): v for k, v in state_dict.items()} + state_dict = {k.replace("up_proj", "w3"): v for k, v in state_dict.items()} + state_dict = {k.replace("gate_proj", "w1"): v for k, v in state_dict.items()} + + # 4. layernorms + state_dict = { + k.replace("input_layernorm", "attention_norm"): v for k, v in state_dict.items() + } + state_dict = { + k.replace("post_attention_layernorm", "ffn_norm"): v + for k, v in state_dict.items() + } + + # 5. lm head + state_dict = {k.replace("lm_head", "output"): v for k, v in state_dict.items()} + + # 6. token emb + state_dict = { + k.replace("embed_tokens", "tok_embeddings"): v for k, v in state_dict.items() + } + + # 7. attention + state_dict = {k.replace("self_attn", "attention"): v for k, v in state_dict.items()} + state_dict = {k.replace("q_proj", "wq"): v for k, v in state_dict.items()} + state_dict = {k.replace("k_proj", "wk"): v for k, v in state_dict.items()} + state_dict = {k.replace("v_proj", "wv"): v for k, v in state_dict.items()} + state_dict = {k.replace("o_proj", "wo"): v for k, v in state_dict.items()} + + weights = {k: v.numpy() for k, v in state_dict.items()} + + np.savez(str(mlx_path / "weights.npz"), **weights) + tokenizer.save_pretrained(mlx_path) + with open(mlx_path / "config.json", "w") as f: + json.dump(config, f, indent=4) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Convert Deepseek coder model to npz") + + parser.add_argument( + "--model-path", + help="The huggingface model to be converted", + default="deepseek-ai/deepseek-coder-6.7b-instruct", + ) + + parser.add_argument( + "--mlx-path", + type=str, + default="mlx_model", + help="The path to save the MLX model.", + ) + args = parser.parse_args() + convert(args) diff --git a/llms/deepseek-coder/deepseek-coder.py b/llms/deepseek-coder/deepseek-coder.py new file mode 100644 index 00000000..de52eba3 --- /dev/null +++ b/llms/deepseek-coder/deepseek-coder.py @@ -0,0 +1,324 @@ +import argparse +import math +import json +from dataclasses import dataclass +from pathlib import Path +from typing import Optional, Tuple + +import mlx.core as mx +import mlx.nn as nn +from mlx.utils import tree_unflatten +from transformers import AutoTokenizer + + +@dataclass +class ModelArgs: + hidden_size: int = 4096 + num_attention_heads: int = 32 + num_hidden_layers: int = 32 + num_key_value_heads: int = 32 + max_position_embeddings: int = 16384 + layer_norm_epsilon: float = 1e-6 + intermediate_size: int = 11008 + rope_theta: float = 100000 + rope_scaling_factor: float = 4.0 + vocab_size: int = 32256 + + +class RMSNorm(nn.Module): + def __init__(self, dims: int, eps: float = 1e-5): + super().__init__() + self.weight = mx.ones((dims,)) + self.eps = eps + + def _norm(self, x): + return x * mx.rsqrt(x.square().mean(-1, keepdims=True) + self.eps) + + def __call__(self, x): + output = self._norm(x.astype(mx.float32)).astype(x.dtype) + return self.weight * output + + +class LinearScalingRoPE(nn.RoPE): + def __init__( + self, dims: int, rope_scaling_factor: float = 4.0, base: float = 10000 + ): + super().__init__(dims) + self.base = base + self.rope_scaling_factor = rope_scaling_factor + + def __call__(self, x, offset: int = 0): + shape = x.shape + x = mx.reshape(x, (-1, shape[-2], shape[-1])) + N = x.shape[1] + offset + costheta, sintheta = LinearScalingRoPE.create_cos_sin_theta( + self.rope_scaling_factor, + N, + self.dims, + offset=offset, + base=self.base, + dtype=x.dtype, + ) + + rx = self._compute_rope(costheta, sintheta, x) + + return mx.reshape(rx, shape) + + @staticmethod + def create_cos_sin_theta( + rope_scaling_factor: float, + N: int, + D: int, + offset: int = 0, + base: float = 10000, + dtype=mx.float32, + ): + D = D // 2 + positions = mx.arange(offset, N, dtype=dtype) + positions = positions / rope_scaling_factor + freqs = mx.exp(-mx.arange(0.0, D, dtype=dtype) * (math.log(base) / D)) + theta = mx.reshape(positions, (-1, 1)) * mx.reshape(freqs, (1, -1)) + return mx.cos(theta), mx.sin(theta) + + +class Attention(nn.Module): + def __init__(self, args: ModelArgs): + super().__init__() + self.num_attention_heads: int = args.num_attention_heads + self.num_key_value_heads: int = args.num_key_value_heads + self.repeats = self.num_attention_heads // self.num_key_value_heads + + self.head_dim = args.hidden_size // args.num_attention_heads + + self.scale = self.head_dim**-0.5 + + self.wq = nn.Linear( + args.hidden_size, args.num_attention_heads * self.head_dim, bias=False + ) + self.wk = nn.Linear( + args.hidden_size, args.num_key_value_heads * self.head_dim, bias=False + ) + self.wv = nn.Linear( + args.hidden_size, args.num_key_value_heads * self.head_dim, bias=False + ) + self.wo = nn.Linear( + args.num_attention_heads * self.head_dim, args.hidden_size, bias=False + ) + self.rope = LinearScalingRoPE( + self.head_dim, rope_scaling_factor=4.0, base=args.rope_theta + ) + + def __call__( + self, + x: mx.array, + mask: Optional[mx.array] = None, + cache: Optional[Tuple[mx.array, mx.array]] = None, + ) -> mx.array: + B, L, D = x.shape + + queries, keys, values = self.wq(x), self.wk(x), self.wv(x) + + # Prepare the queries, keys and values for the attention computation + queries = queries.reshape(B, L, self.num_attention_heads, -1).transpose( + 0, 2, 1, 3 + ) + keys = keys.reshape(B, L, self.num_key_value_heads, -1).transpose(0, 2, 1, 3) + values = values.reshape(B, L, self.num_key_value_heads, -1).transpose( + 0, 2, 1, 3 + ) + + def repeat(a): + a = mx.concatenate([mx.expand_dims(a, 2)] * self.repeats, axis=2) + return a.reshape([B, self.num_attention_heads, L, -1]) + + keys, values = map(repeat, (keys, values)) + + if cache is not None: + key_cache, value_cache = cache + queries = self.rope(queries, offset=key_cache.shape[2]) + keys = self.rope(keys, offset=key_cache.shape[2]) + keys = mx.concatenate([key_cache, keys], axis=2) + values = mx.concatenate([value_cache, values], axis=2) + else: + queries = self.rope(queries) + keys = self.rope(keys) + + scores = (queries * self.scale) @ keys.transpose(0, 1, 3, 2) + if mask is not None: + scores += mask + scores = mx.softmax(scores.astype(mx.float32), axis=-1).astype(scores.dtype) + output = (scores @ values).transpose(0, 2, 1, 3).reshape(B, L, -1) + return self.wo(output), (keys, values) + + +class FeedForward(nn.Module): + def __init__(self, args: ModelArgs): + super().__init__() + self.w1 = nn.Linear(args.hidden_size, args.intermediate_size, bias=False) + self.w2 = nn.Linear(args.intermediate_size, args.hidden_size, bias=False) + self.w3 = nn.Linear(args.hidden_size, args.intermediate_size, bias=False) + + def __call__(self, x) -> mx.array: + return self.w2(nn.silu(self.w1(x)) * self.w3(x)) + + +class TransformerBlock(nn.Module): + def __init__(self, args: ModelArgs): + super().__init__() + self.attention = Attention(args) + self.feed_forward = FeedForward(args=args) + self.attention_norm = RMSNorm(args.hidden_size, eps=args.layer_norm_epsilon) + self.ffn_norm = RMSNorm(args.hidden_size, eps=args.layer_norm_epsilon) + + def __call__( + self, + x: mx.array, + mask: Optional[mx.array] = None, + cache: Optional[Tuple[mx.array, mx.array]] = None, + ) -> mx.array: + r, cache = self.attention(self.attention_norm(x), mask, cache) + h = x + r + r = self.feed_forward(self.ffn_norm(h)) + out = h + r + return out, cache + + +class DeepseekCoder(nn.Module): + def __init__(self, args: ModelArgs): + super().__init__() + self.args = args + self.vocab_size = args.vocab_size + self.tok_embeddings = nn.Embedding(args.vocab_size, args.hidden_size) + self.layers = [ + TransformerBlock(args=args) for _ in range(args.num_hidden_layers) + ] + self.norm = RMSNorm(args.hidden_size, eps=args.layer_norm_epsilon) + self.output = nn.Linear(args.hidden_size, args.vocab_size, bias=False) + + def __call__(self, x, mask=None, cache=None): + x = self.tok_embeddings(x) + mask = None + T = x.shape[1] + if T > 1: + mask = nn.MultiHeadAttention.create_additive_causal_mask(T) + mask = mask.astype(x.dtype) + + if cache is None: + cache = [None] * len(self.layers) + + for e, layer in enumerate(self.layers): + x, cache[e] = layer(x, mask, cache[e]) + x = self.norm(x) + return self.output(x), cache + + +def generate(prompt: mx.array, model: DeepseekCoder, temp: 0.0): + def sample(logits): + if temp == 0: + return mx.argmax(logits, axis=-1) + else: + return mx.random.categorical(logits * (1 / temp)) + + logits, cache = model(prompt) + y = sample(logits[:, -1, :]) + yield y + + while True: + logits, cache = model(y[:, None], cache=cache) + y = sample(logits.squeeze(1)) + yield y + + +def load_model(model_path: str): + model_args = ModelArgs() + + model_path = Path(model_path) + with open(model_path / "config.json", "r") as f: + config = json.load(f) + model_args.vocab_size = config["vocab_size"] + model_args.hidden_size = config["hidden_size"] + model_args.num_attention_heads = config["num_attention_heads"] + model_args.num_key_value_heads = config["num_key_value_heads"] + model_args.num_hidden_layers = config["num_hidden_layers"] + model_args.max_position_embeddings = config["max_position_embeddings"] + model_args.layer_norm_epsilon = config["rms_norm_eps"] + model_args.intermediate_size = config["intermediate_size"] + model_args.rope_scaling_factor = config["rope_scaling"]["factor"] + + model = DeepseekCoder(model_args) + weights = mx.load(str(model_path / "weights.npz")) + if quantization := config.get("quantization", False): + nn.QuantizedLinear.quantize_module(model, **quantization) + model.update(tree_unflatten(list(weights.items()))) + + tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) + return model, tokenizer + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Deepseek coder inference script") + parser.add_argument( + "--model-path", + type=str, + default="mlx_model", + help="The path to the mlx model weights, tokenizer and config", + ) + + parser.add_argument( + "--prompt", + help="The message to be processed by the model", + default="### Instruction: \nwrite a quick sort algorithm in python.\n### Response: \n", + ) + parser.add_argument( + "--max-tokens", + "-m", + type=int, + default=500, + help="Maximum number of tokens to generate", + ) + parser.add_argument( + "--temp", + help="The sampling temperature.", + type=float, + default=0.6, + ) + parser.add_argument("--seed", type=int, default=0, help="The PRNG seed") + args = parser.parse_args() + + mx.random.seed(args.seed) + + model, tokenizer = load_model(args.model_path) + + prompt = tokenizer( + args.prompt, + return_tensors="np", + return_attention_mask=False, + )["input_ids"] + + prompt = mx.array(prompt) + + print(args.prompt, end="", flush=True) + + tokens = [] + for token, _ in zip(generate(prompt, model, args.temp), range(args.max_tokens)): + tokens.append(token) + + if (len(tokens) % 10) == 0: + mx.eval(tokens) + eos_index = next( + (i for i, t in enumerate(tokens) if t.item() == tokenizer.eos_token_id), + None, + ) + + if eos_index is not None: + tokens = tokens[:eos_index] + + s = tokenizer.decode([t.item() for t in tokens]) + print(s, end="", flush=True) + tokens = [] + if eos_index is not None: + break + + mx.eval(tokens) + s = tokenizer.decode([t.item() for t in tokens]) + print(s, flush=True) diff --git a/llms/deepseek-coder/requirements.txt b/llms/deepseek-coder/requirements.txt new file mode 100644 index 00000000..3417c23b --- /dev/null +++ b/llms/deepseek-coder/requirements.txt @@ -0,0 +1,4 @@ +torch +mlx +numpy +transformers>=4.35 \ No newline at end of file