mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-24 09:21:18 +08:00
60 lines
1.5 KiB
Python
60 lines
1.5 KiB
Python
# Copyright © 2023 Apple Inc.
|
|
|
|
import argparse
|
|
import collections
|
|
import glob
|
|
from pathlib import Path
|
|
|
|
import numpy as np
|
|
import torch
|
|
|
|
SHARD_FIRST = ["wv", "wq", "wk", "w1", "w3", "output"]
|
|
SHARD_SECOND = ["tok_embeddings", "wo", "w2"]
|
|
SHARD_WEIGHTS = set(SHARD_FIRST + SHARD_SECOND)
|
|
|
|
|
|
def shard_key(k):
|
|
keys = k.split(".")
|
|
if len(keys) < 2:
|
|
return None
|
|
return keys[-2]
|
|
|
|
|
|
def unshard(k, v):
|
|
wn = shard_key(k)
|
|
if wn not in SHARD_WEIGHTS:
|
|
return v
|
|
elif wn in SHARD_FIRST:
|
|
axis = 0
|
|
elif wn in SHARD_SECOND:
|
|
axis = 1
|
|
else:
|
|
raise ValueError("Invalid weight name")
|
|
return np.concatenate(v, axis=axis)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
parser = argparse.ArgumentParser(description="Convert Llama weights to MLX")
|
|
parser.add_argument(
|
|
"--model_path",
|
|
help="Path to the Torch model. The MLX weights will also be saved there.",
|
|
)
|
|
args = parser.parse_args()
|
|
|
|
model_path = Path(args.model_path)
|
|
torch_files = glob.glob(str(model_path / "consolidated.*.pth"))
|
|
weights = collections.defaultdict(list)
|
|
for wf in torch_files:
|
|
state = torch.load(wf, map_location=torch.device("cpu"))
|
|
for k, v in state.items():
|
|
v = v.to(torch.float16).numpy()
|
|
if shard_key(k) in SHARD_WEIGHTS:
|
|
weights[k].append(v)
|
|
else:
|
|
weights[k] = v
|
|
|
|
out_file = str(model_path / "weights.npz")
|
|
for k, v in weights.items():
|
|
weights[k] = unshard(k, v)
|
|
np.savez(out_file, **weights)
|