chore: add workaround for quant

This commit is contained in:
Anchen
2024-01-01 14:56:21 +11:00
parent b763ad3829
commit d81dcad68f
2 changed files with 42 additions and 5 deletions

View File

@@ -84,6 +84,23 @@ def convert(args):
state_dict = {k.replace("v_proj", "wv"): v for k, v in state_dict.items()} state_dict = {k.replace("v_proj", "wv"): v for k, v in state_dict.items()}
state_dict = {k.replace("o_proj", "wo"): v for k, v in state_dict.items()} state_dict = {k.replace("o_proj", "wo"): v for k, v in state_dict.items()}
# quantization workaround for kv proj
keys_to_delete =[]
new_state_dict ={}
for k, v in state_dict.items():
if k.endswith('wk.weight'):
prefix = k[:-len('wk.weight')]
wv = prefix +"wv.weight"
if wv in state_dict:
wkwv = torch.cat([v, state_dict[wv]],dim=0)
new_key = prefix + "wkwv.weight"
new_state_dict[new_key] = wkwv
keys_to_delete.extend([k,wv])
for key in keys_to_delete:
del state_dict[key]
state_dict.update(new_state_dict)
weights = {k: v.numpy() for k, v in state_dict.items()} weights = {k: v.numpy() for k, v in state_dict.items()}
keep_keys = set( keep_keys = set(
@@ -152,3 +169,4 @@ if __name__ == "__main__":
with open(mlx_path / "config.json", "w") as f: with open(mlx_path / "config.json", "w") as f:
config["model_type"] = "yayi" config["model_type"] = "yayi"
json.dump(config, f, indent=4) json.dump(config, f, indent=4)

View File

@@ -1,6 +1,7 @@
import argparse import argparse
import json import json
from dataclasses import dataclass from dataclasses import dataclass
import math
from pathlib import Path from pathlib import Path
from typing import Optional, Tuple from typing import Optional, Tuple
@@ -23,6 +24,21 @@ class ModelArgs:
rope_traditional: bool = False rope_traditional: bool = False
class KVLinear(nn.Module):
def __init__(self, input_dims: int, output_dims: int):
super().__init__()
scale = math.sqrt(1 / input_dims)
self.weight = mx.random.uniform(
low=-scale,
high=scale,
shape=(output_dims, input_dims),
)
def __call__(self, x):
k, v = mx.split(x @ self.weight.T, 2, axis=-1)
return k, v
class RMSNorm(nn.Module): class RMSNorm(nn.Module):
def __init__(self, dims: int, eps: float = 1e-5): def __init__(self, dims: int, eps: float = 1e-5):
super().__init__() super().__init__()
@@ -49,8 +65,8 @@ class Attention(nn.Module):
self.wq = nn.Linear( self.wq = nn.Linear(
args.hidden_size, args.num_attention_heads * self.head_dim, bias=False args.hidden_size, args.num_attention_heads * self.head_dim, bias=False
) )
self.wk = nn.Linear(args.hidden_size, int(self.head_dim), bias=False) self.wkwv = KVLinear(args.hidden_size, int(self.head_dim) * 2)
self.wv = nn.Linear(args.hidden_size, int(self.head_dim), bias=False)
self.wo = nn.Linear( self.wo = nn.Linear(
args.num_attention_heads * self.head_dim, args.hidden_size, bias=False args.num_attention_heads * self.head_dim, args.hidden_size, bias=False
) )
@@ -66,7 +82,8 @@ class Attention(nn.Module):
) -> mx.array: ) -> mx.array:
B, L, _ = x.shape B, L, _ = x.shape
q, k, v = self.wq(x), self.wk(x), self.wv(x) q = self.wq(x)
k, v = self.wkwv(x)
q = q.reshape(B, L, self.num_attention_heads, self.head_dim).transpose( q = q.reshape(B, L, self.num_attention_heads, self.head_dim).transpose(
0, 2, 1, 3 0, 2, 1, 3
@@ -189,7 +206,9 @@ def load_model(model_path: str):
weights = mx.load(str(model_path / "weights.npz")) weights = mx.load(str(model_path / "weights.npz"))
if quantization is not None: if quantization is not None:
nn.QuantizedLinear.quantize_module(model, **quantization) nn.QuantizedLinear.quantize_module(model, **quantization)
model.update(tree_unflatten(list(weights.items()))) parameteres = tree_unflatten(list(weights.items()))
model.update(parameteres)
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
return model, tokenizer return model, tokenizer
@@ -206,7 +225,7 @@ if __name__ == "__main__":
parser.add_argument( parser.add_argument(
"--prompt", "--prompt",
help="The message to be processed by the model", help="The message to be processed by the model",
default="The winter in Beijing is", default="The winter in Beijing is ",
) )
parser.add_argument( parser.add_argument(
"--max-tokens", "--max-tokens",