use the same model structure and module names as HF

This commit is contained in:
Awni Hannun
2024-01-03 08:26:46 -08:00
parent 3fefd2e8eb
commit 2e2acc4349
2 changed files with 78 additions and 108 deletions

View File

@@ -17,70 +17,18 @@ from models import Model, ModelArgs
from mlx.utils import tree_flatten, tree_map, tree_unflatten from mlx.utils import tree_flatten, tree_map, tree_unflatten
def convert(hf_path: str, dtype: str): def fetch_from_hub(hf_path: str, dtype: str):
# Download model, config and tokenizer from HF
model = transformers.AutoModelForCausalLM.from_pretrained( model = transformers.AutoModelForCausalLM.from_pretrained(
hf_path, hf_path,
trust_remote_code=True,
torch_dtype=getattr(torch, dtype), torch_dtype=getattr(torch, dtype),
).state_dict() ).state_dict()
config = transformers.AutoConfig.from_pretrained(hf_path) config = transformers.AutoConfig.from_pretrained(hf_path)
tokenizer = transformers.AutoTokenizer.from_pretrained( tokenizer = transformers.AutoTokenizer.from_pretrained(
hf_path, hf_path,
trust_remote_code=True,
) )
# things to change
# 1. there's no "model." in the weight names
model = {k.replace("model.", ""): v for k, v in model.items()}
# 2. mlp is called feed_forward
model = {k.replace("mlp", "feed_forward"): v for k, v in model.items()}
# 3. up_proj, down_proj, gate_proj
model = {k.replace("down_proj", "w2"): v for k, v in model.items()}
model = {k.replace("up_proj", "w3"): v for k, v in model.items()}
model = {k.replace("gate_proj", "w1"): v for k, v in model.items()}
# 4. layernorms
model = {
k.replace("input_layernorm", "attention_norm"): v for k, v in model.items()
}
model = {
k.replace("post_attention_layernorm", "ffn_norm"): v for k, v in model.items()
}
# 5. lm head
model = {k.replace("lm_head", "output"): v for k, v in model.items()}
# 6. token emb
model = {k.replace("embed_tokens", "tok_embeddings"): v for k, v in model.items()}
# 7. attention
model = {k.replace("self_attn", "attention"): v for k, v in model.items()}
model = {k.replace("q_proj", "wq"): v for k, v in model.items()}
model = {k.replace("k_proj", "wk"): v for k, v in model.items()}
model = {k.replace("v_proj", "wv"): v for k, v in model.items()}
model = {k.replace("o_proj", "wo"): v for k, v in model.items()}
params = {}
params["model_type"] = "llama"
params["dim"] = config.hidden_size
params["hidden_dim"] = config.intermediate_size
params["head_dim"] = config.hidden_size // config.num_attention_heads
params["n_heads"] = config.num_attention_heads
if hasattr(config, "num_key_value_heads"):
params["n_kv_heads"] = config.num_key_value_heads
params["n_layers"] = config.num_hidden_layers
params["vocab_size"] = config.vocab_size
params["norm_eps"] = config.rms_norm_eps
params["rope_traditional"] = False
params["rope_theta"] = getattr(config, "rope_theta", 10000)
for k, v in model.items(): for k, v in model.items():
model[k] = mx.array(v.numpy()) model[k] = mx.array(v.numpy())
return model, config.to_dict(), tokenizer
return model, params, tokenizer
def quantize(weights, config, args): def quantize(weights, config, args):
@@ -89,8 +37,7 @@ def quantize(weights, config, args):
# Load the model: # Load the model:
model = Model(ModelArgs(**config)) model = Model(ModelArgs(**config))
weights = tree_map(mx.array, weights) weights = tree_map(mx.array, weights)
# TODO replace with model.load_weights model.load_weights(list(weights.items()))
model.update(tree_unflatten(list(weights.items())))
# Quantize the model: # Quantize the model:
nn.QuantizedLinear.quantize_module(model, args.q_group_size, args.q_bits) nn.QuantizedLinear.quantize_module(model, args.q_group_size, args.q_bits)
@@ -163,7 +110,7 @@ if __name__ == "__main__":
args = parser.parse_args() args = parser.parse_args()
print("[INFO] Loading") print("[INFO] Loading")
weights, config, tokenizer = convert(args.hf_path, args.dtype) weights, config, tokenizer = fetch_from_hub(args.hf_path, args.dtype)
if args.quantize: if args.quantize:
print("[INFO] Quantizing") print("[INFO] Quantizing")
weights, config = quantize(weights, config, args) weights, config = quantize(weights, config, args)

View File

@@ -2,6 +2,7 @@
from dataclasses import dataclass from dataclasses import dataclass
import glob import glob
import inspect
import json import json
from pathlib import Path from pathlib import Path
from typing import Optional, Tuple from typing import Optional, Tuple
@@ -15,17 +16,27 @@ from huggingface_hub import snapshot_download
@dataclass @dataclass
class ModelArgs: class ModelArgs:
dim: int hidden_size: int
n_layers: int num_hidden_layers: int
head_dim: int intermediate_size: int
hidden_dim: int num_attention_heads: int
n_heads: int rms_norm_eps: float
n_kv_heads: int
norm_eps: float
vocab_size: int vocab_size: int
num_key_value_heads: int = None
rope_theta: float = 10000 rope_theta: float = 10000
rope_traditional: bool = True rope_traditional: bool = False
model_type: str = None
def __post_init__(self):
if self.num_key_value_heads is None:
self.num_key_value_heads = self.num_attention_heads
@classmethod
def from_dict(cls, params):
return cls(**{
k: v for k, v in params.items()
if k in inspect.signature(cls).parameters
})
class RMSNorm(nn.Module): class RMSNorm(nn.Module):
def __init__(self, dims: int, eps: float = 1e-5): def __init__(self, dims: int, eps: float = 1e-5):
@@ -44,20 +55,22 @@ class RMSNorm(nn.Module):
class Attention(nn.Module): class Attention(nn.Module):
def __init__(self, args: ModelArgs): def __init__(self, args: ModelArgs):
super().__init__() super().__init__()
self.args = args
self.n_heads: int = args.n_heads
self.n_kv_heads: int = args.n_kv_heads
self.repeats = self.n_heads // self.n_kv_heads dim = args.hidden_size
self.n_heads = n_heads = args.num_attention_heads
self.n_kv_heads = n_kv_heads = args.num_key_value_heads
self.scale = self.args.head_dim**-0.5 self.repeats = n_heads // n_kv_heads
self.wq = nn.Linear(args.dim, args.n_heads * args.head_dim, bias=False) head_dim = args.hidden_size // n_heads
self.wk = nn.Linear(args.dim, args.n_kv_heads * args.head_dim, bias=False) self.scale = head_dim**-0.5
self.wv = nn.Linear(args.dim, args.n_kv_heads * args.head_dim, bias=False)
self.wo = nn.Linear(args.n_heads * args.head_dim, args.dim, bias=False) self.q_proj = nn.Linear(dim, n_heads * head_dim, bias=False)
self.rope = nn.RoPE(args.head_dim, traditional=args.rope_traditional, base=args.rope_theta) self.k_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=False)
self.v_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=False)
self.o_proj = nn.Linear(n_heads * head_dim, dim, bias=False)
self.rope = nn.RoPE(head_dim, traditional=args.rope_traditional, base=args.rope_theta)
def __call__( def __call__(
self, self,
@@ -67,7 +80,7 @@ class Attention(nn.Module):
) -> mx.array: ) -> mx.array:
B, L, D = x.shape B, L, D = x.shape
queries, keys, values = self.wq(x), self.wk(x), self.wv(x) queries, keys, values = self.q_proj(x), self.k_proj(x), self.v_proj(x)
# Prepare the queries, keys and values for the attention computation # Prepare the queries, keys and values for the attention computation
queries = queries.reshape(B, L, self.n_heads, -1).transpose(0, 2, 1, 3) queries = queries.reshape(B, L, self.n_heads, -1).transpose(0, 2, 1, 3)
@@ -78,7 +91,8 @@ class Attention(nn.Module):
a = mx.concatenate([mx.expand_dims(a, 2)] * self.repeats, axis=2) a = mx.concatenate([mx.expand_dims(a, 2)] * self.repeats, axis=2)
return a.reshape([B, self.n_heads, L, -1]) return a.reshape([B, self.n_heads, L, -1])
keys, values = map(repeat, (keys, values)) if self.repeats > 1:
keys, values = map(repeat, (keys, values))
if cache is not None: if cache is not None:
key_cache, value_cache = cache key_cache, value_cache = cache
@@ -95,30 +109,29 @@ class Attention(nn.Module):
scores += mask scores += mask
scores = mx.softmax(scores.astype(mx.float32), axis=-1).astype(scores.dtype) scores = mx.softmax(scores.astype(mx.float32), axis=-1).astype(scores.dtype)
output = (scores @ values).transpose(0, 2, 1, 3).reshape(B, L, -1) output = (scores @ values).transpose(0, 2, 1, 3).reshape(B, L, -1)
return self.wo(output), (keys, values) return self.o_proj(output), (keys, values)
class FeedForward(nn.Module): class MLP(nn.Module):
def __init__(self, args: ModelArgs): def __init__(self, dim, hidden_dim):
super().__init__() super().__init__()
self.gate_proj = nn.Linear(dim, hidden_dim, bias=False)
self.w1 = nn.Linear(args.dim, args.hidden_dim, bias=False) self.down_proj = nn.Linear(hidden_dim, dim, bias=False)
self.w2 = nn.Linear(args.hidden_dim, args.dim, bias=False) self.up_proj = nn.Linear(dim, hidden_dim, bias=False)
self.w3 = nn.Linear(args.dim, args.hidden_dim, bias=False)
def __call__(self, x) -> mx.array: def __call__(self, x) -> mx.array:
return self.w2(nn.silu(self.w1(x)) * self.w3(x)) return self.down_proj(nn.silu(self.gate_proj(x)) * self.up_proj(x))
class TransformerBlock(nn.Module): class TransformerBlock(nn.Module):
def __init__(self, args: ModelArgs): def __init__(self, args: ModelArgs):
super().__init__() super().__init__()
self.n_heads = args.n_heads self.num_attention_heads = args.num_attention_heads
self.dim = args.dim self.hidden_size = args.hidden_size
self.attention = Attention(args) self.self_attn = Attention(args)
self.feed_forward = FeedForward(args=args) self.mlp = MLP(args.hidden_size, args.intermediate_size)
self.attention_norm = RMSNorm(args.dim, eps=args.norm_eps) self.input_layernorm = RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
self.ffn_norm = RMSNorm(args.dim, eps=args.norm_eps) self.post_attention_layernorm = RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
self.args = args self.args = args
def __call__( def __call__(
@@ -127,31 +140,30 @@ class TransformerBlock(nn.Module):
mask: Optional[mx.array] = None, mask: Optional[mx.array] = None,
cache: Optional[Tuple[mx.array, mx.array]] = None, cache: Optional[Tuple[mx.array, mx.array]] = None,
) -> mx.array: ) -> mx.array:
r, cache = self.attention(self.attention_norm(x), mask, cache) r, cache = self.self_attn(self.input_layernorm(x), mask, cache)
h = x + r h = x + r
r = self.feed_forward(self.ffn_norm(h)) r = self.mlp(self.post_attention_layernorm(h))
out = h + r out = h + r
return out, cache return out, cache
class Model(nn.Module): class LlamaModel(nn.Module):
def __init__(self, args: ModelArgs): def __init__(self, args: ModelArgs):
super().__init__() super().__init__()
self.args = args self.args = args
self.vocab_size = args.vocab_size self.vocab_size = args.vocab_size
self.n_layers = args.n_layers self.num_hidden_layers = args.num_hidden_layers
assert self.vocab_size > 0 assert self.vocab_size > 0
self.tok_embeddings = nn.Embedding(args.vocab_size, args.dim) self.embed_tokens = nn.Embedding(args.vocab_size, args.hidden_size)
self.layers = [TransformerBlock(args=args) for _ in range(args.n_layers)] self.layers = [TransformerBlock(args=args) for _ in range(args.num_hidden_layers)]
self.norm = RMSNorm(args.dim, eps=args.norm_eps) self.norm = RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
self.output = nn.Linear(args.dim, args.vocab_size, bias=False)
def __call__( def __call__(
self, self,
inputs: mx.array, inputs: mx.array,
cache=None, cache=None,
): ):
h = self.tok_embeddings(inputs) h = self.embed_tokens(inputs)
mask = None mask = None
if h.shape[1] > 1: if h.shape[1] > 1:
@@ -164,7 +176,22 @@ class Model(nn.Module):
for e, layer in enumerate(self.layers): for e, layer in enumerate(self.layers):
h, cache[e] = layer(h, mask, cache[e]) h, cache[e] = layer(h, mask, cache[e])
return self.output(self.norm(h)), cache return self.norm(h), cache
class Model(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.model = LlamaModel(args)
self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False)
def __call__(
self,
inputs: mx.array,
cache=None,
):
out, cache = self.model(inputs, cache)
return self.lm_head(out), cache
def load(path_or_hf_repo: str): def load(path_or_hf_repo: str):
@@ -178,9 +205,8 @@ def load(path_or_hf_repo: str):
with open(model_path / "config.json", "r") as f: with open(model_path / "config.json", "r") as f:
config = json.loads(f.read()) config = json.loads(f.read())
config.pop("model_type", None) quantization = config.get("quantization", None)
quantization = config.pop("quantization", None) model_args = ModelArgs.from_dict(config)
model_args = ModelArgs(**config)
weight_files = glob.glob(str(model_path / "weights.*.safetensors")) weight_files = glob.glob(str(model_path / "weights.*.safetensors"))
if len(weight_files) == 0: if len(weight_files) == 0:
@@ -194,14 +220,11 @@ def load(path_or_hf_repo: str):
if quantization is not None: if quantization is not None:
nn.QuantizedLinear.quantize_module(model, **quantization) nn.QuantizedLinear.quantize_module(model, **quantization)
# TODO replace with model.load_weights(list(weights.items()))
# model.load_weights(weights)
model.update(tree_unflatten(list(weights.items())))
mx.eval(model.parameters()) mx.eval(model.parameters())
tokenizer = AutoTokenizer.from_pretrained( tokenizer = AutoTokenizer.from_pretrained(
model_path, model_path,
trust_remote_code=True,
) )
return model, tokenizer return model, tokenizer