mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-10-24 06:28:07 +08:00
@@ -14,11 +14,13 @@ import torch
|
||||
from llama import Llama, ModelArgs, sanitize_config
|
||||
from mlx.utils import tree_flatten, tree_map, tree_unflatten
|
||||
|
||||
|
||||
def torch_to_mx(a: torch.Tensor, *, dtype: str) -> mx.array:
|
||||
# bfloat16 is not numpy convertible. Upcast to float32 to avoid precision loss
|
||||
a = a.to(torch.float32) if dtype == 'bfloat16' else a.to(getattr(torch, dtype))
|
||||
a = a.to(torch.float32) if dtype == "bfloat16" else a.to(getattr(torch, dtype))
|
||||
return mx.array(a.numpy(), getattr(mx, dtype))
|
||||
|
||||
|
||||
def llama(model_path, *, dtype: str):
|
||||
SHARD_FIRST = ["wv", "wq", "wk", "w1", "w3", "output"]
|
||||
SHARD_SECOND = ["tok_embeddings", "wo", "w2"]
|
||||
@@ -48,7 +50,7 @@ def llama(model_path, *, dtype: str):
|
||||
state = torch.load(wf, map_location=torch.device("cpu"))
|
||||
for k, v in state.items():
|
||||
v = torch_to_mx(v, dtype=dtype)
|
||||
state[k] = None # free memory
|
||||
state[k] = None # free memory
|
||||
if shard_key(k) in SHARD_WEIGHTS:
|
||||
weights[k].append(v)
|
||||
else:
|
||||
@@ -204,7 +206,7 @@ if __name__ == "__main__":
|
||||
parser.add_argument(
|
||||
"--dtype",
|
||||
help="dtype for loading the torch model and input for quantization or saving the converted model. "
|
||||
"The original weights are stored in bfloat16.",
|
||||
"The original weights are stored in bfloat16.",
|
||||
type=str,
|
||||
default="float16",
|
||||
)
|
||||
|
Reference in New Issue
Block a user