mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-24 09:21:18 +08:00
support for tiny llama (#129)
This commit is contained in:
parent
08e862336a
commit
44b546d446
@ -3,8 +3,8 @@
|
||||
An example of generating text with Llama (1 or 2) using MLX.
|
||||
|
||||
Llama is a set of open source language models from Meta AI Research[^1][^2]
|
||||
ranging from 7B to 70B parameters. This example also supports Llama Chat and
|
||||
Code Llama.
|
||||
ranging from 7B to 70B parameters. This example also supports Meta's Llama Chat
|
||||
and Code Llama models, as well as the 1.1B TinyLlama models from SUTD.[^3]
|
||||
|
||||
### Setup
|
||||
|
||||
@ -25,10 +25,19 @@ Alternatively, you can also download a select converted checkpoints from the
|
||||
[mlx-llama](https://huggingface.co/mlx-llama) community organisation on Hugging
|
||||
Face and skip the conversion step.
|
||||
|
||||
You can download the TinyLlama models directly from [Hugging
|
||||
Face](https://huggingface.co/TinyLlama).
|
||||
|
||||
Convert the weights with:
|
||||
|
||||
```
|
||||
python convert.py --model_path <path_to_torch_model>
|
||||
python convert.py --model-path <path_to_torch_model>
|
||||
```
|
||||
|
||||
For TinyLlama use
|
||||
|
||||
```
|
||||
python convert.py --model-path <path_to_torch_model> --model-name tiny_llama
|
||||
```
|
||||
|
||||
The conversion script will save the converted weights in the same location.
|
||||
@ -39,10 +48,11 @@ Once you've converted the weights to MLX format, you can interact with the
|
||||
LlaMA model:
|
||||
|
||||
```
|
||||
python llama.py <path_to_model> <path_to_tokenizer.model> "hello"
|
||||
python llama.py <path_to_model> <path_to_tokenizer.model> --prompt "hello"
|
||||
```
|
||||
|
||||
Run `python llama.py --help` for more details.
|
||||
|
||||
[^1]: For Llama v1 refer to the [arXiv paper](https://arxiv.org/abs/2302.13971) and [blog post](https://ai.meta.com/blog/large-language-model-llama-meta-ai/) for more details.
|
||||
[^2]: For Llama v2 refer to the [blob post](https://ai.meta.com/llama/)
|
||||
[^3]: For TinyLlama refer to the [gihub repository](https://github.com/jzhang38/TinyLlama?tab=readme-ov-file)
|
||||
|
147
llama/convert.py
147
llama/convert.py
@ -3,45 +3,35 @@
|
||||
import argparse
|
||||
import collections
|
||||
import glob
|
||||
from pathlib import Path
|
||||
|
||||
import json
|
||||
import numpy as np
|
||||
from pathlib import Path
|
||||
import torch
|
||||
|
||||
SHARD_FIRST = ["wv", "wq", "wk", "w1", "w3", "output"]
|
||||
SHARD_SECOND = ["tok_embeddings", "wo", "w2"]
|
||||
SHARD_WEIGHTS = set(SHARD_FIRST + SHARD_SECOND)
|
||||
|
||||
def llama(model_path):
|
||||
SHARD_FIRST = ["wv", "wq", "wk", "w1", "w3", "output"]
|
||||
SHARD_SECOND = ["tok_embeddings", "wo", "w2"]
|
||||
SHARD_WEIGHTS = set(SHARD_FIRST + SHARD_SECOND)
|
||||
|
||||
def shard_key(k):
|
||||
keys = k.split(".")
|
||||
if len(keys) < 2:
|
||||
return None
|
||||
return keys[-2]
|
||||
def shard_key(k):
|
||||
keys = k.split(".")
|
||||
if len(keys) < 2:
|
||||
return None
|
||||
return keys[-2]
|
||||
|
||||
def unshard(k, v):
|
||||
wn = shard_key(k)
|
||||
if wn not in SHARD_WEIGHTS:
|
||||
return v
|
||||
elif wn in SHARD_FIRST:
|
||||
axis = 0
|
||||
elif wn in SHARD_SECOND:
|
||||
axis = 1
|
||||
else:
|
||||
raise ValueError("Invalid weight name")
|
||||
return np.concatenate(v, axis=axis)
|
||||
|
||||
def unshard(k, v):
|
||||
wn = shard_key(k)
|
||||
if wn not in SHARD_WEIGHTS:
|
||||
return v
|
||||
elif wn in SHARD_FIRST:
|
||||
axis = 0
|
||||
elif wn in SHARD_SECOND:
|
||||
axis = 1
|
||||
else:
|
||||
raise ValueError("Invalid weight name")
|
||||
return np.concatenate(v, axis=axis)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="Convert Llama weights to MLX")
|
||||
parser.add_argument(
|
||||
"--model_path",
|
||||
help="Path to the Torch model. The MLX weights will also be saved there.",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
model_path = Path(args.model_path)
|
||||
torch_files = glob.glob(str(model_path / "consolidated.*.pth"))
|
||||
weights = collections.defaultdict(list)
|
||||
for wf in torch_files:
|
||||
@ -53,7 +43,96 @@ if __name__ == "__main__":
|
||||
else:
|
||||
weights[k] = v
|
||||
|
||||
out_file = str(model_path / "weights.npz")
|
||||
for k, v in weights.items():
|
||||
weights[k] = unshard(k, v)
|
||||
np.savez(out_file, **weights)
|
||||
return weights, None
|
||||
|
||||
|
||||
def tiny_llama(model_path):
|
||||
try:
|
||||
import transformers
|
||||
except ImportError as e:
|
||||
print("The transformers package must be installed for this model conversion:")
|
||||
print("pip install transformers")
|
||||
import sys
|
||||
|
||||
sys.exit(0)
|
||||
|
||||
model = transformers.AutoModelForCausalLM.from_pretrained(
|
||||
str(model_path)
|
||||
).state_dict()
|
||||
config = transformers.AutoConfig.from_pretrained(model_path)
|
||||
|
||||
# things to change
|
||||
# 1. there's no "model." in the weight names
|
||||
model = {k.replace("model.", ""): v for k, v in model.items()}
|
||||
|
||||
# 2. mlp is called feed_forward
|
||||
model = {k.replace("mlp", "feed_forward"): v for k, v in model.items()}
|
||||
|
||||
# 3. up_proj, down_proj, gate_proj
|
||||
model = {k.replace("down_proj", "w2"): v for k, v in model.items()}
|
||||
model = {k.replace("up_proj", "w3"): v for k, v in model.items()}
|
||||
model = {k.replace("gate_proj", "w1"): v for k, v in model.items()}
|
||||
|
||||
# 4. layernorms
|
||||
model = {
|
||||
k.replace("input_layernorm", "attention_norm"): v for k, v in model.items()
|
||||
}
|
||||
model = {
|
||||
k.replace("post_attention_layernorm", "ffn_norm"): v for k, v in model.items()
|
||||
}
|
||||
|
||||
# 5. lm head
|
||||
model = {k.replace("lm_head", "output"): v for k, v in model.items()}
|
||||
|
||||
# 6. token emb
|
||||
model = {k.replace("embed_tokens", "tok_embeddings"): v for k, v in model.items()}
|
||||
|
||||
# 7. attention
|
||||
model = {k.replace("self_attn", "attention"): v for k, v in model.items()}
|
||||
model = {k.replace("q_proj", "wq"): v for k, v in model.items()}
|
||||
model = {k.replace("k_proj", "wk"): v for k, v in model.items()}
|
||||
model = {k.replace("v_proj", "wv"): v for k, v in model.items()}
|
||||
model = {k.replace("o_proj", "wo"): v for k, v in model.items()}
|
||||
|
||||
params = {}
|
||||
params["dim"] = config.hidden_size
|
||||
params["hidden_dim"] = config.intermediate_size
|
||||
params["n_heads"] = config.num_attention_heads
|
||||
if hasattr(config, "num_key_value_heads"):
|
||||
params["n_kv_heads"] = config.num_key_value_heads
|
||||
params["n_layers"] = config.num_hidden_layers
|
||||
params["vocab_size"] = config.vocab_size
|
||||
params["norm_eps"] = config.rms_norm_eps
|
||||
params["rope_traditional"] = False
|
||||
weights = {k: v.to(torch.float16).numpy() for k, v in model.items()}
|
||||
|
||||
return weights, params
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="Convert Llama weights to MLX")
|
||||
parser.add_argument(
|
||||
"--model-path",
|
||||
help="Path to the model. The MLX weights will also be saved there.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--model-name",
|
||||
help=(
|
||||
"Name of the model to convert. Use 'llama' for models in the "
|
||||
"Llama family distributed by Meta including Llama 1, Llama 2, "
|
||||
"Coda Llama, and Llama chat."
|
||||
),
|
||||
choices=["tiny_llama", "llama"],
|
||||
default="llama",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
model_path = Path(args.model_path)
|
||||
weights, params = globals()[args.model_name](model_path)
|
||||
np.savez(str(model_path / "weights.npz"), **weights)
|
||||
if params is not None:
|
||||
with open(model_path / "params.json", "w") as fid:
|
||||
json.dump(params, fid, indent=4)
|
||||
|
@ -24,6 +24,7 @@ class ModelArgs:
|
||||
norm_eps: float
|
||||
vocab_size: int
|
||||
rope_theta: float
|
||||
rope_traditional: bool = True
|
||||
|
||||
|
||||
class RMSNorm(nn.Module):
|
||||
@ -77,7 +78,9 @@ class Attention(nn.Module):
|
||||
self.wk = nn.Linear(args.dim, args.n_kv_heads * args.head_dim, bias=False)
|
||||
self.wv = nn.Linear(args.dim, args.n_kv_heads * args.head_dim, bias=False)
|
||||
self.wo = nn.Linear(args.n_heads * args.head_dim, args.dim, bias=False)
|
||||
self.rope = RoPE(args.head_dim, traditional=True, base=args.rope_theta)
|
||||
self.rope = RoPE(
|
||||
args.head_dim, traditional=args.rope_traditional, base=args.rope_theta
|
||||
)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
@ -234,7 +237,7 @@ def generate(args):
|
||||
|
||||
input("Press enter to start generation")
|
||||
print("------")
|
||||
|
||||
print(args.prompt)
|
||||
x = mx.array([[tokenizer.bos_id()] + tokenizer.encode(args.prompt)])
|
||||
skip = 0
|
||||
prompt_processing = None
|
||||
@ -248,7 +251,7 @@ def generate(args):
|
||||
mx.eval(token)
|
||||
prompt_processing = toc("Prompt processing", start)
|
||||
|
||||
if len(tokens) >= args.num_tokens:
|
||||
if len(tokens) >= args.max_tokens:
|
||||
break
|
||||
|
||||
elif (len(tokens) % args.write_every) == 0:
|
||||
@ -261,8 +264,7 @@ def generate(args):
|
||||
mx.eval(tokens)
|
||||
full_gen = toc("Full generation", start)
|
||||
s = tokenizer.decode([t.item() for t in tokens])
|
||||
print(s[skip:], end="", flush=True)
|
||||
print()
|
||||
print(s[skip:], flush=True)
|
||||
print("------")
|
||||
print(prompt_processing)
|
||||
print(full_gen)
|
||||
@ -354,14 +356,18 @@ if __name__ == "__main__":
|
||||
"model", help="Path to the model directory containing the MLX weights"
|
||||
)
|
||||
parser.add_argument("tokenizer", help="The sentencepiece tokenizer")
|
||||
parser.add_argument("prompt", help="The message to be processed by the model")
|
||||
parser.add_argument(
|
||||
"--prompt",
|
||||
help="The message to be processed by the model",
|
||||
default="In the beginning the Universe was created.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--few-shot",
|
||||
action="store_true",
|
||||
help="Read a few shot prompt from a file (as in `sample_prompt.txt`).",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--num-tokens", "-n", type=int, default=100, help="How many tokens to generate"
|
||||
"--max-tokens", "-m", type=int, default=100, help="How many tokens to generate"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--write-every", type=int, default=1, help="After how many tokens to detokenize"
|
||||
|
Loading…
Reference in New Issue
Block a user