mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-24 09:21:18 +08:00
keep dtype on model conversion (#186)
This commit is contained in:
parent
85258b2be7
commit
1d09c4fecd
@ -10,13 +10,16 @@ from pathlib import Path
|
|||||||
|
|
||||||
import mlx.core as mx
|
import mlx.core as mx
|
||||||
import mlx.nn as nn
|
import mlx.nn as nn
|
||||||
import numpy as np
|
|
||||||
import torch
|
import torch
|
||||||
from llama import Llama, ModelArgs, sanitize_config
|
from llama import Llama, ModelArgs, sanitize_config
|
||||||
from mlx.utils import tree_flatten, tree_map, tree_unflatten
|
from mlx.utils import tree_flatten, tree_map, tree_unflatten
|
||||||
|
|
||||||
|
def torch_to_mx(a: torch.Tensor, *, dtype: str) -> mx.array:
|
||||||
|
# bfloat16 is not numpy convertible. Upcast to float32 to avoid precision loss
|
||||||
|
a = a.to(torch.float32) if dtype == 'bfloat16' else a.to(getattr(torch, dtype))
|
||||||
|
return mx.array(a.numpy(), getattr(mx, dtype))
|
||||||
|
|
||||||
def llama(model_path):
|
def llama(model_path, *, dtype: str):
|
||||||
SHARD_FIRST = ["wv", "wq", "wk", "w1", "w3", "output"]
|
SHARD_FIRST = ["wv", "wq", "wk", "w1", "w3", "output"]
|
||||||
SHARD_SECOND = ["tok_embeddings", "wo", "w2"]
|
SHARD_SECOND = ["tok_embeddings", "wo", "w2"]
|
||||||
SHARD_WEIGHTS = set(SHARD_FIRST + SHARD_SECOND)
|
SHARD_WEIGHTS = set(SHARD_FIRST + SHARD_SECOND)
|
||||||
@ -37,14 +40,15 @@ def llama(model_path):
|
|||||||
axis = 1
|
axis = 1
|
||||||
else:
|
else:
|
||||||
raise ValueError("Invalid weight name")
|
raise ValueError("Invalid weight name")
|
||||||
return np.concatenate(v, axis=axis)
|
return mx.concatenate(v, axis=axis)
|
||||||
|
|
||||||
torch_files = glob.glob(str(model_path / "consolidated.*.pth"))
|
torch_files = glob.glob(str(model_path / "consolidated.*.pth"))
|
||||||
weights = collections.defaultdict(list)
|
weights = collections.defaultdict(list)
|
||||||
for wf in torch_files:
|
for wf in torch_files:
|
||||||
state = torch.load(wf, map_location=torch.device("cpu"))
|
state = torch.load(wf, map_location=torch.device("cpu"))
|
||||||
for k, v in state.items():
|
for k, v in state.items():
|
||||||
v = v.to(torch.float16).numpy()
|
v = torch_to_mx(v, dtype=dtype)
|
||||||
|
state[k] = None # free memory
|
||||||
if shard_key(k) in SHARD_WEIGHTS:
|
if shard_key(k) in SHARD_WEIGHTS:
|
||||||
weights[k].append(v)
|
weights[k].append(v)
|
||||||
else:
|
else:
|
||||||
@ -57,7 +61,7 @@ def llama(model_path):
|
|||||||
return weights, params
|
return weights, params
|
||||||
|
|
||||||
|
|
||||||
def tiny_llama(model_path):
|
def tiny_llama(model_path, *, dtype: str):
|
||||||
try:
|
try:
|
||||||
import transformers
|
import transformers
|
||||||
except ImportError:
|
except ImportError:
|
||||||
@ -113,7 +117,7 @@ def tiny_llama(model_path):
|
|||||||
params["vocab_size"] = config.vocab_size
|
params["vocab_size"] = config.vocab_size
|
||||||
params["norm_eps"] = config.rms_norm_eps
|
params["norm_eps"] = config.rms_norm_eps
|
||||||
params["rope_traditional"] = False
|
params["rope_traditional"] = False
|
||||||
weights = {k: v.to(torch.float16).numpy() for k, v in model.items()}
|
weights = {k: torch_to_mx(v, dtype=dtype) for k, v in model.items()}
|
||||||
|
|
||||||
return weights, params
|
return weights, params
|
||||||
|
|
||||||
@ -197,6 +201,13 @@ if __name__ == "__main__":
|
|||||||
type=int,
|
type=int,
|
||||||
default=4,
|
default=4,
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--dtype",
|
||||||
|
help="dtype for loading the torch model and input for quantization or saving the converted model. "
|
||||||
|
"The original weights are stored in bfloat16.",
|
||||||
|
type=str,
|
||||||
|
default="float16",
|
||||||
|
)
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
@ -205,7 +216,7 @@ if __name__ == "__main__":
|
|||||||
mlx_path.mkdir(parents=True, exist_ok=True)
|
mlx_path.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
print("[INFO] Loading")
|
print("[INFO] Loading")
|
||||||
weights, params = globals()[args.model_name](torch_path)
|
weights, params = globals()[args.model_name](torch_path, dtype=args.dtype)
|
||||||
params["model_type"] = "llama"
|
params["model_type"] = "llama"
|
||||||
if args.quantize:
|
if args.quantize:
|
||||||
print("[INFO] Quantizing")
|
print("[INFO] Quantizing")
|
||||||
@ -218,9 +229,9 @@ if __name__ == "__main__":
|
|||||||
)
|
)
|
||||||
shards = make_shards(weights)
|
shards = make_shards(weights)
|
||||||
if len(shards) == 1:
|
if len(shards) == 1:
|
||||||
np.savez(str(mlx_path / f"weights.npz"), **shards[0])
|
mx.savez(str(mlx_path / f"weights.npz"), **shards[0])
|
||||||
else:
|
else:
|
||||||
for i, shard in enumerate(shards):
|
for i, shard in enumerate(shards):
|
||||||
np.savez(str(mlx_path / f"weights.{i:02d}.npz"), **shard)
|
mx.savez(str(mlx_path / f"weights.{i:02d}.npz"), **shard)
|
||||||
with open(mlx_path / "config.json", "w") as fid:
|
with open(mlx_path / "config.json", "w") as fid:
|
||||||
json.dump(params, fid, indent=4)
|
json.dump(params, fid, indent=4)
|
||||||
|
Loading…
Reference in New Issue
Block a user