From 1d09c4fecd085a635f2744d63c57b37a6b1fa651 Mon Sep 17 00:00:00 2001 From: Daniel Strobusch <1847260+dastrobu@users.noreply.github.com> Date: Tue, 2 Jan 2024 20:20:29 +0100 Subject: [PATCH] keep dtype on model conversion (#186) --- llms/llama/convert.py | 29 ++++++++++++++++++++--------- 1 file changed, 20 insertions(+), 9 deletions(-) diff --git a/llms/llama/convert.py b/llms/llama/convert.py index 80f0ac13..c5e6e773 100644 --- a/llms/llama/convert.py +++ b/llms/llama/convert.py @@ -10,13 +10,16 @@ from pathlib import Path import mlx.core as mx import mlx.nn as nn -import numpy as np import torch from llama import Llama, ModelArgs, sanitize_config from mlx.utils import tree_flatten, tree_map, tree_unflatten +def torch_to_mx(a: torch.Tensor, *, dtype: str) -> mx.array: + # bfloat16 is not numpy convertible. Upcast to float32 to avoid precision loss + a = a.to(torch.float32) if dtype == 'bfloat16' else a.to(getattr(torch, dtype)) + return mx.array(a.numpy(), getattr(mx, dtype)) -def llama(model_path): +def llama(model_path, *, dtype: str): SHARD_FIRST = ["wv", "wq", "wk", "w1", "w3", "output"] SHARD_SECOND = ["tok_embeddings", "wo", "w2"] SHARD_WEIGHTS = set(SHARD_FIRST + SHARD_SECOND) @@ -37,14 +40,15 @@ def llama(model_path): axis = 1 else: raise ValueError("Invalid weight name") - return np.concatenate(v, axis=axis) + return mx.concatenate(v, axis=axis) torch_files = glob.glob(str(model_path / "consolidated.*.pth")) weights = collections.defaultdict(list) for wf in torch_files: state = torch.load(wf, map_location=torch.device("cpu")) for k, v in state.items(): - v = v.to(torch.float16).numpy() + v = torch_to_mx(v, dtype=dtype) + state[k] = None # free memory if shard_key(k) in SHARD_WEIGHTS: weights[k].append(v) else: @@ -57,7 +61,7 @@ def llama(model_path): return weights, params -def tiny_llama(model_path): +def tiny_llama(model_path, *, dtype: str): try: import transformers except ImportError: @@ -113,7 +117,7 @@ def tiny_llama(model_path): params["vocab_size"] = config.vocab_size params["norm_eps"] = config.rms_norm_eps params["rope_traditional"] = False - weights = {k: v.to(torch.float16).numpy() for k, v in model.items()} + weights = {k: torch_to_mx(v, dtype=dtype) for k, v in model.items()} return weights, params @@ -197,6 +201,13 @@ if __name__ == "__main__": type=int, default=4, ) + parser.add_argument( + "--dtype", + help="dtype for loading the torch model and input for quantization or saving the converted model. " + "The original weights are stored in bfloat16.", + type=str, + default="float16", + ) args = parser.parse_args() @@ -205,7 +216,7 @@ if __name__ == "__main__": mlx_path.mkdir(parents=True, exist_ok=True) print("[INFO] Loading") - weights, params = globals()[args.model_name](torch_path) + weights, params = globals()[args.model_name](torch_path, dtype=args.dtype) params["model_type"] = "llama" if args.quantize: print("[INFO] Quantizing") @@ -218,9 +229,9 @@ if __name__ == "__main__": ) shards = make_shards(weights) if len(shards) == 1: - np.savez(str(mlx_path / f"weights.npz"), **shards[0]) + mx.savez(str(mlx_path / f"weights.npz"), **shards[0]) else: for i, shard in enumerate(shards): - np.savez(str(mlx_path / f"weights.{i:02d}.npz"), **shard) + mx.savez(str(mlx_path / f"weights.{i:02d}.npz"), **shard) with open(mlx_path / "config.json", "w") as fid: json.dump(params, fid, indent=4)