mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-24 17:31:18 +08:00
support for tiny llama (#129)
This commit is contained in:
parent
08e862336a
commit
44b546d446
@ -3,8 +3,8 @@
|
|||||||
An example of generating text with Llama (1 or 2) using MLX.
|
An example of generating text with Llama (1 or 2) using MLX.
|
||||||
|
|
||||||
Llama is a set of open source language models from Meta AI Research[^1][^2]
|
Llama is a set of open source language models from Meta AI Research[^1][^2]
|
||||||
ranging from 7B to 70B parameters. This example also supports Llama Chat and
|
ranging from 7B to 70B parameters. This example also supports Meta's Llama Chat
|
||||||
Code Llama.
|
and Code Llama models, as well as the 1.1B TinyLlama models from SUTD.[^3]
|
||||||
|
|
||||||
### Setup
|
### Setup
|
||||||
|
|
||||||
@ -25,10 +25,19 @@ Alternatively, you can also download a select converted checkpoints from the
|
|||||||
[mlx-llama](https://huggingface.co/mlx-llama) community organisation on Hugging
|
[mlx-llama](https://huggingface.co/mlx-llama) community organisation on Hugging
|
||||||
Face and skip the conversion step.
|
Face and skip the conversion step.
|
||||||
|
|
||||||
|
You can download the TinyLlama models directly from [Hugging
|
||||||
|
Face](https://huggingface.co/TinyLlama).
|
||||||
|
|
||||||
Convert the weights with:
|
Convert the weights with:
|
||||||
|
|
||||||
```
|
```
|
||||||
python convert.py --model_path <path_to_torch_model>
|
python convert.py --model-path <path_to_torch_model>
|
||||||
|
```
|
||||||
|
|
||||||
|
For TinyLlama use
|
||||||
|
|
||||||
|
```
|
||||||
|
python convert.py --model-path <path_to_torch_model> --model-name tiny_llama
|
||||||
```
|
```
|
||||||
|
|
||||||
The conversion script will save the converted weights in the same location.
|
The conversion script will save the converted weights in the same location.
|
||||||
@ -39,10 +48,11 @@ Once you've converted the weights to MLX format, you can interact with the
|
|||||||
LlaMA model:
|
LlaMA model:
|
||||||
|
|
||||||
```
|
```
|
||||||
python llama.py <path_to_model> <path_to_tokenizer.model> "hello"
|
python llama.py <path_to_model> <path_to_tokenizer.model> --prompt "hello"
|
||||||
```
|
```
|
||||||
|
|
||||||
Run `python llama.py --help` for more details.
|
Run `python llama.py --help` for more details.
|
||||||
|
|
||||||
[^1]: For Llama v1 refer to the [arXiv paper](https://arxiv.org/abs/2302.13971) and [blog post](https://ai.meta.com/blog/large-language-model-llama-meta-ai/) for more details.
|
[^1]: For Llama v1 refer to the [arXiv paper](https://arxiv.org/abs/2302.13971) and [blog post](https://ai.meta.com/blog/large-language-model-llama-meta-ai/) for more details.
|
||||||
[^2]: For Llama v2 refer to the [blob post](https://ai.meta.com/llama/)
|
[^2]: For Llama v2 refer to the [blob post](https://ai.meta.com/llama/)
|
||||||
|
[^3]: For TinyLlama refer to the [gihub repository](https://github.com/jzhang38/TinyLlama?tab=readme-ov-file)
|
||||||
|
119
llama/convert.py
119
llama/convert.py
@ -3,24 +3,24 @@
|
|||||||
import argparse
|
import argparse
|
||||||
import collections
|
import collections
|
||||||
import glob
|
import glob
|
||||||
from pathlib import Path
|
import json
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
from pathlib import Path
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
SHARD_FIRST = ["wv", "wq", "wk", "w1", "w3", "output"]
|
|
||||||
SHARD_SECOND = ["tok_embeddings", "wo", "w2"]
|
|
||||||
SHARD_WEIGHTS = set(SHARD_FIRST + SHARD_SECOND)
|
|
||||||
|
|
||||||
|
def llama(model_path):
|
||||||
|
SHARD_FIRST = ["wv", "wq", "wk", "w1", "w3", "output"]
|
||||||
|
SHARD_SECOND = ["tok_embeddings", "wo", "w2"]
|
||||||
|
SHARD_WEIGHTS = set(SHARD_FIRST + SHARD_SECOND)
|
||||||
|
|
||||||
def shard_key(k):
|
def shard_key(k):
|
||||||
keys = k.split(".")
|
keys = k.split(".")
|
||||||
if len(keys) < 2:
|
if len(keys) < 2:
|
||||||
return None
|
return None
|
||||||
return keys[-2]
|
return keys[-2]
|
||||||
|
|
||||||
|
def unshard(k, v):
|
||||||
def unshard(k, v):
|
|
||||||
wn = shard_key(k)
|
wn = shard_key(k)
|
||||||
if wn not in SHARD_WEIGHTS:
|
if wn not in SHARD_WEIGHTS:
|
||||||
return v
|
return v
|
||||||
@ -32,16 +32,6 @@ def unshard(k, v):
|
|||||||
raise ValueError("Invalid weight name")
|
raise ValueError("Invalid weight name")
|
||||||
return np.concatenate(v, axis=axis)
|
return np.concatenate(v, axis=axis)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
parser = argparse.ArgumentParser(description="Convert Llama weights to MLX")
|
|
||||||
parser.add_argument(
|
|
||||||
"--model_path",
|
|
||||||
help="Path to the Torch model. The MLX weights will also be saved there.",
|
|
||||||
)
|
|
||||||
args = parser.parse_args()
|
|
||||||
|
|
||||||
model_path = Path(args.model_path)
|
|
||||||
torch_files = glob.glob(str(model_path / "consolidated.*.pth"))
|
torch_files = glob.glob(str(model_path / "consolidated.*.pth"))
|
||||||
weights = collections.defaultdict(list)
|
weights = collections.defaultdict(list)
|
||||||
for wf in torch_files:
|
for wf in torch_files:
|
||||||
@ -53,7 +43,96 @@ if __name__ == "__main__":
|
|||||||
else:
|
else:
|
||||||
weights[k] = v
|
weights[k] = v
|
||||||
|
|
||||||
out_file = str(model_path / "weights.npz")
|
|
||||||
for k, v in weights.items():
|
for k, v in weights.items():
|
||||||
weights[k] = unshard(k, v)
|
weights[k] = unshard(k, v)
|
||||||
np.savez(out_file, **weights)
|
return weights, None
|
||||||
|
|
||||||
|
|
||||||
|
def tiny_llama(model_path):
|
||||||
|
try:
|
||||||
|
import transformers
|
||||||
|
except ImportError as e:
|
||||||
|
print("The transformers package must be installed for this model conversion:")
|
||||||
|
print("pip install transformers")
|
||||||
|
import sys
|
||||||
|
|
||||||
|
sys.exit(0)
|
||||||
|
|
||||||
|
model = transformers.AutoModelForCausalLM.from_pretrained(
|
||||||
|
str(model_path)
|
||||||
|
).state_dict()
|
||||||
|
config = transformers.AutoConfig.from_pretrained(model_path)
|
||||||
|
|
||||||
|
# things to change
|
||||||
|
# 1. there's no "model." in the weight names
|
||||||
|
model = {k.replace("model.", ""): v for k, v in model.items()}
|
||||||
|
|
||||||
|
# 2. mlp is called feed_forward
|
||||||
|
model = {k.replace("mlp", "feed_forward"): v for k, v in model.items()}
|
||||||
|
|
||||||
|
# 3. up_proj, down_proj, gate_proj
|
||||||
|
model = {k.replace("down_proj", "w2"): v for k, v in model.items()}
|
||||||
|
model = {k.replace("up_proj", "w3"): v for k, v in model.items()}
|
||||||
|
model = {k.replace("gate_proj", "w1"): v for k, v in model.items()}
|
||||||
|
|
||||||
|
# 4. layernorms
|
||||||
|
model = {
|
||||||
|
k.replace("input_layernorm", "attention_norm"): v for k, v in model.items()
|
||||||
|
}
|
||||||
|
model = {
|
||||||
|
k.replace("post_attention_layernorm", "ffn_norm"): v for k, v in model.items()
|
||||||
|
}
|
||||||
|
|
||||||
|
# 5. lm head
|
||||||
|
model = {k.replace("lm_head", "output"): v for k, v in model.items()}
|
||||||
|
|
||||||
|
# 6. token emb
|
||||||
|
model = {k.replace("embed_tokens", "tok_embeddings"): v for k, v in model.items()}
|
||||||
|
|
||||||
|
# 7. attention
|
||||||
|
model = {k.replace("self_attn", "attention"): v for k, v in model.items()}
|
||||||
|
model = {k.replace("q_proj", "wq"): v for k, v in model.items()}
|
||||||
|
model = {k.replace("k_proj", "wk"): v for k, v in model.items()}
|
||||||
|
model = {k.replace("v_proj", "wv"): v for k, v in model.items()}
|
||||||
|
model = {k.replace("o_proj", "wo"): v for k, v in model.items()}
|
||||||
|
|
||||||
|
params = {}
|
||||||
|
params["dim"] = config.hidden_size
|
||||||
|
params["hidden_dim"] = config.intermediate_size
|
||||||
|
params["n_heads"] = config.num_attention_heads
|
||||||
|
if hasattr(config, "num_key_value_heads"):
|
||||||
|
params["n_kv_heads"] = config.num_key_value_heads
|
||||||
|
params["n_layers"] = config.num_hidden_layers
|
||||||
|
params["vocab_size"] = config.vocab_size
|
||||||
|
params["norm_eps"] = config.rms_norm_eps
|
||||||
|
params["rope_traditional"] = False
|
||||||
|
weights = {k: v.to(torch.float16).numpy() for k, v in model.items()}
|
||||||
|
|
||||||
|
return weights, params
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
parser = argparse.ArgumentParser(description="Convert Llama weights to MLX")
|
||||||
|
parser.add_argument(
|
||||||
|
"--model-path",
|
||||||
|
help="Path to the model. The MLX weights will also be saved there.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--model-name",
|
||||||
|
help=(
|
||||||
|
"Name of the model to convert. Use 'llama' for models in the "
|
||||||
|
"Llama family distributed by Meta including Llama 1, Llama 2, "
|
||||||
|
"Coda Llama, and Llama chat."
|
||||||
|
),
|
||||||
|
choices=["tiny_llama", "llama"],
|
||||||
|
default="llama",
|
||||||
|
)
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
model_path = Path(args.model_path)
|
||||||
|
weights, params = globals()[args.model_name](model_path)
|
||||||
|
np.savez(str(model_path / "weights.npz"), **weights)
|
||||||
|
if params is not None:
|
||||||
|
with open(model_path / "params.json", "w") as fid:
|
||||||
|
json.dump(params, fid, indent=4)
|
||||||
|
@ -24,6 +24,7 @@ class ModelArgs:
|
|||||||
norm_eps: float
|
norm_eps: float
|
||||||
vocab_size: int
|
vocab_size: int
|
||||||
rope_theta: float
|
rope_theta: float
|
||||||
|
rope_traditional: bool = True
|
||||||
|
|
||||||
|
|
||||||
class RMSNorm(nn.Module):
|
class RMSNorm(nn.Module):
|
||||||
@ -77,7 +78,9 @@ class Attention(nn.Module):
|
|||||||
self.wk = nn.Linear(args.dim, args.n_kv_heads * args.head_dim, bias=False)
|
self.wk = nn.Linear(args.dim, args.n_kv_heads * args.head_dim, bias=False)
|
||||||
self.wv = nn.Linear(args.dim, args.n_kv_heads * args.head_dim, bias=False)
|
self.wv = nn.Linear(args.dim, args.n_kv_heads * args.head_dim, bias=False)
|
||||||
self.wo = nn.Linear(args.n_heads * args.head_dim, args.dim, bias=False)
|
self.wo = nn.Linear(args.n_heads * args.head_dim, args.dim, bias=False)
|
||||||
self.rope = RoPE(args.head_dim, traditional=True, base=args.rope_theta)
|
self.rope = RoPE(
|
||||||
|
args.head_dim, traditional=args.rope_traditional, base=args.rope_theta
|
||||||
|
)
|
||||||
|
|
||||||
def __call__(
|
def __call__(
|
||||||
self,
|
self,
|
||||||
@ -234,7 +237,7 @@ def generate(args):
|
|||||||
|
|
||||||
input("Press enter to start generation")
|
input("Press enter to start generation")
|
||||||
print("------")
|
print("------")
|
||||||
|
print(args.prompt)
|
||||||
x = mx.array([[tokenizer.bos_id()] + tokenizer.encode(args.prompt)])
|
x = mx.array([[tokenizer.bos_id()] + tokenizer.encode(args.prompt)])
|
||||||
skip = 0
|
skip = 0
|
||||||
prompt_processing = None
|
prompt_processing = None
|
||||||
@ -248,7 +251,7 @@ def generate(args):
|
|||||||
mx.eval(token)
|
mx.eval(token)
|
||||||
prompt_processing = toc("Prompt processing", start)
|
prompt_processing = toc("Prompt processing", start)
|
||||||
|
|
||||||
if len(tokens) >= args.num_tokens:
|
if len(tokens) >= args.max_tokens:
|
||||||
break
|
break
|
||||||
|
|
||||||
elif (len(tokens) % args.write_every) == 0:
|
elif (len(tokens) % args.write_every) == 0:
|
||||||
@ -261,8 +264,7 @@ def generate(args):
|
|||||||
mx.eval(tokens)
|
mx.eval(tokens)
|
||||||
full_gen = toc("Full generation", start)
|
full_gen = toc("Full generation", start)
|
||||||
s = tokenizer.decode([t.item() for t in tokens])
|
s = tokenizer.decode([t.item() for t in tokens])
|
||||||
print(s[skip:], end="", flush=True)
|
print(s[skip:], flush=True)
|
||||||
print()
|
|
||||||
print("------")
|
print("------")
|
||||||
print(prompt_processing)
|
print(prompt_processing)
|
||||||
print(full_gen)
|
print(full_gen)
|
||||||
@ -354,14 +356,18 @@ if __name__ == "__main__":
|
|||||||
"model", help="Path to the model directory containing the MLX weights"
|
"model", help="Path to the model directory containing the MLX weights"
|
||||||
)
|
)
|
||||||
parser.add_argument("tokenizer", help="The sentencepiece tokenizer")
|
parser.add_argument("tokenizer", help="The sentencepiece tokenizer")
|
||||||
parser.add_argument("prompt", help="The message to be processed by the model")
|
parser.add_argument(
|
||||||
|
"--prompt",
|
||||||
|
help="The message to be processed by the model",
|
||||||
|
default="In the beginning the Universe was created.",
|
||||||
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--few-shot",
|
"--few-shot",
|
||||||
action="store_true",
|
action="store_true",
|
||||||
help="Read a few shot prompt from a file (as in `sample_prompt.txt`).",
|
help="Read a few shot prompt from a file (as in `sample_prompt.txt`).",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--num-tokens", "-n", type=int, default=100, help="How many tokens to generate"
|
"--max-tokens", "-m", type=int, default=100, help="How many tokens to generate"
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--write-every", type=int, default=1, help="After how many tokens to detokenize"
|
"--write-every", type=int, default=1, help="After how many tokens to detokenize"
|
||||||
|
Loading…
Reference in New Issue
Block a user