support for tiny llama (#129)

This commit is contained in:
Awni Hannun 2023-12-18 07:47:55 -08:00 committed by GitHub
parent 08e862336a
commit 44b546d446
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 140 additions and 45 deletions

View File

@ -3,8 +3,8 @@
An example of generating text with Llama (1 or 2) using MLX.
Llama is a set of open source language models from Meta AI Research[^1][^2]
ranging from 7B to 70B parameters. This example also supports Llama Chat and
Code Llama.
ranging from 7B to 70B parameters. This example also supports Meta's Llama Chat
and Code Llama models, as well as the 1.1B TinyLlama models from SUTD.[^3]
### Setup
@ -25,10 +25,19 @@ Alternatively, you can also download a select converted checkpoints from the
[mlx-llama](https://huggingface.co/mlx-llama) community organisation on Hugging
Face and skip the conversion step.
You can download the TinyLlama models directly from [Hugging
Face](https://huggingface.co/TinyLlama).
Convert the weights with:
```
python convert.py --model_path <path_to_torch_model>
python convert.py --model-path <path_to_torch_model>
```
For TinyLlama use
```
python convert.py --model-path <path_to_torch_model> --model-name tiny_llama
```
The conversion script will save the converted weights in the same location.
@ -39,10 +48,11 @@ Once you've converted the weights to MLX format, you can interact with the
LlaMA model:
```
python llama.py <path_to_model> <path_to_tokenizer.model> "hello"
python llama.py <path_to_model> <path_to_tokenizer.model> --prompt "hello"
```
Run `python llama.py --help` for more details.
[^1]: For Llama v1 refer to the [arXiv paper](https://arxiv.org/abs/2302.13971) and [blog post](https://ai.meta.com/blog/large-language-model-llama-meta-ai/) for more details.
[^2]: For Llama v2 refer to the [blob post](https://ai.meta.com/llama/)
[^3]: For TinyLlama refer to the [gihub repository](https://github.com/jzhang38/TinyLlama?tab=readme-ov-file)

View File

@ -3,45 +3,35 @@
import argparse
import collections
import glob
from pathlib import Path
import json
import numpy as np
from pathlib import Path
import torch
SHARD_FIRST = ["wv", "wq", "wk", "w1", "w3", "output"]
SHARD_SECOND = ["tok_embeddings", "wo", "w2"]
SHARD_WEIGHTS = set(SHARD_FIRST + SHARD_SECOND)
def llama(model_path):
SHARD_FIRST = ["wv", "wq", "wk", "w1", "w3", "output"]
SHARD_SECOND = ["tok_embeddings", "wo", "w2"]
SHARD_WEIGHTS = set(SHARD_FIRST + SHARD_SECOND)
def shard_key(k):
keys = k.split(".")
if len(keys) < 2:
return None
return keys[-2]
def shard_key(k):
keys = k.split(".")
if len(keys) < 2:
return None
return keys[-2]
def unshard(k, v):
wn = shard_key(k)
if wn not in SHARD_WEIGHTS:
return v
elif wn in SHARD_FIRST:
axis = 0
elif wn in SHARD_SECOND:
axis = 1
else:
raise ValueError("Invalid weight name")
return np.concatenate(v, axis=axis)
def unshard(k, v):
wn = shard_key(k)
if wn not in SHARD_WEIGHTS:
return v
elif wn in SHARD_FIRST:
axis = 0
elif wn in SHARD_SECOND:
axis = 1
else:
raise ValueError("Invalid weight name")
return np.concatenate(v, axis=axis)
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Convert Llama weights to MLX")
parser.add_argument(
"--model_path",
help="Path to the Torch model. The MLX weights will also be saved there.",
)
args = parser.parse_args()
model_path = Path(args.model_path)
torch_files = glob.glob(str(model_path / "consolidated.*.pth"))
weights = collections.defaultdict(list)
for wf in torch_files:
@ -53,7 +43,96 @@ if __name__ == "__main__":
else:
weights[k] = v
out_file = str(model_path / "weights.npz")
for k, v in weights.items():
weights[k] = unshard(k, v)
np.savez(out_file, **weights)
return weights, None
def tiny_llama(model_path):
try:
import transformers
except ImportError as e:
print("The transformers package must be installed for this model conversion:")
print("pip install transformers")
import sys
sys.exit(0)
model = transformers.AutoModelForCausalLM.from_pretrained(
str(model_path)
).state_dict()
config = transformers.AutoConfig.from_pretrained(model_path)
# things to change
# 1. there's no "model." in the weight names
model = {k.replace("model.", ""): v for k, v in model.items()}
# 2. mlp is called feed_forward
model = {k.replace("mlp", "feed_forward"): v for k, v in model.items()}
# 3. up_proj, down_proj, gate_proj
model = {k.replace("down_proj", "w2"): v for k, v in model.items()}
model = {k.replace("up_proj", "w3"): v for k, v in model.items()}
model = {k.replace("gate_proj", "w1"): v for k, v in model.items()}
# 4. layernorms
model = {
k.replace("input_layernorm", "attention_norm"): v for k, v in model.items()
}
model = {
k.replace("post_attention_layernorm", "ffn_norm"): v for k, v in model.items()
}
# 5. lm head
model = {k.replace("lm_head", "output"): v for k, v in model.items()}
# 6. token emb
model = {k.replace("embed_tokens", "tok_embeddings"): v for k, v in model.items()}
# 7. attention
model = {k.replace("self_attn", "attention"): v for k, v in model.items()}
model = {k.replace("q_proj", "wq"): v for k, v in model.items()}
model = {k.replace("k_proj", "wk"): v for k, v in model.items()}
model = {k.replace("v_proj", "wv"): v for k, v in model.items()}
model = {k.replace("o_proj", "wo"): v for k, v in model.items()}
params = {}
params["dim"] = config.hidden_size
params["hidden_dim"] = config.intermediate_size
params["n_heads"] = config.num_attention_heads
if hasattr(config, "num_key_value_heads"):
params["n_kv_heads"] = config.num_key_value_heads
params["n_layers"] = config.num_hidden_layers
params["vocab_size"] = config.vocab_size
params["norm_eps"] = config.rms_norm_eps
params["rope_traditional"] = False
weights = {k: v.to(torch.float16).numpy() for k, v in model.items()}
return weights, params
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Convert Llama weights to MLX")
parser.add_argument(
"--model-path",
help="Path to the model. The MLX weights will also be saved there.",
)
parser.add_argument(
"--model-name",
help=(
"Name of the model to convert. Use 'llama' for models in the "
"Llama family distributed by Meta including Llama 1, Llama 2, "
"Coda Llama, and Llama chat."
),
choices=["tiny_llama", "llama"],
default="llama",
)
args = parser.parse_args()
model_path = Path(args.model_path)
weights, params = globals()[args.model_name](model_path)
np.savez(str(model_path / "weights.npz"), **weights)
if params is not None:
with open(model_path / "params.json", "w") as fid:
json.dump(params, fid, indent=4)

View File

@ -24,6 +24,7 @@ class ModelArgs:
norm_eps: float
vocab_size: int
rope_theta: float
rope_traditional: bool = True
class RMSNorm(nn.Module):
@ -77,7 +78,9 @@ class Attention(nn.Module):
self.wk = nn.Linear(args.dim, args.n_kv_heads * args.head_dim, bias=False)
self.wv = nn.Linear(args.dim, args.n_kv_heads * args.head_dim, bias=False)
self.wo = nn.Linear(args.n_heads * args.head_dim, args.dim, bias=False)
self.rope = RoPE(args.head_dim, traditional=True, base=args.rope_theta)
self.rope = RoPE(
args.head_dim, traditional=args.rope_traditional, base=args.rope_theta
)
def __call__(
self,
@ -234,7 +237,7 @@ def generate(args):
input("Press enter to start generation")
print("------")
print(args.prompt)
x = mx.array([[tokenizer.bos_id()] + tokenizer.encode(args.prompt)])
skip = 0
prompt_processing = None
@ -248,7 +251,7 @@ def generate(args):
mx.eval(token)
prompt_processing = toc("Prompt processing", start)
if len(tokens) >= args.num_tokens:
if len(tokens) >= args.max_tokens:
break
elif (len(tokens) % args.write_every) == 0:
@ -261,8 +264,7 @@ def generate(args):
mx.eval(tokens)
full_gen = toc("Full generation", start)
s = tokenizer.decode([t.item() for t in tokens])
print(s[skip:], end="", flush=True)
print()
print(s[skip:], flush=True)
print("------")
print(prompt_processing)
print(full_gen)
@ -354,14 +356,18 @@ if __name__ == "__main__":
"model", help="Path to the model directory containing the MLX weights"
)
parser.add_argument("tokenizer", help="The sentencepiece tokenizer")
parser.add_argument("prompt", help="The message to be processed by the model")
parser.add_argument(
"--prompt",
help="The message to be processed by the model",
default="In the beginning the Universe was created.",
)
parser.add_argument(
"--few-shot",
action="store_true",
help="Read a few shot prompt from a file (as in `sample_prompt.txt`).",
)
parser.add_argument(
"--num-tokens", "-n", type=int, default=100, help="How many tokens to generate"
"--max-tokens", "-m", type=int, default=100, help="How many tokens to generate"
)
parser.add_argument(
"--write-every", type=int, default=1, help="After how many tokens to detokenize"