keep dtype on model conversion (#186)

This commit is contained in:
Daniel Strobusch 2024-01-02 20:20:29 +01:00 committed by GitHub
parent 85258b2be7
commit 1d09c4fecd
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -10,13 +10,16 @@ from pathlib import Path
import mlx.core as mx
import mlx.nn as nn
import numpy as np
import torch
from llama import Llama, ModelArgs, sanitize_config
from mlx.utils import tree_flatten, tree_map, tree_unflatten
def torch_to_mx(a: torch.Tensor, *, dtype: str) -> mx.array:
# bfloat16 is not numpy convertible. Upcast to float32 to avoid precision loss
a = a.to(torch.float32) if dtype == 'bfloat16' else a.to(getattr(torch, dtype))
return mx.array(a.numpy(), getattr(mx, dtype))
def llama(model_path):
def llama(model_path, *, dtype: str):
SHARD_FIRST = ["wv", "wq", "wk", "w1", "w3", "output"]
SHARD_SECOND = ["tok_embeddings", "wo", "w2"]
SHARD_WEIGHTS = set(SHARD_FIRST + SHARD_SECOND)
@ -37,14 +40,15 @@ def llama(model_path):
axis = 1
else:
raise ValueError("Invalid weight name")
return np.concatenate(v, axis=axis)
return mx.concatenate(v, axis=axis)
torch_files = glob.glob(str(model_path / "consolidated.*.pth"))
weights = collections.defaultdict(list)
for wf in torch_files:
state = torch.load(wf, map_location=torch.device("cpu"))
for k, v in state.items():
v = v.to(torch.float16).numpy()
v = torch_to_mx(v, dtype=dtype)
state[k] = None # free memory
if shard_key(k) in SHARD_WEIGHTS:
weights[k].append(v)
else:
@ -57,7 +61,7 @@ def llama(model_path):
return weights, params
def tiny_llama(model_path):
def tiny_llama(model_path, *, dtype: str):
try:
import transformers
except ImportError:
@ -113,7 +117,7 @@ def tiny_llama(model_path):
params["vocab_size"] = config.vocab_size
params["norm_eps"] = config.rms_norm_eps
params["rope_traditional"] = False
weights = {k: v.to(torch.float16).numpy() for k, v in model.items()}
weights = {k: torch_to_mx(v, dtype=dtype) for k, v in model.items()}
return weights, params
@ -197,6 +201,13 @@ if __name__ == "__main__":
type=int,
default=4,
)
parser.add_argument(
"--dtype",
help="dtype for loading the torch model and input for quantization or saving the converted model. "
"The original weights are stored in bfloat16.",
type=str,
default="float16",
)
args = parser.parse_args()
@ -205,7 +216,7 @@ if __name__ == "__main__":
mlx_path.mkdir(parents=True, exist_ok=True)
print("[INFO] Loading")
weights, params = globals()[args.model_name](torch_path)
weights, params = globals()[args.model_name](torch_path, dtype=args.dtype)
params["model_type"] = "llama"
if args.quantize:
print("[INFO] Quantizing")
@ -218,9 +229,9 @@ if __name__ == "__main__":
)
shards = make_shards(weights)
if len(shards) == 1:
np.savez(str(mlx_path / f"weights.npz"), **shards[0])
mx.savez(str(mlx_path / f"weights.npz"), **shards[0])
else:
for i, shard in enumerate(shards):
np.savez(str(mlx_path / f"weights.{i:02d}.npz"), **shard)
mx.savez(str(mlx_path / f"weights.{i:02d}.npz"), **shard)
with open(mlx_path / "config.json", "w") as fid:
json.dump(params, fid, indent=4)