mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-08-31 11:54:37 +08:00
chore: add workaround for quant
This commit is contained in:
parent
b763ad3829
commit
d81dcad68f
@ -84,6 +84,23 @@ def convert(args):
|
||||
state_dict = {k.replace("v_proj", "wv"): v for k, v in state_dict.items()}
|
||||
state_dict = {k.replace("o_proj", "wo"): v for k, v in state_dict.items()}
|
||||
|
||||
# quantization workaround for kv proj
|
||||
keys_to_delete =[]
|
||||
new_state_dict ={}
|
||||
for k, v in state_dict.items():
|
||||
if k.endswith('wk.weight'):
|
||||
prefix = k[:-len('wk.weight')]
|
||||
wv = prefix +"wv.weight"
|
||||
if wv in state_dict:
|
||||
wkwv = torch.cat([v, state_dict[wv]],dim=0)
|
||||
new_key = prefix + "wkwv.weight"
|
||||
new_state_dict[new_key] = wkwv
|
||||
keys_to_delete.extend([k,wv])
|
||||
|
||||
for key in keys_to_delete:
|
||||
del state_dict[key]
|
||||
|
||||
state_dict.update(new_state_dict)
|
||||
weights = {k: v.numpy() for k, v in state_dict.items()}
|
||||
|
||||
keep_keys = set(
|
||||
@ -152,3 +169,4 @@ if __name__ == "__main__":
|
||||
with open(mlx_path / "config.json", "w") as f:
|
||||
config["model_type"] = "yayi"
|
||||
json.dump(config, f, indent=4)
|
||||
|
||||
|
@ -1,6 +1,7 @@
|
||||
import argparse
|
||||
import json
|
||||
from dataclasses import dataclass
|
||||
import math
|
||||
from pathlib import Path
|
||||
from typing import Optional, Tuple
|
||||
|
||||
@ -23,6 +24,21 @@ class ModelArgs:
|
||||
rope_traditional: bool = False
|
||||
|
||||
|
||||
class KVLinear(nn.Module):
|
||||
def __init__(self, input_dims: int, output_dims: int):
|
||||
super().__init__()
|
||||
scale = math.sqrt(1 / input_dims)
|
||||
self.weight = mx.random.uniform(
|
||||
low=-scale,
|
||||
high=scale,
|
||||
shape=(output_dims, input_dims),
|
||||
)
|
||||
|
||||
def __call__(self, x):
|
||||
k, v = mx.split(x @ self.weight.T, 2, axis=-1)
|
||||
return k, v
|
||||
|
||||
|
||||
class RMSNorm(nn.Module):
|
||||
def __init__(self, dims: int, eps: float = 1e-5):
|
||||
super().__init__()
|
||||
@ -49,8 +65,8 @@ class Attention(nn.Module):
|
||||
self.wq = nn.Linear(
|
||||
args.hidden_size, args.num_attention_heads * self.head_dim, bias=False
|
||||
)
|
||||
self.wk = nn.Linear(args.hidden_size, int(self.head_dim), bias=False)
|
||||
self.wv = nn.Linear(args.hidden_size, int(self.head_dim), bias=False)
|
||||
self.wkwv = KVLinear(args.hidden_size, int(self.head_dim) * 2)
|
||||
|
||||
self.wo = nn.Linear(
|
||||
args.num_attention_heads * self.head_dim, args.hidden_size, bias=False
|
||||
)
|
||||
@ -66,7 +82,8 @@ class Attention(nn.Module):
|
||||
) -> mx.array:
|
||||
B, L, _ = x.shape
|
||||
|
||||
q, k, v = self.wq(x), self.wk(x), self.wv(x)
|
||||
q = self.wq(x)
|
||||
k, v = self.wkwv(x)
|
||||
|
||||
q = q.reshape(B, L, self.num_attention_heads, self.head_dim).transpose(
|
||||
0, 2, 1, 3
|
||||
@ -189,7 +206,9 @@ def load_model(model_path: str):
|
||||
weights = mx.load(str(model_path / "weights.npz"))
|
||||
if quantization is not None:
|
||||
nn.QuantizedLinear.quantize_module(model, **quantization)
|
||||
model.update(tree_unflatten(list(weights.items())))
|
||||
parameteres = tree_unflatten(list(weights.items()))
|
||||
|
||||
model.update(parameteres)
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
|
||||
return model, tokenizer
|
||||
@ -206,7 +225,7 @@ if __name__ == "__main__":
|
||||
parser.add_argument(
|
||||
"--prompt",
|
||||
help="The message to be processed by the model",
|
||||
default="The winter in Beijing is",
|
||||
default="The winter in Beijing is ",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max-tokens",
|
||||
|
Loading…
Reference in New Issue
Block a user