mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-12-16 02:08:55 +08:00
shard llama model after conversion and unshard on loading (#174)
This commit is contained in:
@@ -3,6 +3,7 @@
|
||||
import argparse
|
||||
import json
|
||||
import time
|
||||
import glob
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Optional, Tuple
|
||||
@@ -330,7 +331,23 @@ def sanitize_config(config, weights):
|
||||
|
||||
def load_model(model_path):
|
||||
model_path = Path(model_path)
|
||||
weights = mx.load(str(model_path / "weights.npz"))
|
||||
|
||||
unsharded_weights_path = Path(model_path / "weights.npz")
|
||||
if unsharded_weights_path.is_file():
|
||||
print("[INFO] Loading model from {}.".format(unsharded_weights_path))
|
||||
weights = mx.load(str(unsharded_weights_path))
|
||||
else:
|
||||
sharded_weights_glob = str(model_path / "weights.*.npz")
|
||||
weight_files = glob.glob(sharded_weights_glob)
|
||||
print("[INFO] Loading model from {}.".format(sharded_weights_glob))
|
||||
|
||||
if len(weight_files) == 0:
|
||||
raise FileNotFoundError("No weights found in {}".format(model_path))
|
||||
|
||||
weights = {}
|
||||
for wf in weight_files:
|
||||
weights.update(mx.load(wf).items())
|
||||
|
||||
with open(model_path / "config.json", "r") as f:
|
||||
config = sanitize_config(json.loads(f.read()), weights)
|
||||
quantization = config.pop("quantization", None)
|
||||
@@ -373,7 +390,6 @@ if __name__ == "__main__":
|
||||
|
||||
mx.random.seed(args.seed)
|
||||
|
||||
print("[INFO] Loading model from disk.")
|
||||
model, tokenizer = load_model(args.model_path)
|
||||
if args.few_shot:
|
||||
few_shot_generate(args)
|
||||
|
||||
Reference in New Issue
Block a user