style /consistency changes to ease future integration

This commit is contained in:
Awni Hannun
2023-12-28 21:31:29 -08:00
parent a476d1909d
commit 9ff0a96ab0
8 changed files with 94 additions and 116 deletions

View File

@@ -1,6 +1,9 @@
# Deepseek Coder # Deepseek Coder
Deepseek Coder is an advanced series of code language models based on LLama architecture, trained from scratch on a massive corpus of 2T tokens, with a unique composition of 87% code and 13% natural language in both English and Chinese. Deepseek Coder is a family of code generating language models based on the
LLama architecture.[^1] The models were trained from scratch on a corpus of 2T
tokens, with a composition of 87% code and 13% natural language containing both
English and Chinese.
### Setup ### Setup
@@ -11,19 +14,23 @@ pip install -r requirements.txt
``` ```
Next, download and convert the model. Next, download and convert the model.
```sh ```sh
python convert.py --model-path <path_to_huggingface_model> --mlx-path <path_to_save_converted_model> python convert.py --hf-path <path_to_huggingface_model> --mlx-path <path_to_save_converted_model>
``` ```
To generate a 4-bit quantized model, use -q. For a full list of options:
To generate a 4-bit quantized model, use `-q`. For a full list of options run:
``` ```
python convert.py --help python convert.py --help
``` ```
This process retrieves the model from Hugging Face. The default model is deepseek-ai/deepseek-coder-6.7b-instruct. Check out the [Hugging Face page](https://huggingface.co/deepseek-ai) to see a list of available models. The converter downloads the model from Hugging Face. The default model is
`deepseek-ai/deepseek-coder-6.7b-instruct`. Check out the Hugging Face
page[^1] to see a list of available models.
By default, the conversion script will save By default, the conversion script will save the converted `weights.npz`,
the converted `weights.npz`, `tokenizer`, and `config.json` there in the mlx-path you speficied . `tokenizer`, and `config.json` in the path provided by `--mlx-path`.
### Run ### Run
@@ -35,3 +42,4 @@ Deepseek coder model:
python deepseek-coder.py --model-path <path_to_save_converted_model> --prompt "write a quick sort algorithm in python." python deepseek-coder.py --model-path <path_to_save_converted_model> --prompt "write a quick sort algorithm in python."
``` ```
[^1] For more information see the [Hugging Face page](https://huggingface.co/deepseek-ai).

View File

@@ -7,25 +7,16 @@ import mlx.core as mx
import mlx.nn as nn import mlx.nn as nn
import numpy as np import numpy as np
import torch import torch
from deepseek_coder import DeepseekCoder, ModelArgs
from mlx.utils import tree_flatten, tree_map, tree_unflatten from mlx.utils import tree_flatten, tree_map, tree_unflatten
from transformers import AutoModelForCausalLM, AutoTokenizer from transformers import AutoModelForCausalLM, AutoTokenizer
from deepseek_coder import ModelArgs, DeepseekCoder
def quantize(weights, config, args): def quantize(weights, config, args):
quantized_config = copy.deepcopy(config) quantized_config = copy.deepcopy(config)
# Load the model: # Load the model:
model_args = ModelArgs() model_args = ModelArgs(**config)
model_args.vocab_size = config["vocab_size"]
model_args.hidden_size = config["hidden_size"]
model_args.num_attention_heads = config["num_attention_heads"]
model_args.num_key_value_heads = config["num_key_value_heads"]
model_args.num_hidden_layers = config["num_hidden_layers"]
model_args.max_position_embeddings = config["max_position_embeddings"]
model_args.rms_norm_eps = config["rms_norm_eps"]
model_args.intermediate_size = config["intermediate_size"]
model_args.rope_scaling_factor = config["rope_scaling"]["factor"]
model = DeepseekCoder(model_args) model = DeepseekCoder(model_args)
weights = tree_map(mx.array, weights) weights = tree_map(mx.array, weights)
@@ -45,18 +36,15 @@ def quantize(weights, config, args):
def convert(args): def convert(args):
model_path = Path(args.model_path) hf_path = Path(args.hf_path)
mlx_path = Path(args.mlx_path)
mlx_path.mkdir(parents=True, exist_ok=True)
model = AutoModelForCausalLM.from_pretrained( model = AutoModelForCausalLM.from_pretrained(
str(model_path), trust_remote_code=True, torch_dtype=torch.float16 str(hf_path), trust_remote_code=True, torch_dtype=torch.float16
) )
config = model.config.to_dict() config = model.config.to_dict()
state_dict = model.state_dict() state_dict = model.state_dict()
tokenizer = AutoTokenizer.from_pretrained(str(model_path), trust_remote_code=True) tokenizer = AutoTokenizer.from_pretrained(str(hf_path), trust_remote_code=True)
# things to change # things to change
# 1. there's no "model." in the weight names # 1. there's no "model." in the weight names
@@ -96,25 +84,34 @@ def convert(args):
weights = {k: v.numpy() for k, v in state_dict.items()} weights = {k: v.numpy() for k, v in state_dict.items()}
if args.quantize: config["rope_scaling_factor"] = config["rope_scaling"]["factor"]
print("[INFO] Quantizing") keep_keys = set(
weights, config = quantize(weights, config, args) [
"vocab_size",
"hidden_size",
"num_attention_heads",
"num_key_value_heads",
"num_hidden_layers",
"max_position_embeddings",
"rms_norm_eps",
"intermediate_size",
"rope_scaling_factor",
]
)
for k in list(config.keys()):
if k not in keep_keys:
config.pop(k)
np.savez(str(mlx_path / "weights.npz"), **weights) return weights, config, tokenizer
tokenizer.save_pretrained(mlx_path)
with open(mlx_path / "config.json", "w") as f:
json.dump(config, f, indent=4)
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Convert Deepseek coder model to npz") parser = argparse.ArgumentParser(description="Convert Deepseek coder model to npz")
parser.add_argument( parser.add_argument(
"--model-path", "--hf-path",
help="The huggingface model to be converted", help="The huggingface model to be converted",
default="deepseek-ai/deepseek-coder-6.7b-instruct", default="deepseek-ai/deepseek-coder-6.7b-instruct",
) )
parser.add_argument( parser.add_argument(
"--mlx-path", "--mlx-path",
type=str, type=str,
@@ -128,16 +125,30 @@ if __name__ == "__main__":
action="store_true", action="store_true",
) )
parser.add_argument( parser.add_argument(
"--q_group_size", "--q-group-size",
help="Group size for quantization.", help="Group size for quantization.",
type=int, type=int,
default=64, default=64,
) )
parser.add_argument( parser.add_argument(
"--q_bits", "--q-bits",
help="Bits per weight for quantization.", help="Bits per weight for quantization.",
type=int, type=int,
default=4, default=4,
) )
args = parser.parse_args() args = parser.parse_args()
convert(args)
mlx_path = Path(args.mlx_path)
mlx_path.mkdir(parents=True, exist_ok=True)
weights, config, tokenizer = convert(args)
if args.quantize:
print("[INFO] Quantizing")
weights, config = quantize(weights, config, args)
np.savez(str(mlx_path / "weights.npz"), **weights)
tokenizer.save_pretrained(mlx_path)
with open(mlx_path / "config.json", "w") as f:
config["model_type"] = "deepseek_coder"
json.dump(config, f, indent=4)

View File

@@ -1,6 +1,6 @@
import argparse import argparse
import math
import json import json
import math
from dataclasses import dataclass from dataclasses import dataclass
from pathlib import Path from pathlib import Path
from typing import Optional, Tuple from typing import Optional, Tuple
@@ -214,22 +214,10 @@ class DeepseekCoder(nn.Module):
return self.output(x), cache return self.output(x), cache
def apply_repeat_penalty(logits, context, penalty):
if len(context) > 0:
indices = mx.array([token.item() for token in context])
selected_logists = logits[:, indices]
selected_logists = mx.where(
selected_logists < 0, selected_logists * penalty, selected_logists / penalty
)
logits[:, indices] = selected_logists
def generate( def generate(
prompt: mx.array, prompt: mx.array,
model: DeepseekCoder, model: DeepseekCoder,
temp: 0.0, temp: float = 0.0,
generated_tokens,
repetition_penalty,
): ):
def sample(logits): def sample(logits):
if temp == 0: if temp == 0:
@@ -237,34 +225,22 @@ def generate(
else: else:
return mx.random.categorical(logits * (1 / temp)) return mx.random.categorical(logits * (1 / temp))
logits, cache = model(prompt) y = prompt
y = sample(logits[:, -1, :]) cache = None
yield y
while True: while True:
logits, cache = model(y[:, None], cache=cache) logits, cache = model(y[None], cache=cache)
logits = logits.squeeze(1) logits = logits[:, -1, :]
if repetition_penalty is not None and repetition_penalty != 1.0:
apply_repeat_penalty(logits, generated_tokens, repetition_penalty)
y = sample(logits) y = sample(logits)
yield y yield y
def load_model(model_path: str): def load_model(model_path: str):
model_args = ModelArgs()
model_path = Path(model_path) model_path = Path(model_path)
with open(model_path / "config.json", "r") as f: with open(model_path / "config.json", "r") as f:
config = json.load(f) config = json.load(f)
model_args.vocab_size = config["vocab_size"] config.pop("model_type")
model_args.hidden_size = config["hidden_size"] quantization = config.pop("quantization", None)
model_args.num_attention_heads = config["num_attention_heads"] model_args = ModelArgs(**config)
model_args.num_key_value_heads = config["num_key_value_heads"]
model_args.num_hidden_layers = config["num_hidden_layers"]
model_args.max_position_embeddings = config["max_position_embeddings"]
model_args.rms_norm_eps = config["rms_norm_eps"]
model_args.intermediate_size = config["intermediate_size"]
model_args.rope_scaling_factor = config["rope_scaling"]["factor"]
model = DeepseekCoder(model_args) model = DeepseekCoder(model_args)
weights = mx.load(str(model_path / "weights.npz")) weights = mx.load(str(model_path / "weights.npz"))
@@ -282,9 +258,8 @@ if __name__ == "__main__":
"--model-path", "--model-path",
type=str, type=str,
default="mlx_model", default="mlx_model",
help="The path to the mlx model weights, tokenizer and config", help="The path to the mlx model weights, tokenizer, and config",
) )
parser.add_argument( parser.add_argument(
"--prompt", "--prompt",
help="The message to be processed by the model", help="The message to be processed by the model",
@@ -303,14 +278,6 @@ if __name__ == "__main__":
type=float, type=float,
default=0.6, default=0.6,
) )
parser.add_argument(
"--repetition-penalty",
help="The parameter for repetition penalty.",
type=float,
default=1.2,
)
parser.add_argument("--seed", type=int, default=0, help="The PRNG seed") parser.add_argument("--seed", type=int, default=0, help="The PRNG seed")
args = parser.parse_args() args = parser.parse_args()
@@ -318,39 +285,25 @@ if __name__ == "__main__":
model, tokenizer = load_model(args.model_path) model, tokenizer = load_model(args.model_path)
prompt = tokenizer( prompt = tokenizer(args.prompt, return_tensors="np", return_attention_mask=False,)[
args.prompt, "input_ids"
return_tensors="np", ][0]
return_attention_mask=False,
)["input_ids"]
prompt = mx.array(prompt) prompt = mx.array(prompt)
print(args.prompt, end="", flush=True) print(args.prompt, end="", flush=True)
tokens = [] tokens = []
skip = 0
for token, _ in zip( for token, _ in zip(
generate(prompt, model, args.temp, tokens, args.repetition_penalty), generate(prompt, model, args.temp),
range(args.max_tokens), range(args.max_tokens),
): ):
tokens.append(token) if token == tokenizer.eos_token_id:
break
tokens.append(token.item())
s = tokenizer.decode(tokens)
print(s[skip:], end="", flush=True)
skip = len(s)
if (len(tokens) % 10) == 0: print(tokenizer.decode(tokens)[skip:], flush=True)
mx.eval(tokens)
eos_index = next(
(i for i, t in enumerate(tokens) if t.item() == tokenizer.eos_token_id),
None,
)
if eos_index is not None:
tokens = tokens[:eos_index]
s = tokenizer.decode([t.item() for t in tokens])
print(s, end="", flush=True)
tokens = []
if eos_index is not None:
break
mx.eval(tokens)
s = tokenizer.decode([t.item() for t in tokens])
print(s, flush=True)

View File

@@ -15,6 +15,7 @@ import torch
from llama import Llama, ModelArgs, sanitize_config from llama import Llama, ModelArgs, sanitize_config
from mlx.utils import tree_flatten, tree_map, tree_unflatten from mlx.utils import tree_flatten, tree_map, tree_unflatten
def llama(model_path): def llama(model_path):
SHARD_FIRST = ["wv", "wq", "wk", "w1", "w3", "output"] SHARD_FIRST = ["wv", "wq", "wk", "w1", "w3", "output"]
SHARD_SECOND = ["tok_embeddings", "wo", "w2"] SHARD_SECOND = ["tok_embeddings", "wo", "w2"]
@@ -185,13 +186,13 @@ if __name__ == "__main__":
action="store_true", action="store_true",
) )
parser.add_argument( parser.add_argument(
"--q_group_size", "--q-group-size",
help="Group size for quantization.", help="Group size for quantization.",
type=int, type=int,
default=64, default=64,
) )
parser.add_argument( parser.add_argument(
"--q_bits", "--q-bits",
help="Bits per weight for quantization.", help="Bits per weight for quantization.",
type=int, type=int,
default=4, default=4,

View File

@@ -57,13 +57,13 @@ if __name__ == "__main__":
action="store_true", action="store_true",
) )
parser.add_argument( parser.add_argument(
"--q_group_size", "--q-group-size",
help="Group size for quantization.", help="Group size for quantization.",
type=int, type=int,
default=64, default=64,
) )
parser.add_argument( parser.add_argument(
"--q_bits", "--q-bits",
help="Bits per weight for quantization.", help="Bits per weight for quantization.",
type=int, type=int,
default=4, default=4,

View File

@@ -110,13 +110,13 @@ if __name__ == "__main__":
action="store_true", action="store_true",
) )
parser.add_argument( parser.add_argument(
"--q_group_size", "--q-group-size",
help="Group size for quantization.", help="Group size for quantization.",
type=int, type=int,
default=64, default=64,
) )
parser.add_argument( parser.add_argument(
"--q_bits", "--q-bits",
help="Bits per weight for quantization.", help="Bits per weight for quantization.",
type=int, type=int,
default=4, default=4,

View File

@@ -56,13 +56,13 @@ def convert():
action="store_true", action="store_true",
) )
parser.add_argument( parser.add_argument(
"--q_group_size", "--q-group-size",
help="Group size for quantization.", help="Group size for quantization.",
type=int, type=int,
default=64, default=64,
) )
parser.add_argument( parser.add_argument(
"--q_bits", "--q-bits",
help="Bits per weight for quantization.", help="Bits per weight for quantization.",
type=int, type=int,
default=4, default=4,

View File

@@ -60,7 +60,12 @@ def convert(args):
args.model, trust_remote_code=True, torch_dtype=torch.float16 args.model, trust_remote_code=True, torch_dtype=torch.float16
) )
state_dict = model.state_dict() state_dict = model.state_dict()
weights = {replace_key(k): (v.numpy() if v.dtype != torch.bfloat16 else v.to(torch.float32).numpy()) for k, v in state_dict.items()} weights = {
replace_key(k): (
v.numpy() if v.dtype != torch.bfloat16 else v.to(torch.float32).numpy()
)
for k, v in state_dict.items()
}
config = model.config.to_dict() config = model.config.to_dict()
if args.quantize: if args.quantize:
@@ -95,13 +100,13 @@ if __name__ == "__main__":
action="store_true", action="store_true",
) )
parser.add_argument( parser.add_argument(
"--q_group_size", "--q-group-size",
help="Group size for quantization.", help="Group size for quantization.",
type=int, type=int,
default=64, default=64,
) )
parser.add_argument( parser.add_argument(
"--q_bits", "--q-bits",
help="Bits per weight for quantization.", help="Bits per weight for quantization.",
type=int, type=int,
default=4, default=4,