llama v2 with sharded weights

This commit is contained in:
Awni Hannun
2023-12-12 12:48:15 -08:00
parent 03a408fa2e
commit c18b990f08
5 changed files with 189 additions and 123 deletions

View File

@@ -1,53 +1,59 @@
# Copyright © 2023 Apple Inc.
import argparse
from itertools import starmap
import collections
import glob
from pathlib import Path
import numpy as np
import torch
SHARD_FIRST = ["wv", "wq", "wk", "w1", "w3", "output"]
SHARD_SECOND = ["tok_embeddings", "wo", "w2"]
SHARD_WEIGHTS = set(SHARD_FIRST + SHARD_SECOND)
def map_torch_to_mlx(key, value):
if "tok_embedding" in key:
key = "embedding.weight"
elif "norm" in key:
key = key.replace("attention_norm", "norm1").replace("ffn_norm", "norm2")
def shard_key(k):
keys = k.split(".")
if len(keys) < 2:
return None
return keys[-2]
elif "wq" in key or "wk" in key or "wv" in key or "wo" in key:
key = key.replace("wq", "query_proj")
key = key.replace("wk", "key_proj")
key = key.replace("wv", "value_proj")
key = key.replace("wo", "out_proj")
elif "w1" in key or "w2" in key or "w3" in key:
# The FFN is a separate submodule in PyTorch
key = key.replace("feed_forward.w1", "linear1")
key = key.replace("feed_forward.w3", "linear2")
key = key.replace("feed_forward.w2", "linear3")
elif "output" in key:
key = key.replace("output", "out_proj")
elif "rope" in key:
return None, None
return (
key,
value.numpy()
if value.dtype != torch.bfloat16
else value.to(torch.float32).numpy(),
)
def unshard(k, v):
wn = shard_key(k)
if wn not in SHARD_WEIGHTS:
return v
elif wn in SHARD_FIRST:
axis = 0
elif wn in SHARD_SECOND:
axis = 1
else:
raise ValueError("Invalid weight name")
return np.concatenate(v, axis=axis)
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Convert Llama weights to MLX")
parser.add_argument("torch_weights")
parser.add_argument("output_file")
parser.add_argument(
"--model_path",
help="Path to the Torch model. The MLX weights will also be saved there.",
)
args = parser.parse_args()
state = torch.load(args.torch_weights, map_location=torch.device('cpu'))
np.savez(
args.output_file,
**{k: v for k, v in starmap(map_torch_to_mlx, state.items()) if k is not None}
)
model_path = Path(args.model_path)
torch_files = glob.glob(str(model_path / "consolidated.*.pth"))
weights = collections.defaultdict(list)
for wf in torch_files:
state = torch.load(wf, map_location=torch.device("cpu"))
for k, v in state.items():
v = v.to(torch.float16).numpy()
if shard_key(k) in SHARD_WEIGHTS:
weights[k].append(v)
else:
weights[k] = v
out_file = str(model_path / "weights.npz")
for k, v in weights.items():
weights[k] = unshard(k, v)
np.savez(out_file, **weights)