mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-09-11 14:24:35 +08:00
llama v2 with sharded weights
This commit is contained in:
@@ -1,53 +1,59 @@
|
||||
# Copyright © 2023 Apple Inc.
|
||||
|
||||
import argparse
|
||||
from itertools import starmap
|
||||
import collections
|
||||
import glob
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
SHARD_FIRST = ["wv", "wq", "wk", "w1", "w3", "output"]
|
||||
SHARD_SECOND = ["tok_embeddings", "wo", "w2"]
|
||||
SHARD_WEIGHTS = set(SHARD_FIRST + SHARD_SECOND)
|
||||
|
||||
def map_torch_to_mlx(key, value):
|
||||
if "tok_embedding" in key:
|
||||
key = "embedding.weight"
|
||||
|
||||
elif "norm" in key:
|
||||
key = key.replace("attention_norm", "norm1").replace("ffn_norm", "norm2")
|
||||
def shard_key(k):
|
||||
keys = k.split(".")
|
||||
if len(keys) < 2:
|
||||
return None
|
||||
return keys[-2]
|
||||
|
||||
elif "wq" in key or "wk" in key or "wv" in key or "wo" in key:
|
||||
key = key.replace("wq", "query_proj")
|
||||
key = key.replace("wk", "key_proj")
|
||||
key = key.replace("wv", "value_proj")
|
||||
key = key.replace("wo", "out_proj")
|
||||
|
||||
elif "w1" in key or "w2" in key or "w3" in key:
|
||||
# The FFN is a separate submodule in PyTorch
|
||||
key = key.replace("feed_forward.w1", "linear1")
|
||||
key = key.replace("feed_forward.w3", "linear2")
|
||||
key = key.replace("feed_forward.w2", "linear3")
|
||||
|
||||
elif "output" in key:
|
||||
key = key.replace("output", "out_proj")
|
||||
|
||||
elif "rope" in key:
|
||||
return None, None
|
||||
|
||||
return (
|
||||
key,
|
||||
value.numpy()
|
||||
if value.dtype != torch.bfloat16
|
||||
else value.to(torch.float32).numpy(),
|
||||
)
|
||||
def unshard(k, v):
|
||||
wn = shard_key(k)
|
||||
if wn not in SHARD_WEIGHTS:
|
||||
return v
|
||||
elif wn in SHARD_FIRST:
|
||||
axis = 0
|
||||
elif wn in SHARD_SECOND:
|
||||
axis = 1
|
||||
else:
|
||||
raise ValueError("Invalid weight name")
|
||||
return np.concatenate(v, axis=axis)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="Convert Llama weights to MLX")
|
||||
parser.add_argument("torch_weights")
|
||||
parser.add_argument("output_file")
|
||||
parser.add_argument(
|
||||
"--model_path",
|
||||
help="Path to the Torch model. The MLX weights will also be saved there.",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
state = torch.load(args.torch_weights, map_location=torch.device('cpu'))
|
||||
np.savez(
|
||||
args.output_file,
|
||||
**{k: v for k, v in starmap(map_torch_to_mlx, state.items()) if k is not None}
|
||||
)
|
||||
model_path = Path(args.model_path)
|
||||
torch_files = glob.glob(str(model_path / "consolidated.*.pth"))
|
||||
weights = collections.defaultdict(list)
|
||||
for wf in torch_files:
|
||||
state = torch.load(wf, map_location=torch.device("cpu"))
|
||||
for k, v in state.items():
|
||||
v = v.to(torch.float16).numpy()
|
||||
if shard_key(k) in SHARD_WEIGHTS:
|
||||
weights[k].append(v)
|
||||
else:
|
||||
weights[k] = v
|
||||
|
||||
out_file = str(model_path / "weights.npz")
|
||||
for k, v in weights.items():
|
||||
weights[k] = unshard(k, v)
|
||||
np.savez(out_file, **weights)
|
||||
|
Reference in New Issue
Block a user