mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-08-31 11:54:37 +08:00
Merge branch 'ml-explore:main' into prompt_lookup
This commit is contained in:
commit
cb4464bb7b
@ -1,2 +1,3 @@
|
||||
mlx
|
||||
mlx-data
|
||||
mlx-data
|
||||
numpy
|
@ -15,7 +15,6 @@ import torch
|
||||
from llama import Llama, ModelArgs, sanitize_config
|
||||
from mlx.utils import tree_flatten, tree_map, tree_unflatten
|
||||
|
||||
|
||||
def llama(model_path):
|
||||
SHARD_FIRST = ["wv", "wq", "wk", "w1", "w3", "output"]
|
||||
SHARD_SECOND = ["tok_embeddings", "wo", "w2"]
|
||||
@ -140,6 +139,22 @@ def quantize(weights, config, args):
|
||||
return quantized_weights, quantized_config
|
||||
|
||||
|
||||
def make_shards(weights: dict, max_file_size_gibibyte: int = 15):
|
||||
max_file_size_bytes = max_file_size_gibibyte << 30
|
||||
shards = []
|
||||
shard, shard_size = {}, 0
|
||||
for k, v in weights.items():
|
||||
# TODO: simplify to v.nbytes as soon as mx.array exposes it
|
||||
estimated_size = v.size * v.dtype.size if isinstance(v, mx.array) else v.nbytes
|
||||
if shard_size + estimated_size > max_file_size_bytes:
|
||||
shards.append(shard)
|
||||
shard, shard_size = {}, 0
|
||||
shard[k] = v
|
||||
shard_size += estimated_size
|
||||
shards.append(shard)
|
||||
return shards
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="Convert Llama weights to MLX")
|
||||
parser.add_argument(
|
||||
@ -200,6 +215,11 @@ if __name__ == "__main__":
|
||||
str(torch_path / "tokenizer.model"),
|
||||
str(mlx_path / "tokenizer.model"),
|
||||
)
|
||||
np.savez(str(mlx_path / "weights.npz"), **weights)
|
||||
shards = make_shards(weights)
|
||||
if len(shards) == 1:
|
||||
np.savez(str(mlx_path / f"weights.npz"), **shards[0])
|
||||
else:
|
||||
for i, shard in enumerate(shards):
|
||||
np.savez(str(mlx_path / f"weights.{i:02d}.npz"), **shard)
|
||||
with open(mlx_path / "config.json", "w") as fid:
|
||||
json.dump(params, fid, indent=4)
|
||||
|
@ -3,6 +3,7 @@
|
||||
import argparse
|
||||
import json
|
||||
import time
|
||||
import glob
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Optional, Tuple
|
||||
@ -66,7 +67,7 @@ class Attention(nn.Module):
|
||||
x: mx.array,
|
||||
mask: Optional[mx.array] = None,
|
||||
cache: Optional[Tuple[mx.array, mx.array]] = None,
|
||||
) -> mx.array:
|
||||
) -> Tuple[mx.array, Tuple[mx.array, mx.array]]:
|
||||
B, L, D = x.shape
|
||||
|
||||
queries, keys, values = self.wq(x), self.wk(x), self.wv(x)
|
||||
@ -330,7 +331,23 @@ def sanitize_config(config, weights):
|
||||
|
||||
def load_model(model_path):
|
||||
model_path = Path(model_path)
|
||||
weights = mx.load(str(model_path / "weights.npz"))
|
||||
|
||||
unsharded_weights_path = Path(model_path / "weights.npz")
|
||||
if unsharded_weights_path.is_file():
|
||||
print("[INFO] Loading model from {}.".format(unsharded_weights_path))
|
||||
weights = mx.load(str(unsharded_weights_path))
|
||||
else:
|
||||
sharded_weights_glob = str(model_path / "weights.*.npz")
|
||||
weight_files = glob.glob(sharded_weights_glob)
|
||||
print("[INFO] Loading model from {}.".format(sharded_weights_glob))
|
||||
|
||||
if len(weight_files) == 0:
|
||||
raise FileNotFoundError("No weights found in {}".format(model_path))
|
||||
|
||||
weights = {}
|
||||
for wf in weight_files:
|
||||
weights.update(mx.load(wf).items())
|
||||
|
||||
with open(model_path / "config.json", "r") as f:
|
||||
config = sanitize_config(json.loads(f.read()), weights)
|
||||
quantization = config.pop("quantization", None)
|
||||
@ -373,7 +390,6 @@ if __name__ == "__main__":
|
||||
|
||||
mx.random.seed(args.seed)
|
||||
|
||||
print("[INFO] Loading model from disk.")
|
||||
model, tokenizer = load_model(args.model_path)
|
||||
if args.few_shot:
|
||||
few_shot_generate(args)
|
||||
|
@ -61,7 +61,7 @@ the converted `weights.npz`, `tokenizer.model`, and `config.json` there.
|
||||
As easy as:
|
||||
|
||||
```
|
||||
python mixtral.py --model-path $MIXTRAL_MODEL/
|
||||
python mixtral.py --model-path mlx_model
|
||||
```
|
||||
|
||||
For more options including how to prompt the model, run:
|
||||
|
@ -60,7 +60,7 @@ def convert(args):
|
||||
args.model, trust_remote_code=True, torch_dtype=torch.float16
|
||||
)
|
||||
state_dict = model.state_dict()
|
||||
weights = {replace_key(k): v.numpy() for k, v in state_dict.items()}
|
||||
weights = {replace_key(k): (v.numpy() if v.dtype != torch.bfloat16 else v.to(torch.float32).numpy()) for k, v in state_dict.items()}
|
||||
config = model.config.to_dict()
|
||||
|
||||
if args.quantize:
|
||||
|
@ -41,8 +41,8 @@ def decode(model, mels):
|
||||
return decoding.decode(model, mels)
|
||||
|
||||
|
||||
def everything():
|
||||
return transcribe(audio_file)
|
||||
def everything(model_name):
|
||||
return transcribe(audio_file, model=model_name)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
@ -99,6 +99,6 @@ if __name__ == "__main__":
|
||||
print(f"Model forward time {model_forward_time:.3f}")
|
||||
decode_time = timer(decode, model, mels)
|
||||
print(f"Decode time {decode_time:.3f}")
|
||||
everything_time = timer(everything)
|
||||
everything_time = timer(everything, model_name)
|
||||
print(f"Everything time {everything_time:.3f}")
|
||||
print(f"\n{'-----' * 10}\n")
|
||||
|
Loading…
Reference in New Issue
Block a user