keep dtype on model conversion (#186)

This commit is contained in:
Daniel Strobusch 2024-01-02 20:20:29 +01:00 committed by GitHub
parent 85258b2be7
commit 1d09c4fecd
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -10,13 +10,16 @@ from pathlib import Path
import mlx.core as mx import mlx.core as mx
import mlx.nn as nn import mlx.nn as nn
import numpy as np
import torch import torch
from llama import Llama, ModelArgs, sanitize_config from llama import Llama, ModelArgs, sanitize_config
from mlx.utils import tree_flatten, tree_map, tree_unflatten from mlx.utils import tree_flatten, tree_map, tree_unflatten
def torch_to_mx(a: torch.Tensor, *, dtype: str) -> mx.array:
# bfloat16 is not numpy convertible. Upcast to float32 to avoid precision loss
a = a.to(torch.float32) if dtype == 'bfloat16' else a.to(getattr(torch, dtype))
return mx.array(a.numpy(), getattr(mx, dtype))
def llama(model_path): def llama(model_path, *, dtype: str):
SHARD_FIRST = ["wv", "wq", "wk", "w1", "w3", "output"] SHARD_FIRST = ["wv", "wq", "wk", "w1", "w3", "output"]
SHARD_SECOND = ["tok_embeddings", "wo", "w2"] SHARD_SECOND = ["tok_embeddings", "wo", "w2"]
SHARD_WEIGHTS = set(SHARD_FIRST + SHARD_SECOND) SHARD_WEIGHTS = set(SHARD_FIRST + SHARD_SECOND)
@ -37,14 +40,15 @@ def llama(model_path):
axis = 1 axis = 1
else: else:
raise ValueError("Invalid weight name") raise ValueError("Invalid weight name")
return np.concatenate(v, axis=axis) return mx.concatenate(v, axis=axis)
torch_files = glob.glob(str(model_path / "consolidated.*.pth")) torch_files = glob.glob(str(model_path / "consolidated.*.pth"))
weights = collections.defaultdict(list) weights = collections.defaultdict(list)
for wf in torch_files: for wf in torch_files:
state = torch.load(wf, map_location=torch.device("cpu")) state = torch.load(wf, map_location=torch.device("cpu"))
for k, v in state.items(): for k, v in state.items():
v = v.to(torch.float16).numpy() v = torch_to_mx(v, dtype=dtype)
state[k] = None # free memory
if shard_key(k) in SHARD_WEIGHTS: if shard_key(k) in SHARD_WEIGHTS:
weights[k].append(v) weights[k].append(v)
else: else:
@ -57,7 +61,7 @@ def llama(model_path):
return weights, params return weights, params
def tiny_llama(model_path): def tiny_llama(model_path, *, dtype: str):
try: try:
import transformers import transformers
except ImportError: except ImportError:
@ -113,7 +117,7 @@ def tiny_llama(model_path):
params["vocab_size"] = config.vocab_size params["vocab_size"] = config.vocab_size
params["norm_eps"] = config.rms_norm_eps params["norm_eps"] = config.rms_norm_eps
params["rope_traditional"] = False params["rope_traditional"] = False
weights = {k: v.to(torch.float16).numpy() for k, v in model.items()} weights = {k: torch_to_mx(v, dtype=dtype) for k, v in model.items()}
return weights, params return weights, params
@ -197,6 +201,13 @@ if __name__ == "__main__":
type=int, type=int,
default=4, default=4,
) )
parser.add_argument(
"--dtype",
help="dtype for loading the torch model and input for quantization or saving the converted model. "
"The original weights are stored in bfloat16.",
type=str,
default="float16",
)
args = parser.parse_args() args = parser.parse_args()
@ -205,7 +216,7 @@ if __name__ == "__main__":
mlx_path.mkdir(parents=True, exist_ok=True) mlx_path.mkdir(parents=True, exist_ok=True)
print("[INFO] Loading") print("[INFO] Loading")
weights, params = globals()[args.model_name](torch_path) weights, params = globals()[args.model_name](torch_path, dtype=args.dtype)
params["model_type"] = "llama" params["model_type"] = "llama"
if args.quantize: if args.quantize:
print("[INFO] Quantizing") print("[INFO] Quantizing")
@ -218,9 +229,9 @@ if __name__ == "__main__":
) )
shards = make_shards(weights) shards = make_shards(weights)
if len(shards) == 1: if len(shards) == 1:
np.savez(str(mlx_path / f"weights.npz"), **shards[0]) mx.savez(str(mlx_path / f"weights.npz"), **shards[0])
else: else:
for i, shard in enumerate(shards): for i, shard in enumerate(shards):
np.savez(str(mlx_path / f"weights.{i:02d}.npz"), **shard) mx.savez(str(mlx_path / f"weights.{i:02d}.npz"), **shard)
with open(mlx_path / "config.json", "w") as fid: with open(mlx_path / "config.json", "w") as fid:
json.dump(params, fid, indent=4) json.dump(params, fid, indent=4)