diff --git a/llms/yayi2/convert.py b/llms/yayi2/convert.py index 5beca05a..51d14634 100644 --- a/llms/yayi2/convert.py +++ b/llms/yayi2/convert.py @@ -84,6 +84,23 @@ def convert(args): state_dict = {k.replace("v_proj", "wv"): v for k, v in state_dict.items()} state_dict = {k.replace("o_proj", "wo"): v for k, v in state_dict.items()} + # quantization workaround for kv proj + keys_to_delete =[] + new_state_dict ={} + for k, v in state_dict.items(): + if k.endswith('wk.weight'): + prefix = k[:-len('wk.weight')] + wv = prefix +"wv.weight" + if wv in state_dict: + wkwv = torch.cat([v, state_dict[wv]],dim=0) + new_key = prefix + "wkwv.weight" + new_state_dict[new_key] = wkwv + keys_to_delete.extend([k,wv]) + + for key in keys_to_delete: + del state_dict[key] + + state_dict.update(new_state_dict) weights = {k: v.numpy() for k, v in state_dict.items()} keep_keys = set( @@ -152,3 +169,4 @@ if __name__ == "__main__": with open(mlx_path / "config.json", "w") as f: config["model_type"] = "yayi" json.dump(config, f, indent=4) + diff --git a/llms/yayi2/yayi.py b/llms/yayi2/yayi.py index 29dfd4cf..0889c83a 100644 --- a/llms/yayi2/yayi.py +++ b/llms/yayi2/yayi.py @@ -1,6 +1,7 @@ import argparse import json from dataclasses import dataclass +import math from pathlib import Path from typing import Optional, Tuple @@ -23,6 +24,21 @@ class ModelArgs: rope_traditional: bool = False +class KVLinear(nn.Module): + def __init__(self, input_dims: int, output_dims: int): + super().__init__() + scale = math.sqrt(1 / input_dims) + self.weight = mx.random.uniform( + low=-scale, + high=scale, + shape=(output_dims, input_dims), + ) + + def __call__(self, x): + k, v = mx.split(x @ self.weight.T, 2, axis=-1) + return k, v + + class RMSNorm(nn.Module): def __init__(self, dims: int, eps: float = 1e-5): super().__init__() @@ -49,8 +65,8 @@ class Attention(nn.Module): self.wq = nn.Linear( args.hidden_size, args.num_attention_heads * self.head_dim, bias=False ) - self.wk = nn.Linear(args.hidden_size, int(self.head_dim), bias=False) - self.wv = nn.Linear(args.hidden_size, int(self.head_dim), bias=False) + self.wkwv = KVLinear(args.hidden_size, int(self.head_dim) * 2) + self.wo = nn.Linear( args.num_attention_heads * self.head_dim, args.hidden_size, bias=False ) @@ -66,7 +82,8 @@ class Attention(nn.Module): ) -> mx.array: B, L, _ = x.shape - q, k, v = self.wq(x), self.wk(x), self.wv(x) + q = self.wq(x) + k, v = self.wkwv(x) q = q.reshape(B, L, self.num_attention_heads, self.head_dim).transpose( 0, 2, 1, 3 @@ -189,7 +206,9 @@ def load_model(model_path: str): weights = mx.load(str(model_path / "weights.npz")) if quantization is not None: nn.QuantizedLinear.quantize_module(model, **quantization) - model.update(tree_unflatten(list(weights.items()))) + parameteres = tree_unflatten(list(weights.items())) + + model.update(parameteres) tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) return model, tokenizer @@ -206,7 +225,7 @@ if __name__ == "__main__": parser.add_argument( "--prompt", help="The message to be processed by the model", - default="The winter in Beijing is", + default="The winter in Beijing is ", ) parser.add_argument( "--max-tokens",