chore: add workaround for quant

This commit is contained in:
Anchen 2024-01-01 14:56:21 +11:00
parent b763ad3829
commit d81dcad68f
2 changed files with 42 additions and 5 deletions

View File

@ -84,6 +84,23 @@ def convert(args):
state_dict = {k.replace("v_proj", "wv"): v for k, v in state_dict.items()}
state_dict = {k.replace("o_proj", "wo"): v for k, v in state_dict.items()}
# quantization workaround for kv proj
keys_to_delete =[]
new_state_dict ={}
for k, v in state_dict.items():
if k.endswith('wk.weight'):
prefix = k[:-len('wk.weight')]
wv = prefix +"wv.weight"
if wv in state_dict:
wkwv = torch.cat([v, state_dict[wv]],dim=0)
new_key = prefix + "wkwv.weight"
new_state_dict[new_key] = wkwv
keys_to_delete.extend([k,wv])
for key in keys_to_delete:
del state_dict[key]
state_dict.update(new_state_dict)
weights = {k: v.numpy() for k, v in state_dict.items()}
keep_keys = set(
@ -152,3 +169,4 @@ if __name__ == "__main__":
with open(mlx_path / "config.json", "w") as f:
config["model_type"] = "yayi"
json.dump(config, f, indent=4)

View File

@ -1,6 +1,7 @@
import argparse
import json
from dataclasses import dataclass
import math
from pathlib import Path
from typing import Optional, Tuple
@ -23,6 +24,21 @@ class ModelArgs:
rope_traditional: bool = False
class KVLinear(nn.Module):
def __init__(self, input_dims: int, output_dims: int):
super().__init__()
scale = math.sqrt(1 / input_dims)
self.weight = mx.random.uniform(
low=-scale,
high=scale,
shape=(output_dims, input_dims),
)
def __call__(self, x):
k, v = mx.split(x @ self.weight.T, 2, axis=-1)
return k, v
class RMSNorm(nn.Module):
def __init__(self, dims: int, eps: float = 1e-5):
super().__init__()
@ -49,8 +65,8 @@ class Attention(nn.Module):
self.wq = nn.Linear(
args.hidden_size, args.num_attention_heads * self.head_dim, bias=False
)
self.wk = nn.Linear(args.hidden_size, int(self.head_dim), bias=False)
self.wv = nn.Linear(args.hidden_size, int(self.head_dim), bias=False)
self.wkwv = KVLinear(args.hidden_size, int(self.head_dim) * 2)
self.wo = nn.Linear(
args.num_attention_heads * self.head_dim, args.hidden_size, bias=False
)
@ -66,7 +82,8 @@ class Attention(nn.Module):
) -> mx.array:
B, L, _ = x.shape
q, k, v = self.wq(x), self.wk(x), self.wv(x)
q = self.wq(x)
k, v = self.wkwv(x)
q = q.reshape(B, L, self.num_attention_heads, self.head_dim).transpose(
0, 2, 1, 3
@ -189,7 +206,9 @@ def load_model(model_path: str):
weights = mx.load(str(model_path / "weights.npz"))
if quantization is not None:
nn.QuantizedLinear.quantize_module(model, **quantization)
model.update(tree_unflatten(list(weights.items())))
parameteres = tree_unflatten(list(weights.items()))
model.update(parameteres)
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
return model, tokenizer
@ -206,7 +225,7 @@ if __name__ == "__main__":
parser.add_argument(
"--prompt",
help="The message to be processed by the model",
default="The winter in Beijing is",
default="The winter in Beijing is ",
)
parser.add_argument(
"--max-tokens",