shard llama model after conversion and unshard on loading (#174)

This commit is contained in:
Daniel Strobusch
2023-12-25 20:19:43 +01:00
committed by GitHub
parent 738448c2d4
commit 2bd20ef0e0
2 changed files with 40 additions and 4 deletions

View File

@@ -3,6 +3,7 @@
import argparse
import json
import time
import glob
from dataclasses import dataclass
from pathlib import Path
from typing import Optional, Tuple
@@ -330,7 +331,23 @@ def sanitize_config(config, weights):
def load_model(model_path):
model_path = Path(model_path)
weights = mx.load(str(model_path / "weights.npz"))
unsharded_weights_path = Path(model_path / "weights.npz")
if unsharded_weights_path.is_file():
print("[INFO] Loading model from {}.".format(unsharded_weights_path))
weights = mx.load(str(unsharded_weights_path))
else:
sharded_weights_glob = str(model_path / "weights.*.npz")
weight_files = glob.glob(sharded_weights_glob)
print("[INFO] Loading model from {}.".format(sharded_weights_glob))
if len(weight_files) == 0:
raise FileNotFoundError("No weights found in {}".format(model_path))
weights = {}
for wf in weight_files:
weights.update(mx.load(wf).items())
with open(model_path / "config.json", "r") as f:
config = sanitize_config(json.loads(f.read()), weights)
quantization = config.pop("quantization", None)
@@ -373,7 +390,6 @@ if __name__ == "__main__":
mx.random.seed(args.seed)
print("[INFO] Loading model from disk.")
model, tokenizer = load_model(args.model_path)
if args.few_shot:
few_shot_generate(args)