shard llama model after conversion and unshard on loading (#174)

This commit is contained in:
Daniel Strobusch 2023-12-25 20:19:43 +01:00 committed by GitHub
parent 738448c2d4
commit 2bd20ef0e0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 40 additions and 4 deletions

View File

@ -15,7 +15,6 @@ import torch
from llama import Llama, ModelArgs, sanitize_config from llama import Llama, ModelArgs, sanitize_config
from mlx.utils import tree_flatten, tree_map, tree_unflatten from mlx.utils import tree_flatten, tree_map, tree_unflatten
def llama(model_path): def llama(model_path):
SHARD_FIRST = ["wv", "wq", "wk", "w1", "w3", "output"] SHARD_FIRST = ["wv", "wq", "wk", "w1", "w3", "output"]
SHARD_SECOND = ["tok_embeddings", "wo", "w2"] SHARD_SECOND = ["tok_embeddings", "wo", "w2"]
@ -140,6 +139,22 @@ def quantize(weights, config, args):
return quantized_weights, quantized_config return quantized_weights, quantized_config
def make_shards(weights: dict, max_file_size_gibibyte: int = 15):
max_file_size_bytes = max_file_size_gibibyte << 30
shards = []
shard, shard_size = {}, 0
for k, v in weights.items():
# TODO: simplify to v.nbytes as soon as mx.array exposes it
estimated_size = v.size * v.dtype.size if isinstance(v, mx.array) else v.nbytes
if shard_size + estimated_size > max_file_size_bytes:
shards.append(shard)
shard, shard_size = {}, 0
shard[k] = v
shard_size += estimated_size
shards.append(shard)
return shards
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Convert Llama weights to MLX") parser = argparse.ArgumentParser(description="Convert Llama weights to MLX")
parser.add_argument( parser.add_argument(
@ -200,6 +215,11 @@ if __name__ == "__main__":
str(torch_path / "tokenizer.model"), str(torch_path / "tokenizer.model"),
str(mlx_path / "tokenizer.model"), str(mlx_path / "tokenizer.model"),
) )
np.savez(str(mlx_path / "weights.npz"), **weights) shards = make_shards(weights)
if len(shards) == 1:
np.savez(str(mlx_path / f"weights.npz"), **shards[0])
else:
for i, shard in enumerate(shards):
np.savez(str(mlx_path / f"weights.{i:02d}.npz"), **shard)
with open(mlx_path / "config.json", "w") as fid: with open(mlx_path / "config.json", "w") as fid:
json.dump(params, fid, indent=4) json.dump(params, fid, indent=4)

View File

@ -3,6 +3,7 @@
import argparse import argparse
import json import json
import time import time
import glob
from dataclasses import dataclass from dataclasses import dataclass
from pathlib import Path from pathlib import Path
from typing import Optional, Tuple from typing import Optional, Tuple
@ -330,7 +331,23 @@ def sanitize_config(config, weights):
def load_model(model_path): def load_model(model_path):
model_path = Path(model_path) model_path = Path(model_path)
weights = mx.load(str(model_path / "weights.npz"))
unsharded_weights_path = Path(model_path / "weights.npz")
if unsharded_weights_path.is_file():
print("[INFO] Loading model from {}.".format(unsharded_weights_path))
weights = mx.load(str(unsharded_weights_path))
else:
sharded_weights_glob = str(model_path / "weights.*.npz")
weight_files = glob.glob(sharded_weights_glob)
print("[INFO] Loading model from {}.".format(sharded_weights_glob))
if len(weight_files) == 0:
raise FileNotFoundError("No weights found in {}".format(model_path))
weights = {}
for wf in weight_files:
weights.update(mx.load(wf).items())
with open(model_path / "config.json", "r") as f: with open(model_path / "config.json", "r") as f:
config = sanitize_config(json.loads(f.read()), weights) config = sanitize_config(json.loads(f.read()), weights)
quantization = config.pop("quantization", None) quantization = config.pop("quantization", None)
@ -373,7 +390,6 @@ if __name__ == "__main__":
mx.random.seed(args.seed) mx.random.seed(args.seed)
print("[INFO] Loading model from disk.")
model, tokenizer = load_model(args.model_path) model, tokenizer = load_model(args.model_path)
if args.few_shot: if args.few_shot:
few_shot_generate(args) few_shot_generate(args)