support for tiny llama (#129)

This commit is contained in:
Awni Hannun 2023-12-18 07:47:55 -08:00 committed by GitHub
parent 08e862336a
commit 44b546d446
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 140 additions and 45 deletions

View File

@ -3,8 +3,8 @@
An example of generating text with Llama (1 or 2) using MLX. An example of generating text with Llama (1 or 2) using MLX.
Llama is a set of open source language models from Meta AI Research[^1][^2] Llama is a set of open source language models from Meta AI Research[^1][^2]
ranging from 7B to 70B parameters. This example also supports Llama Chat and ranging from 7B to 70B parameters. This example also supports Meta's Llama Chat
Code Llama. and Code Llama models, as well as the 1.1B TinyLlama models from SUTD.[^3]
### Setup ### Setup
@ -25,10 +25,19 @@ Alternatively, you can also download a select converted checkpoints from the
[mlx-llama](https://huggingface.co/mlx-llama) community organisation on Hugging [mlx-llama](https://huggingface.co/mlx-llama) community organisation on Hugging
Face and skip the conversion step. Face and skip the conversion step.
You can download the TinyLlama models directly from [Hugging
Face](https://huggingface.co/TinyLlama).
Convert the weights with: Convert the weights with:
``` ```
python convert.py --model_path <path_to_torch_model> python convert.py --model-path <path_to_torch_model>
```
For TinyLlama use
```
python convert.py --model-path <path_to_torch_model> --model-name tiny_llama
``` ```
The conversion script will save the converted weights in the same location. The conversion script will save the converted weights in the same location.
@ -39,10 +48,11 @@ Once you've converted the weights to MLX format, you can interact with the
LlaMA model: LlaMA model:
``` ```
python llama.py <path_to_model> <path_to_tokenizer.model> "hello" python llama.py <path_to_model> <path_to_tokenizer.model> --prompt "hello"
``` ```
Run `python llama.py --help` for more details. Run `python llama.py --help` for more details.
[^1]: For Llama v1 refer to the [arXiv paper](https://arxiv.org/abs/2302.13971) and [blog post](https://ai.meta.com/blog/large-language-model-llama-meta-ai/) for more details. [^1]: For Llama v1 refer to the [arXiv paper](https://arxiv.org/abs/2302.13971) and [blog post](https://ai.meta.com/blog/large-language-model-llama-meta-ai/) for more details.
[^2]: For Llama v2 refer to the [blob post](https://ai.meta.com/llama/) [^2]: For Llama v2 refer to the [blob post](https://ai.meta.com/llama/)
[^3]: For TinyLlama refer to the [gihub repository](https://github.com/jzhang38/TinyLlama?tab=readme-ov-file)

View File

@ -3,23 +3,23 @@
import argparse import argparse
import collections import collections
import glob import glob
from pathlib import Path import json
import numpy as np import numpy as np
from pathlib import Path
import torch import torch
def llama(model_path):
SHARD_FIRST = ["wv", "wq", "wk", "w1", "w3", "output"] SHARD_FIRST = ["wv", "wq", "wk", "w1", "w3", "output"]
SHARD_SECOND = ["tok_embeddings", "wo", "w2"] SHARD_SECOND = ["tok_embeddings", "wo", "w2"]
SHARD_WEIGHTS = set(SHARD_FIRST + SHARD_SECOND) SHARD_WEIGHTS = set(SHARD_FIRST + SHARD_SECOND)
def shard_key(k): def shard_key(k):
keys = k.split(".") keys = k.split(".")
if len(keys) < 2: if len(keys) < 2:
return None return None
return keys[-2] return keys[-2]
def unshard(k, v): def unshard(k, v):
wn = shard_key(k) wn = shard_key(k)
if wn not in SHARD_WEIGHTS: if wn not in SHARD_WEIGHTS:
@ -32,16 +32,6 @@ def unshard(k, v):
raise ValueError("Invalid weight name") raise ValueError("Invalid weight name")
return np.concatenate(v, axis=axis) return np.concatenate(v, axis=axis)
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Convert Llama weights to MLX")
parser.add_argument(
"--model_path",
help="Path to the Torch model. The MLX weights will also be saved there.",
)
args = parser.parse_args()
model_path = Path(args.model_path)
torch_files = glob.glob(str(model_path / "consolidated.*.pth")) torch_files = glob.glob(str(model_path / "consolidated.*.pth"))
weights = collections.defaultdict(list) weights = collections.defaultdict(list)
for wf in torch_files: for wf in torch_files:
@ -53,7 +43,96 @@ if __name__ == "__main__":
else: else:
weights[k] = v weights[k] = v
out_file = str(model_path / "weights.npz")
for k, v in weights.items(): for k, v in weights.items():
weights[k] = unshard(k, v) weights[k] = unshard(k, v)
np.savez(out_file, **weights) return weights, None
def tiny_llama(model_path):
try:
import transformers
except ImportError as e:
print("The transformers package must be installed for this model conversion:")
print("pip install transformers")
import sys
sys.exit(0)
model = transformers.AutoModelForCausalLM.from_pretrained(
str(model_path)
).state_dict()
config = transformers.AutoConfig.from_pretrained(model_path)
# things to change
# 1. there's no "model." in the weight names
model = {k.replace("model.", ""): v for k, v in model.items()}
# 2. mlp is called feed_forward
model = {k.replace("mlp", "feed_forward"): v for k, v in model.items()}
# 3. up_proj, down_proj, gate_proj
model = {k.replace("down_proj", "w2"): v for k, v in model.items()}
model = {k.replace("up_proj", "w3"): v for k, v in model.items()}
model = {k.replace("gate_proj", "w1"): v for k, v in model.items()}
# 4. layernorms
model = {
k.replace("input_layernorm", "attention_norm"): v for k, v in model.items()
}
model = {
k.replace("post_attention_layernorm", "ffn_norm"): v for k, v in model.items()
}
# 5. lm head
model = {k.replace("lm_head", "output"): v for k, v in model.items()}
# 6. token emb
model = {k.replace("embed_tokens", "tok_embeddings"): v for k, v in model.items()}
# 7. attention
model = {k.replace("self_attn", "attention"): v for k, v in model.items()}
model = {k.replace("q_proj", "wq"): v for k, v in model.items()}
model = {k.replace("k_proj", "wk"): v for k, v in model.items()}
model = {k.replace("v_proj", "wv"): v for k, v in model.items()}
model = {k.replace("o_proj", "wo"): v for k, v in model.items()}
params = {}
params["dim"] = config.hidden_size
params["hidden_dim"] = config.intermediate_size
params["n_heads"] = config.num_attention_heads
if hasattr(config, "num_key_value_heads"):
params["n_kv_heads"] = config.num_key_value_heads
params["n_layers"] = config.num_hidden_layers
params["vocab_size"] = config.vocab_size
params["norm_eps"] = config.rms_norm_eps
params["rope_traditional"] = False
weights = {k: v.to(torch.float16).numpy() for k, v in model.items()}
return weights, params
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Convert Llama weights to MLX")
parser.add_argument(
"--model-path",
help="Path to the model. The MLX weights will also be saved there.",
)
parser.add_argument(
"--model-name",
help=(
"Name of the model to convert. Use 'llama' for models in the "
"Llama family distributed by Meta including Llama 1, Llama 2, "
"Coda Llama, and Llama chat."
),
choices=["tiny_llama", "llama"],
default="llama",
)
args = parser.parse_args()
model_path = Path(args.model_path)
weights, params = globals()[args.model_name](model_path)
np.savez(str(model_path / "weights.npz"), **weights)
if params is not None:
with open(model_path / "params.json", "w") as fid:
json.dump(params, fid, indent=4)

View File

@ -24,6 +24,7 @@ class ModelArgs:
norm_eps: float norm_eps: float
vocab_size: int vocab_size: int
rope_theta: float rope_theta: float
rope_traditional: bool = True
class RMSNorm(nn.Module): class RMSNorm(nn.Module):
@ -77,7 +78,9 @@ class Attention(nn.Module):
self.wk = nn.Linear(args.dim, args.n_kv_heads * args.head_dim, bias=False) self.wk = nn.Linear(args.dim, args.n_kv_heads * args.head_dim, bias=False)
self.wv = nn.Linear(args.dim, args.n_kv_heads * args.head_dim, bias=False) self.wv = nn.Linear(args.dim, args.n_kv_heads * args.head_dim, bias=False)
self.wo = nn.Linear(args.n_heads * args.head_dim, args.dim, bias=False) self.wo = nn.Linear(args.n_heads * args.head_dim, args.dim, bias=False)
self.rope = RoPE(args.head_dim, traditional=True, base=args.rope_theta) self.rope = RoPE(
args.head_dim, traditional=args.rope_traditional, base=args.rope_theta
)
def __call__( def __call__(
self, self,
@ -234,7 +237,7 @@ def generate(args):
input("Press enter to start generation") input("Press enter to start generation")
print("------") print("------")
print(args.prompt)
x = mx.array([[tokenizer.bos_id()] + tokenizer.encode(args.prompt)]) x = mx.array([[tokenizer.bos_id()] + tokenizer.encode(args.prompt)])
skip = 0 skip = 0
prompt_processing = None prompt_processing = None
@ -248,7 +251,7 @@ def generate(args):
mx.eval(token) mx.eval(token)
prompt_processing = toc("Prompt processing", start) prompt_processing = toc("Prompt processing", start)
if len(tokens) >= args.num_tokens: if len(tokens) >= args.max_tokens:
break break
elif (len(tokens) % args.write_every) == 0: elif (len(tokens) % args.write_every) == 0:
@ -261,8 +264,7 @@ def generate(args):
mx.eval(tokens) mx.eval(tokens)
full_gen = toc("Full generation", start) full_gen = toc("Full generation", start)
s = tokenizer.decode([t.item() for t in tokens]) s = tokenizer.decode([t.item() for t in tokens])
print(s[skip:], end="", flush=True) print(s[skip:], flush=True)
print()
print("------") print("------")
print(prompt_processing) print(prompt_processing)
print(full_gen) print(full_gen)
@ -354,14 +356,18 @@ if __name__ == "__main__":
"model", help="Path to the model directory containing the MLX weights" "model", help="Path to the model directory containing the MLX weights"
) )
parser.add_argument("tokenizer", help="The sentencepiece tokenizer") parser.add_argument("tokenizer", help="The sentencepiece tokenizer")
parser.add_argument("prompt", help="The message to be processed by the model") parser.add_argument(
"--prompt",
help="The message to be processed by the model",
default="In the beginning the Universe was created.",
)
parser.add_argument( parser.add_argument(
"--few-shot", "--few-shot",
action="store_true", action="store_true",
help="Read a few shot prompt from a file (as in `sample_prompt.txt`).", help="Read a few shot prompt from a file (as in `sample_prompt.txt`).",
) )
parser.add_argument( parser.add_argument(
"--num-tokens", "-n", type=int, default=100, help="How many tokens to generate" "--max-tokens", "-m", type=int, default=100, help="How many tokens to generate"
) )
parser.add_argument( parser.add_argument(
"--write-every", type=int, default=1, help="After how many tokens to detokenize" "--write-every", type=int, default=1, help="After how many tokens to detokenize"