mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-10-24 06:28:07 +08:00
shard llama model after conversion and unshard on loading (#174)
This commit is contained in:
@@ -15,7 +15,6 @@ import torch
|
||||
from llama import Llama, ModelArgs, sanitize_config
|
||||
from mlx.utils import tree_flatten, tree_map, tree_unflatten
|
||||
|
||||
|
||||
def llama(model_path):
|
||||
SHARD_FIRST = ["wv", "wq", "wk", "w1", "w3", "output"]
|
||||
SHARD_SECOND = ["tok_embeddings", "wo", "w2"]
|
||||
@@ -140,6 +139,22 @@ def quantize(weights, config, args):
|
||||
return quantized_weights, quantized_config
|
||||
|
||||
|
||||
def make_shards(weights: dict, max_file_size_gibibyte: int = 15):
|
||||
max_file_size_bytes = max_file_size_gibibyte << 30
|
||||
shards = []
|
||||
shard, shard_size = {}, 0
|
||||
for k, v in weights.items():
|
||||
# TODO: simplify to v.nbytes as soon as mx.array exposes it
|
||||
estimated_size = v.size * v.dtype.size if isinstance(v, mx.array) else v.nbytes
|
||||
if shard_size + estimated_size > max_file_size_bytes:
|
||||
shards.append(shard)
|
||||
shard, shard_size = {}, 0
|
||||
shard[k] = v
|
||||
shard_size += estimated_size
|
||||
shards.append(shard)
|
||||
return shards
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="Convert Llama weights to MLX")
|
||||
parser.add_argument(
|
||||
@@ -200,6 +215,11 @@ if __name__ == "__main__":
|
||||
str(torch_path / "tokenizer.model"),
|
||||
str(mlx_path / "tokenizer.model"),
|
||||
)
|
||||
np.savez(str(mlx_path / "weights.npz"), **weights)
|
||||
shards = make_shards(weights)
|
||||
if len(shards) == 1:
|
||||
np.savez(str(mlx_path / f"weights.npz"), **shards[0])
|
||||
else:
|
||||
for i, shard in enumerate(shards):
|
||||
np.savez(str(mlx_path / f"weights.{i:02d}.npz"), **shard)
|
||||
with open(mlx_path / "config.json", "w") as fid:
|
||||
json.dump(params, fid, indent=4)
|
||||
|
||||
Reference in New Issue
Block a user