mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-09-01 12:49:50 +08:00
chore: add workaround for quant
This commit is contained in:
@@ -84,6 +84,23 @@ def convert(args):
|
|||||||
state_dict = {k.replace("v_proj", "wv"): v for k, v in state_dict.items()}
|
state_dict = {k.replace("v_proj", "wv"): v for k, v in state_dict.items()}
|
||||||
state_dict = {k.replace("o_proj", "wo"): v for k, v in state_dict.items()}
|
state_dict = {k.replace("o_proj", "wo"): v for k, v in state_dict.items()}
|
||||||
|
|
||||||
|
# quantization workaround for kv proj
|
||||||
|
keys_to_delete =[]
|
||||||
|
new_state_dict ={}
|
||||||
|
for k, v in state_dict.items():
|
||||||
|
if k.endswith('wk.weight'):
|
||||||
|
prefix = k[:-len('wk.weight')]
|
||||||
|
wv = prefix +"wv.weight"
|
||||||
|
if wv in state_dict:
|
||||||
|
wkwv = torch.cat([v, state_dict[wv]],dim=0)
|
||||||
|
new_key = prefix + "wkwv.weight"
|
||||||
|
new_state_dict[new_key] = wkwv
|
||||||
|
keys_to_delete.extend([k,wv])
|
||||||
|
|
||||||
|
for key in keys_to_delete:
|
||||||
|
del state_dict[key]
|
||||||
|
|
||||||
|
state_dict.update(new_state_dict)
|
||||||
weights = {k: v.numpy() for k, v in state_dict.items()}
|
weights = {k: v.numpy() for k, v in state_dict.items()}
|
||||||
|
|
||||||
keep_keys = set(
|
keep_keys = set(
|
||||||
@@ -152,3 +169,4 @@ if __name__ == "__main__":
|
|||||||
with open(mlx_path / "config.json", "w") as f:
|
with open(mlx_path / "config.json", "w") as f:
|
||||||
config["model_type"] = "yayi"
|
config["model_type"] = "yayi"
|
||||||
json.dump(config, f, indent=4)
|
json.dump(config, f, indent=4)
|
||||||
|
|
||||||
|
@@ -1,6 +1,7 @@
|
|||||||
import argparse
|
import argparse
|
||||||
import json
|
import json
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
|
import math
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Optional, Tuple
|
from typing import Optional, Tuple
|
||||||
|
|
||||||
@@ -23,6 +24,21 @@ class ModelArgs:
|
|||||||
rope_traditional: bool = False
|
rope_traditional: bool = False
|
||||||
|
|
||||||
|
|
||||||
|
class KVLinear(nn.Module):
|
||||||
|
def __init__(self, input_dims: int, output_dims: int):
|
||||||
|
super().__init__()
|
||||||
|
scale = math.sqrt(1 / input_dims)
|
||||||
|
self.weight = mx.random.uniform(
|
||||||
|
low=-scale,
|
||||||
|
high=scale,
|
||||||
|
shape=(output_dims, input_dims),
|
||||||
|
)
|
||||||
|
|
||||||
|
def __call__(self, x):
|
||||||
|
k, v = mx.split(x @ self.weight.T, 2, axis=-1)
|
||||||
|
return k, v
|
||||||
|
|
||||||
|
|
||||||
class RMSNorm(nn.Module):
|
class RMSNorm(nn.Module):
|
||||||
def __init__(self, dims: int, eps: float = 1e-5):
|
def __init__(self, dims: int, eps: float = 1e-5):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@@ -49,8 +65,8 @@ class Attention(nn.Module):
|
|||||||
self.wq = nn.Linear(
|
self.wq = nn.Linear(
|
||||||
args.hidden_size, args.num_attention_heads * self.head_dim, bias=False
|
args.hidden_size, args.num_attention_heads * self.head_dim, bias=False
|
||||||
)
|
)
|
||||||
self.wk = nn.Linear(args.hidden_size, int(self.head_dim), bias=False)
|
self.wkwv = KVLinear(args.hidden_size, int(self.head_dim) * 2)
|
||||||
self.wv = nn.Linear(args.hidden_size, int(self.head_dim), bias=False)
|
|
||||||
self.wo = nn.Linear(
|
self.wo = nn.Linear(
|
||||||
args.num_attention_heads * self.head_dim, args.hidden_size, bias=False
|
args.num_attention_heads * self.head_dim, args.hidden_size, bias=False
|
||||||
)
|
)
|
||||||
@@ -66,7 +82,8 @@ class Attention(nn.Module):
|
|||||||
) -> mx.array:
|
) -> mx.array:
|
||||||
B, L, _ = x.shape
|
B, L, _ = x.shape
|
||||||
|
|
||||||
q, k, v = self.wq(x), self.wk(x), self.wv(x)
|
q = self.wq(x)
|
||||||
|
k, v = self.wkwv(x)
|
||||||
|
|
||||||
q = q.reshape(B, L, self.num_attention_heads, self.head_dim).transpose(
|
q = q.reshape(B, L, self.num_attention_heads, self.head_dim).transpose(
|
||||||
0, 2, 1, 3
|
0, 2, 1, 3
|
||||||
@@ -189,7 +206,9 @@ def load_model(model_path: str):
|
|||||||
weights = mx.load(str(model_path / "weights.npz"))
|
weights = mx.load(str(model_path / "weights.npz"))
|
||||||
if quantization is not None:
|
if quantization is not None:
|
||||||
nn.QuantizedLinear.quantize_module(model, **quantization)
|
nn.QuantizedLinear.quantize_module(model, **quantization)
|
||||||
model.update(tree_unflatten(list(weights.items())))
|
parameteres = tree_unflatten(list(weights.items()))
|
||||||
|
|
||||||
|
model.update(parameteres)
|
||||||
|
|
||||||
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
|
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
|
||||||
return model, tokenizer
|
return model, tokenizer
|
||||||
@@ -206,7 +225,7 @@ if __name__ == "__main__":
|
|||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--prompt",
|
"--prompt",
|
||||||
help="The message to be processed by the model",
|
help="The message to be processed by the model",
|
||||||
default="The winter in Beijing is",
|
default="The winter in Beijing is ",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--max-tokens",
|
"--max-tokens",
|
||||||
|
Reference in New Issue
Block a user