From d81dcad68f208230a2f5360055e32fc08b3eb0af Mon Sep 17 00:00:00 2001
From: Anchen
Date: Mon, 1 Jan 2024 14:56:21 +1100
Subject: [PATCH] chore: add workaround for quant
---
llms/yayi2/convert.py | 18 ++++++++++++++++++
llms/yayi2/yayi.py | 29 ++++++++++++++++++++++++-----
2 files changed, 42 insertions(+), 5 deletions(-)
diff --git a/llms/yayi2/convert.py b/llms/yayi2/convert.py
index 5beca05a..51d14634 100644
--- a/llms/yayi2/convert.py
+++ b/llms/yayi2/convert.py
@@ -84,6 +84,23 @@ def convert(args):
state_dict = {k.replace("v_proj", "wv"): v for k, v in state_dict.items()}
state_dict = {k.replace("o_proj", "wo"): v for k, v in state_dict.items()}
+ # quantization workaround for kv proj
+ keys_to_delete =[]
+ new_state_dict ={}
+ for k, v in state_dict.items():
+ if k.endswith('wk.weight'):
+ prefix = k[:-len('wk.weight')]
+ wv = prefix +"wv.weight"
+ if wv in state_dict:
+ wkwv = torch.cat([v, state_dict[wv]],dim=0)
+ new_key = prefix + "wkwv.weight"
+ new_state_dict[new_key] = wkwv
+ keys_to_delete.extend([k,wv])
+
+ for key in keys_to_delete:
+ del state_dict[key]
+
+ state_dict.update(new_state_dict)
weights = {k: v.numpy() for k, v in state_dict.items()}
keep_keys = set(
@@ -152,3 +169,4 @@ if __name__ == "__main__":
with open(mlx_path / "config.json", "w") as f:
config["model_type"] = "yayi"
json.dump(config, f, indent=4)
+
diff --git a/llms/yayi2/yayi.py b/llms/yayi2/yayi.py
index 29dfd4cf..0889c83a 100644
--- a/llms/yayi2/yayi.py
+++ b/llms/yayi2/yayi.py
@@ -1,6 +1,7 @@
import argparse
import json
from dataclasses import dataclass
+import math
from pathlib import Path
from typing import Optional, Tuple
@@ -23,6 +24,21 @@ class ModelArgs:
rope_traditional: bool = False
+class KVLinear(nn.Module):
+ def __init__(self, input_dims: int, output_dims: int):
+ super().__init__()
+ scale = math.sqrt(1 / input_dims)
+ self.weight = mx.random.uniform(
+ low=-scale,
+ high=scale,
+ shape=(output_dims, input_dims),
+ )
+
+ def __call__(self, x):
+ k, v = mx.split(x @ self.weight.T, 2, axis=-1)
+ return k, v
+
+
class RMSNorm(nn.Module):
def __init__(self, dims: int, eps: float = 1e-5):
super().__init__()
@@ -49,8 +65,8 @@ class Attention(nn.Module):
self.wq = nn.Linear(
args.hidden_size, args.num_attention_heads * self.head_dim, bias=False
)
- self.wk = nn.Linear(args.hidden_size, int(self.head_dim), bias=False)
- self.wv = nn.Linear(args.hidden_size, int(self.head_dim), bias=False)
+ self.wkwv = KVLinear(args.hidden_size, int(self.head_dim) * 2)
+
self.wo = nn.Linear(
args.num_attention_heads * self.head_dim, args.hidden_size, bias=False
)
@@ -66,7 +82,8 @@ class Attention(nn.Module):
) -> mx.array:
B, L, _ = x.shape
- q, k, v = self.wq(x), self.wk(x), self.wv(x)
+ q = self.wq(x)
+ k, v = self.wkwv(x)
q = q.reshape(B, L, self.num_attention_heads, self.head_dim).transpose(
0, 2, 1, 3
@@ -189,7 +206,9 @@ def load_model(model_path: str):
weights = mx.load(str(model_path / "weights.npz"))
if quantization is not None:
nn.QuantizedLinear.quantize_module(model, **quantization)
- model.update(tree_unflatten(list(weights.items())))
+ parameteres = tree_unflatten(list(weights.items()))
+
+ model.update(parameteres)
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
return model, tokenizer
@@ -206,7 +225,7 @@ if __name__ == "__main__":
parser.add_argument(
"--prompt",
help="The message to be processed by the model",
- default="The winter in Beijing is",
+ default="The winter in Beijing is ",
)
parser.add_argument(
"--max-tokens",