mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-24 09:21:18 +08:00
keep dtype on model conversion (#186)
This commit is contained in:
parent
85258b2be7
commit
1d09c4fecd
@ -10,13 +10,16 @@ from pathlib import Path
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
import numpy as np
|
||||
import torch
|
||||
from llama import Llama, ModelArgs, sanitize_config
|
||||
from mlx.utils import tree_flatten, tree_map, tree_unflatten
|
||||
|
||||
def torch_to_mx(a: torch.Tensor, *, dtype: str) -> mx.array:
|
||||
# bfloat16 is not numpy convertible. Upcast to float32 to avoid precision loss
|
||||
a = a.to(torch.float32) if dtype == 'bfloat16' else a.to(getattr(torch, dtype))
|
||||
return mx.array(a.numpy(), getattr(mx, dtype))
|
||||
|
||||
def llama(model_path):
|
||||
def llama(model_path, *, dtype: str):
|
||||
SHARD_FIRST = ["wv", "wq", "wk", "w1", "w3", "output"]
|
||||
SHARD_SECOND = ["tok_embeddings", "wo", "w2"]
|
||||
SHARD_WEIGHTS = set(SHARD_FIRST + SHARD_SECOND)
|
||||
@ -37,14 +40,15 @@ def llama(model_path):
|
||||
axis = 1
|
||||
else:
|
||||
raise ValueError("Invalid weight name")
|
||||
return np.concatenate(v, axis=axis)
|
||||
return mx.concatenate(v, axis=axis)
|
||||
|
||||
torch_files = glob.glob(str(model_path / "consolidated.*.pth"))
|
||||
weights = collections.defaultdict(list)
|
||||
for wf in torch_files:
|
||||
state = torch.load(wf, map_location=torch.device("cpu"))
|
||||
for k, v in state.items():
|
||||
v = v.to(torch.float16).numpy()
|
||||
v = torch_to_mx(v, dtype=dtype)
|
||||
state[k] = None # free memory
|
||||
if shard_key(k) in SHARD_WEIGHTS:
|
||||
weights[k].append(v)
|
||||
else:
|
||||
@ -57,7 +61,7 @@ def llama(model_path):
|
||||
return weights, params
|
||||
|
||||
|
||||
def tiny_llama(model_path):
|
||||
def tiny_llama(model_path, *, dtype: str):
|
||||
try:
|
||||
import transformers
|
||||
except ImportError:
|
||||
@ -113,7 +117,7 @@ def tiny_llama(model_path):
|
||||
params["vocab_size"] = config.vocab_size
|
||||
params["norm_eps"] = config.rms_norm_eps
|
||||
params["rope_traditional"] = False
|
||||
weights = {k: v.to(torch.float16).numpy() for k, v in model.items()}
|
||||
weights = {k: torch_to_mx(v, dtype=dtype) for k, v in model.items()}
|
||||
|
||||
return weights, params
|
||||
|
||||
@ -197,6 +201,13 @@ if __name__ == "__main__":
|
||||
type=int,
|
||||
default=4,
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dtype",
|
||||
help="dtype for loading the torch model and input for quantization or saving the converted model. "
|
||||
"The original weights are stored in bfloat16.",
|
||||
type=str,
|
||||
default="float16",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
@ -205,7 +216,7 @@ if __name__ == "__main__":
|
||||
mlx_path.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
print("[INFO] Loading")
|
||||
weights, params = globals()[args.model_name](torch_path)
|
||||
weights, params = globals()[args.model_name](torch_path, dtype=args.dtype)
|
||||
params["model_type"] = "llama"
|
||||
if args.quantize:
|
||||
print("[INFO] Quantizing")
|
||||
@ -218,9 +229,9 @@ if __name__ == "__main__":
|
||||
)
|
||||
shards = make_shards(weights)
|
||||
if len(shards) == 1:
|
||||
np.savez(str(mlx_path / f"weights.npz"), **shards[0])
|
||||
mx.savez(str(mlx_path / f"weights.npz"), **shards[0])
|
||||
else:
|
||||
for i, shard in enumerate(shards):
|
||||
np.savez(str(mlx_path / f"weights.{i:02d}.npz"), **shard)
|
||||
mx.savez(str(mlx_path / f"weights.{i:02d}.npz"), **shard)
|
||||
with open(mlx_path / "config.json", "w") as fid:
|
||||
json.dump(params, fid, indent=4)
|
||||
|
Loading…
Reference in New Issue
Block a user