Merge branch 'ml-explore:main' into main

This commit is contained in:
Sarthak Yadav 2023-12-19 23:05:20 +01:00 committed by GitHub
commit a33c3095c4
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
18 changed files with 1211 additions and 60 deletions

10
ACKNOWLEDGMENTS.md Normal file
View File

@ -0,0 +1,10 @@
# Individual Contributors
If you wish to be acknowledged for your contributions, please list your name
with a short description of your contribution(s) below. For example:
- Jane Smith: Added the `foo` example.
MLX Examples was developed with contributions from the following individuals:
- Juarez Bochi: Added support for T5 models.

View File

@ -18,5 +18,24 @@ Some more useful examples include:
## Contributing ## Contributing
Check out the [contribution guidelines](CONTRIBUTING.md) for more information We are grateful for all of [our
on contributing to this repo. contributors](ACKNOWLEDGMENTS.md#Individual-Contributors). If you contribute
to MLX Examples and wish to be acknowledged, please add your name to to the list in your
pull request.
## Citing MLX Examples
The MLX software suite was initially developed with equal contribution by Awni
Hannun, Jagrit Digani, Angelos Katharopoulos, and Ronan Collobert. If you find
MLX Examples useful in your research and wish to cite it, please use the following
BibTex entry:
```
@software{mlx2023,
author = {Awni Hannun and Jagrit Digani and Angelos Katharopoulos and Ronan Collobert},
title = {{MLX}: Efficient and flexible machine learning on Apple silicon},
url = {https://github.com/ml-explore},
version = {0.0},
year = {2023},
}
```

View File

@ -3,8 +3,8 @@
An example of generating text with Llama (1 or 2) using MLX. An example of generating text with Llama (1 or 2) using MLX.
Llama is a set of open source language models from Meta AI Research[^1][^2] Llama is a set of open source language models from Meta AI Research[^1][^2]
ranging from 7B to 70B parameters. This example also supports Llama Chat and ranging from 7B to 70B parameters. This example also supports Meta's Llama Chat
Code Llama. and Code Llama models, as well as the 1.1B TinyLlama models from SUTD.[^3]
### Setup ### Setup
@ -25,10 +25,19 @@ Alternatively, you can also download a select converted checkpoints from the
[mlx-llama](https://huggingface.co/mlx-llama) community organisation on Hugging [mlx-llama](https://huggingface.co/mlx-llama) community organisation on Hugging
Face and skip the conversion step. Face and skip the conversion step.
You can download the TinyLlama models directly from [Hugging
Face](https://huggingface.co/TinyLlama).
Convert the weights with: Convert the weights with:
``` ```
python convert.py --model_path <path_to_torch_model> python convert.py --model-path <path_to_torch_model>
```
For TinyLlama use
```
python convert.py --model-path <path_to_torch_model> --model-name tiny_llama
``` ```
The conversion script will save the converted weights in the same location. The conversion script will save the converted weights in the same location.
@ -39,10 +48,11 @@ Once you've converted the weights to MLX format, you can interact with the
LlaMA model: LlaMA model:
``` ```
python llama.py <path_to_model> <path_to_tokenizer.model> "hello" python llama.py <path_to_model> <path_to_tokenizer.model> --prompt "hello"
``` ```
Run `python llama.py --help` for more details. Run `python llama.py --help` for more details.
[^1]: For Llama v1 refer to the [arXiv paper](https://arxiv.org/abs/2302.13971) and [blog post](https://ai.meta.com/blog/large-language-model-llama-meta-ai/) for more details. [^1]: For Llama v1 refer to the [arXiv paper](https://arxiv.org/abs/2302.13971) and [blog post](https://ai.meta.com/blog/large-language-model-llama-meta-ai/) for more details.
[^2]: For Llama v2 refer to the [blob post](https://ai.meta.com/llama/) [^2]: For Llama v2 refer to the [blob post](https://ai.meta.com/llama/)
[^3]: For TinyLlama refer to the [gihub repository](https://github.com/jzhang38/TinyLlama?tab=readme-ov-file)

View File

@ -3,24 +3,24 @@
import argparse import argparse
import collections import collections
import glob import glob
from pathlib import Path import json
import numpy as np import numpy as np
from pathlib import Path
import torch import torch
SHARD_FIRST = ["wv", "wq", "wk", "w1", "w3", "output"]
SHARD_SECOND = ["tok_embeddings", "wo", "w2"]
SHARD_WEIGHTS = set(SHARD_FIRST + SHARD_SECOND)
def llama(model_path):
SHARD_FIRST = ["wv", "wq", "wk", "w1", "w3", "output"]
SHARD_SECOND = ["tok_embeddings", "wo", "w2"]
SHARD_WEIGHTS = set(SHARD_FIRST + SHARD_SECOND)
def shard_key(k): def shard_key(k):
keys = k.split(".") keys = k.split(".")
if len(keys) < 2: if len(keys) < 2:
return None return None
return keys[-2] return keys[-2]
def unshard(k, v):
def unshard(k, v):
wn = shard_key(k) wn = shard_key(k)
if wn not in SHARD_WEIGHTS: if wn not in SHARD_WEIGHTS:
return v return v
@ -32,16 +32,6 @@ def unshard(k, v):
raise ValueError("Invalid weight name") raise ValueError("Invalid weight name")
return np.concatenate(v, axis=axis) return np.concatenate(v, axis=axis)
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Convert Llama weights to MLX")
parser.add_argument(
"--model_path",
help="Path to the Torch model. The MLX weights will also be saved there.",
)
args = parser.parse_args()
model_path = Path(args.model_path)
torch_files = glob.glob(str(model_path / "consolidated.*.pth")) torch_files = glob.glob(str(model_path / "consolidated.*.pth"))
weights = collections.defaultdict(list) weights = collections.defaultdict(list)
for wf in torch_files: for wf in torch_files:
@ -53,7 +43,96 @@ if __name__ == "__main__":
else: else:
weights[k] = v weights[k] = v
out_file = str(model_path / "weights.npz")
for k, v in weights.items(): for k, v in weights.items():
weights[k] = unshard(k, v) weights[k] = unshard(k, v)
np.savez(out_file, **weights) return weights, None
def tiny_llama(model_path):
try:
import transformers
except ImportError as e:
print("The transformers package must be installed for this model conversion:")
print("pip install transformers")
import sys
sys.exit(0)
model = transformers.AutoModelForCausalLM.from_pretrained(
str(model_path)
).state_dict()
config = transformers.AutoConfig.from_pretrained(model_path)
# things to change
# 1. there's no "model." in the weight names
model = {k.replace("model.", ""): v for k, v in model.items()}
# 2. mlp is called feed_forward
model = {k.replace("mlp", "feed_forward"): v for k, v in model.items()}
# 3. up_proj, down_proj, gate_proj
model = {k.replace("down_proj", "w2"): v for k, v in model.items()}
model = {k.replace("up_proj", "w3"): v for k, v in model.items()}
model = {k.replace("gate_proj", "w1"): v for k, v in model.items()}
# 4. layernorms
model = {
k.replace("input_layernorm", "attention_norm"): v for k, v in model.items()
}
model = {
k.replace("post_attention_layernorm", "ffn_norm"): v for k, v in model.items()
}
# 5. lm head
model = {k.replace("lm_head", "output"): v for k, v in model.items()}
# 6. token emb
model = {k.replace("embed_tokens", "tok_embeddings"): v for k, v in model.items()}
# 7. attention
model = {k.replace("self_attn", "attention"): v for k, v in model.items()}
model = {k.replace("q_proj", "wq"): v for k, v in model.items()}
model = {k.replace("k_proj", "wk"): v for k, v in model.items()}
model = {k.replace("v_proj", "wv"): v for k, v in model.items()}
model = {k.replace("o_proj", "wo"): v for k, v in model.items()}
params = {}
params["dim"] = config.hidden_size
params["hidden_dim"] = config.intermediate_size
params["n_heads"] = config.num_attention_heads
if hasattr(config, "num_key_value_heads"):
params["n_kv_heads"] = config.num_key_value_heads
params["n_layers"] = config.num_hidden_layers
params["vocab_size"] = config.vocab_size
params["norm_eps"] = config.rms_norm_eps
params["rope_traditional"] = False
weights = {k: v.to(torch.float16).numpy() for k, v in model.items()}
return weights, params
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Convert Llama weights to MLX")
parser.add_argument(
"--model-path",
help="Path to the model. The MLX weights will also be saved there.",
)
parser.add_argument(
"--model-name",
help=(
"Name of the model to convert. Use 'llama' for models in the "
"Llama family distributed by Meta including Llama 1, Llama 2, "
"Coda Llama, and Llama chat."
),
choices=["tiny_llama", "llama"],
default="llama",
)
args = parser.parse_args()
model_path = Path(args.model_path)
weights, params = globals()[args.model_name](model_path)
np.savez(str(model_path / "weights.npz"), **weights)
if params is not None:
with open(model_path / "params.json", "w") as fid:
json.dump(params, fid, indent=4)

View File

@ -24,6 +24,7 @@ class ModelArgs:
norm_eps: float norm_eps: float
vocab_size: int vocab_size: int
rope_theta: float rope_theta: float
rope_traditional: bool = True
class RMSNorm(nn.Module): class RMSNorm(nn.Module):
@ -77,7 +78,9 @@ class Attention(nn.Module):
self.wk = nn.Linear(args.dim, args.n_kv_heads * args.head_dim, bias=False) self.wk = nn.Linear(args.dim, args.n_kv_heads * args.head_dim, bias=False)
self.wv = nn.Linear(args.dim, args.n_kv_heads * args.head_dim, bias=False) self.wv = nn.Linear(args.dim, args.n_kv_heads * args.head_dim, bias=False)
self.wo = nn.Linear(args.n_heads * args.head_dim, args.dim, bias=False) self.wo = nn.Linear(args.n_heads * args.head_dim, args.dim, bias=False)
self.rope = RoPE(args.head_dim, traditional=True, base=args.rope_theta) self.rope = RoPE(
args.head_dim, traditional=args.rope_traditional, base=args.rope_theta
)
def __call__( def __call__(
self, self,
@ -234,7 +237,7 @@ def generate(args):
input("Press enter to start generation") input("Press enter to start generation")
print("------") print("------")
print(args.prompt)
x = mx.array([[tokenizer.bos_id()] + tokenizer.encode(args.prompt)]) x = mx.array([[tokenizer.bos_id()] + tokenizer.encode(args.prompt)])
skip = 0 skip = 0
prompt_processing = None prompt_processing = None
@ -248,7 +251,7 @@ def generate(args):
mx.eval(token) mx.eval(token)
prompt_processing = toc("Prompt processing", start) prompt_processing = toc("Prompt processing", start)
if len(tokens) >= args.num_tokens: if len(tokens) >= args.max_tokens:
break break
elif (len(tokens) % args.write_every) == 0: elif (len(tokens) % args.write_every) == 0:
@ -261,8 +264,7 @@ def generate(args):
mx.eval(tokens) mx.eval(tokens)
full_gen = toc("Full generation", start) full_gen = toc("Full generation", start)
s = tokenizer.decode([t.item() for t in tokens]) s = tokenizer.decode([t.item() for t in tokens])
print(s[skip:], end="", flush=True) print(s[skip:], flush=True)
print()
print("------") print("------")
print(prompt_processing) print(prompt_processing)
print(full_gen) print(full_gen)
@ -292,7 +294,7 @@ def few_shot_generate(args):
mx.eval(token) mx.eval(token)
prompt_processing = toc("Prompt processing", start) prompt_processing = toc("Prompt processing", start)
if len(tokens) >= args.num_tokens: if len(tokens) >= args.max_tokens:
break break
mx.eval(tokens) mx.eval(tokens)
@ -316,7 +318,8 @@ def few_shot_generate(args):
s = tokenizer.decode([t.item() for t in tokens]) s = tokenizer.decode([t.item() for t in tokens])
print(s[skip:], end="", flush=True) print(s[skip:], end="", flush=True)
prompt = open(args.prompt).read().strip() print("[INFO] Loading few-shot examples from: {}".format(args.few_shot))
prompt = open(args.few_shot).read().strip()
while True: while True:
question = input("Ask a question: ") question = input("Ask a question: ")
generate(prompt.replace("{}", question)) generate(prompt.replace("{}", question))
@ -354,14 +357,17 @@ if __name__ == "__main__":
"model", help="Path to the model directory containing the MLX weights" "model", help="Path to the model directory containing the MLX weights"
) )
parser.add_argument("tokenizer", help="The sentencepiece tokenizer") parser.add_argument("tokenizer", help="The sentencepiece tokenizer")
parser.add_argument("prompt", help="The message to be processed by the model") parser.add_argument(
"--prompt",
help="The message to be processed by the model. Ignored when --few-shot is provided.",
default="In the beginning the Universe was created.",
)
parser.add_argument( parser.add_argument(
"--few-shot", "--few-shot",
action="store_true",
help="Read a few shot prompt from a file (as in `sample_prompt.txt`).", help="Read a few shot prompt from a file (as in `sample_prompt.txt`).",
) )
parser.add_argument( parser.add_argument(
"--num-tokens", "-n", type=int, default=100, help="How many tokens to generate" "--max-tokens", "-m", type=int, default=100, help="How many tokens to generate"
) )
parser.add_argument( parser.add_argument(
"--write-every", type=int, default=1, help="After how many tokens to detokenize" "--write-every", type=int, default=1, help="After how many tokens to detokenize"

2
llms/qwen/.gitignore vendored Normal file
View File

@ -0,0 +1,2 @@
weights.npz
config.json

41
llms/qwen/README.md Normal file
View File

@ -0,0 +1,41 @@
# Qwen
Qwen (通义千问) are a family of language models developed by Alibaba Cloud.[^1]
The architecture of the Qwen models is similar to Llama except for the bias in
the attention layers.
## Setup
First download and convert the model with:
```sh
python convert.py
```
The script downloads the model from Hugging Face. The default model is
`Qwen/Qwen-1_8B`. Check out the [Hugging Face page](https://huggingface.co/Qwen) to see a list of available models.
The conversion script will make the `weights.npz` and `config.json` files in
the working directory.
## Generate
To generate text with the default prompt:
```sh
python qwen.py
```
If you change the model, make sure to pass the corresponding tokenizer. E.g.,
for Qwen 7B use:
```
python qwen.py --tokenizer Qwen/Qwen-7B
```
To see a list of options, run:
```sh
python qwen.py --help
```
[^1]: For more details on the model see the official repo of [Qwen](https://github.com/QwenLM/Qwen) and the [Hugging Face](https://huggingface.co/Qwen).

42
llms/qwen/convert.py Normal file
View File

@ -0,0 +1,42 @@
import argparse
from transformers import AutoModelForCausalLM
import numpy as np
import torch
import json
def replace_key(key: str) -> str:
if key.startswith("transformer."):
# remove transformer prefix
key = key.replace("transformer.", "")
return key
def convert(model_path: str = "Qwen/Qwen-1_8B"):
model = AutoModelForCausalLM.from_pretrained(
model_path, trust_remote_code=True, torch_dtype=torch.float16
)
state_dict = model.state_dict()
weights = {replace_key(k): v.numpy() for k, v in state_dict.items()}
np.savez("weights.npz", **weights)
# write config
config = model.config
config_dict = config.to_dict()
with open("config.json", "w") as f:
json.dump(config_dict, f, indent=4)
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Convert Qwen model to npz")
parser.add_argument(
"--model",
help="The huggingface model to be converted",
default="Qwen/Qwen-1_8B",
)
args = parser.parse_args()
convert(args.model)

269
llms/qwen/qwen.py Normal file
View File

@ -0,0 +1,269 @@
import argparse
from dataclasses import dataclass
import json
import mlx.core as mx
import mlx.nn as nn
from mlx.utils import tree_unflatten
from transformers import AutoTokenizer
@dataclass
class ModelArgs:
hidden_size: int = 2048
num_attention_heads: int = 16
num_hidden_layers: int = 24
kv_channels: int = 128
max_position_embeddings: int = 8192
layer_norm_epsilon: float = 1e-6
intermediate_size: int = 11008
no_bias: bool = True
vocab_size: int = 151936
class RMSNorm(nn.Module):
def __init__(self, dims: int, eps: float = 1e-5):
super().__init__()
self.weight = mx.ones((dims,))
self.eps = eps
def _norm(self, x):
return x * mx.rsqrt(x.square().mean(-1, keepdims=True) + self.eps)
def __call__(self, x):
output = self._norm(x.astype(mx.float32)).astype(x.dtype)
return self.weight * output
class Attention(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
hidden_size = args.hidden_size
self.num_attention_heads = args.num_attention_heads
hidden_size_per_attention_head = hidden_size // self.num_attention_heads
self.rotary_emb = nn.RoPE(hidden_size_per_attention_head, traditional=False)
proj_size = args.kv_channels * self.num_attention_heads
self.c_attn = nn.Linear(hidden_size, proj_size * 3, bias=True)
self.c_proj = nn.Linear(hidden_size, proj_size, bias=not args.no_bias)
self.scale = hidden_size_per_attention_head**-0.5
def __call__(self, x, mask=None, cache=None):
qkv = self.c_attn(x)
q, k, v = mx.split(qkv, 3, axis=-1)
B, L, _ = q.shape
q = q.reshape(B, L, self.num_attention_heads, -1).transpose(0, 2, 1, 3)
k = k.reshape(B, L, self.num_attention_heads, -1).transpose(0, 2, 1, 3)
v = v.reshape(B, L, self.num_attention_heads, -1).transpose(0, 2, 1, 3)
if cache is not None:
k_cache, v_cache = cache
q = self.rotary_emb(q, offset=k_cache.shape[2])
k = self.rotary_emb(k, offset=k_cache.shape[2])
k = mx.concatenate([k_cache, k], axis=2)
v = mx.concatenate([v_cache, v], axis=2)
else:
q = self.rotary_emb(q)
k = self.rotary_emb(k)
scores = (q * self.scale) @ k.transpose(0, 1, 3, 2)
if mask is not None:
scores = scores + mask
scores = mx.softmax(scores.astype(mx.float32), axis=-1).astype(scores.dtype)
v_hat = (scores @ v).transpose(0, 2, 1, 3).reshape(B, L, -1)
return self.c_proj(v_hat), (k, v)
class MLP(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.w1 = nn.Linear(
args.hidden_size, args.intermediate_size // 2, bias=not args.no_bias
)
self.w2 = nn.Linear(
args.intermediate_size // 2, args.hidden_size, bias=not args.no_bias
)
self.c_proj = nn.Linear(
args.intermediate_size // 2, args.hidden_size, bias=not args.no_bias
)
def __call__(self, x):
a1 = self.w1(x)
a2 = self.w2(x)
return self.c_proj(a1 * nn.silu(a2))
class TransformerBlock(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.ln_1 = RMSNorm(args.hidden_size, eps=args.layer_norm_epsilon)
self.attn = Attention(args)
self.ln_2 = RMSNorm(args.hidden_size, eps=args.layer_norm_epsilon)
self.mlp = MLP(args)
def __call__(self, x, mask=None, cache=None):
residual = x
x = self.ln_1(x)
x, cache = self.attn(x, mask=mask, cache=cache)
residual = x + residual
x = self.ln_2(residual)
x = self.mlp(x)
x = x + residual
return x, cache
class Qwen(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.embed_dim = args.hidden_size
self.wte = nn.Embedding(args.vocab_size, args.hidden_size)
self.h = [TransformerBlock(args) for _ in range(args.num_hidden_layers)]
self.ln_f = RMSNorm(self.embed_dim, eps=args.layer_norm_epsilon)
self.lm_head = nn.Linear(self.embed_dim, args.vocab_size, bias=False)
def __call__(self, inputs, mask=None, cache=None):
x = self.wte(inputs)
mask = None
T = x.shape[1]
if T > 1:
mask = nn.MultiHeadAttention.create_additive_causal_mask(T)
mask = mask.astype(x.dtype)
if cache is None:
cache = [None] * len(self.h)
for e, layer in enumerate(self.h):
x, cache[e] = layer(x, mask, cache[e])
x = self.ln_f(x[:, T - 1 : T, :])
return self.lm_head(x), cache
def generate(prompt: mx.array, model: Qwen, temp: 0.0):
def sample(logits):
if temp == 0:
return mx.argmax(logits, axis=-1)
else:
return mx.random.categorical(logits * (1 / temp))
logits, cache = model(prompt)
y = sample(logits[:, -1, :])
yield y
while True:
logits, cache = model(y[:, None], cache=cache)
y = sample(logits.squeeze(1))
yield y
def load_model(
tokenizer_path: str = "Qwen/Qwen-1_8B", config_path: str = "config.json"
):
model_args = ModelArgs()
with open(config_path, "r") as f:
config = json.load(f)
model_args.vocab_size = config["vocab_size"]
model_args.hidden_size = config["hidden_size"]
model_args.num_attention_heads = config["num_attention_heads"]
model_args.num_hidden_layers = config["num_hidden_layers"]
model_args.kv_channels = config["kv_channels"]
model_args.max_position_embeddings = config["max_position_embeddings"]
model_args.layer_norm_epsilon = config["layer_norm_epsilon"]
model_args.intermediate_size = config["intermediate_size"]
model_args.no_bias = config["no_bias"]
model = Qwen(model_args)
weights = mx.load("weights.npz")
model.update(tree_unflatten(list(weights.items())))
tokenizer = AutoTokenizer.from_pretrained(
tokenizer_path, trust_remote_code=True, eos_token="<|endoftext|>"
)
return model, tokenizer
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Qwen inference script")
parser.add_argument(
"--tokenizer",
help="The tokenizer to be used, defaults to Qwen/Qwen-1_8B",
default="Qwen/Qwen-1_8B",
)
parser.add_argument(
"--prompt",
help="The message to be processed by the model",
# The example from the official huggingface repo of Qwen
default="蒙古国的首都是乌兰巴托Ulaanbaatar\n冰岛的首都是雷克雅未克Reykjavik\n埃塞俄比亚的首都是",
)
parser.add_argument(
"--max_tokens",
"-m",
type=int,
default=100,
help="Maximum number of tokens to generate",
)
parser.add_argument(
"--temp",
help="The sampling temperature.",
type=float,
default=0.0,
)
parser.add_argument("--seed", type=int, default=0, help="The PRNG seed")
args = parser.parse_args()
mx.random.seed(args.seed)
model, tokenizer = load_model(args.tokenizer)
prompt = tokenizer(
args.prompt,
return_tensors="np",
return_attention_mask=False,
)["input_ids"]
prompt = mx.array(prompt)
print(args.prompt, end="", flush=True)
tokens = []
for token, _ in zip(generate(prompt, model, args.temp), range(args.max_tokens)):
tokens.append(token)
if (len(tokens) % 10) == 0:
mx.eval(tokens)
eos_index = next(
(i for i, t in enumerate(tokens) if t.item() == tokenizer.eos_token_id),
None,
)
if eos_index is not None:
tokens = tokens[:eos_index]
s = tokenizer.decode([t.item() for t in tokens])
print(s, end="", flush=True)
tokens = []
if eos_index is not None:
break
mx.eval(tokens)
s = tokenizer.decode([t.item() for t in tokens])
print(s, flush=True)

View File

@ -0,0 +1,7 @@
einops
mlx
numpy
transformers>=4.35
transformers_stream_generator>=0.0.4
torch
tiktoken

View File

@ -32,21 +32,30 @@ if __name__ == "__main__":
os.makedirs(args.mlx_model) os.makedirs(args.mlx_model)
mlx_path = Path(args.mlx_model) mlx_path = Path(args.mlx_model)
# Copy the tokenizer
tokenizer_path = torch_path / "tokenizer.model"
if not tokenizer_path.exists():
print(f"Make sure there is a file tokenizer.model in {args.torch_model}")
exit(0)
shutil.copyfile(
str(tokenizer_path),
str(mlx_path / "tokenizer.model"),
)
# Copy the model weights
state = torch.load(str(torch_path / "consolidated.00.pth")) state = torch.load(str(torch_path / "consolidated.00.pth"))
np.savez( np.savez(
str(mlx_path / "weights.npz"), str(mlx_path / "weights.npz"),
**{k: v.to(torch.float16).numpy() for k, v in state.items()} **{k: v.to(torch.float16).numpy() for k, v in state.items()},
)
# Copy the tokenizer
shutil.copyfile(
str(torch_path / "tokenizer.model"),
str(mlx_path / "tokenizer.model"),
) )
# Copy the params # Copy the params
with open(torch_path / "params.json", "r") as f: with open(torch_path / "params.json", "r") as f:
config = json.loads(f.read()) config = json.loads(f.read())
unused = ["multiple_of"]
for k in unused:
if k in config:
config.pop(k)
n_heads = config["n_heads"] n_heads = config["n_heads"]
if "sliding_window" in config: if "sliding_window" in config:
config.pop("sliding_window") config.pop("sliding_window")
@ -55,6 +64,6 @@ if __name__ == "__main__":
if "head_dim" not in config: if "head_dim" not in config:
config["head_dim"] = config["dim"] // n_heads config["head_dim"] = config["dim"] // n_heads
if "hidden_dim" not in config: if "hidden_dim" not in config:
config["hidden_dim"] = state["layers.0.feed_forward.w1.weight"].shape config["hidden_dim"] = state["layers.0.feed_forward.w1.weight"].shape[0]
with open(mlx_path / "params.json", "w") as outfile: with open(mlx_path / "params.json", "w") as outfile:
json.dump(config, outfile) json.dump(config, outfile, indent=4)

View File

@ -332,9 +332,9 @@ def load_model(folder: str, dtype=mx.float16):
tokenizer = Tokenizer(str(model_path / "tokenizer.model")) tokenizer = Tokenizer(str(model_path / "tokenizer.model"))
with open(model_path / "params.json", "r") as f: with open(model_path / "params.json", "r") as f:
config = json.loads(f.read()) config = json.loads(f.read())
model_args = ModelArgs(**config)
if config.get("vocab_size", -1) < 0: if config.get("vocab_size", -1) < 0:
config["vocab_size"] = tokenizer.vocab_size config["vocab_size"] = tokenizer.vocab_size
model_args = ModelArgs(**config)
weights = mx.load(str(model_path / "weights.npz")) weights = mx.load(str(model_path / "weights.npz"))
weights = tree_unflatten(list(weights.items())) weights = tree_unflatten(list(weights.items()))
weights = tree_map(lambda p: p.astype(dtype), weights) weights = tree_map(lambda p: p.astype(dtype), weights)

1
t5/.gitignore vendored Normal file
View File

@ -0,0 +1 @@
*.npz

53
t5/README.md Normal file
View File

@ -0,0 +1,53 @@
# T5
The T5 models are encoder-decoder models pre-trained on a mixture of
unsupervised and supervised tasks.[^1] These models work well on a variety of
tasks by prepending task-specific prefixes to the input, e.g.:
`translate English to German: …`, `summarize: ….`, etc.
This example also supports the FLAN-T5 models variants.[^2]
## Setup
Download and convert the model:
```sh
python convert.py --model <model>
```
This will make the `<model>.npz` file which MLX can read.
The `<model>` can be any of the following:
| Model Name | Model Size |
| ---------- | ----------
| t5-small | 60 million |
| t5-base | 220 million |
| t5-large | 770 million |
| t5-3b | 3 billion |
| t5-11b | 11 billion |
The FLAN variants can be specified with `google/flan-t5-small`,
`google/flan-t5-base`, etc. See the [Hugging Face
page](https://huggingface.co/docs/transformers/model_doc/flan-t5) for a
complete list of models.
## Generate
Generate text with:
```sh
python t5.py --model t5-small --prompt "translate English to German: A tasty apple"
```
This should give the output: `Ein leckerer Apfel`
To see a list of options run:
```sh
python t5.py --help
```
[^1]: For more information on T5 see the [original paper](https://arxiv.org/abs/1910.10683)
or the [Hugging Face page](https://huggingface.co/docs/transformers/model_doc/t5).
[^2]: For more information on FLAN-T5 see the [original paper](https://arxiv.org/abs/2210.11416).

77
t5/convert.py Normal file
View File

@ -0,0 +1,77 @@
from transformers import T5ForConditionalGeneration
import numpy as np
SHARED_REPLACEMENT_PATTERNS = [
(".block.", ".layers."),
(".k.", ".key_proj."),
(".o.", ".out_proj."),
(".q.", ".query_proj."),
(".v.", ".value_proj."),
("shared.", "wte."),
("lm_head.", "lm_head.linear."),
(".layer.0.layer_norm.", ".ln1."),
(".layer.1.layer_norm.", ".ln2."),
(".layer.2.layer_norm.", ".ln3."),
(".final_layer_norm.", ".ln."),
(
"layers.0.layer.0.SelfAttention.relative_attention_bias.",
"relative_attention_bias.embeddings.",
),
]
ENCODER_REPLACEMENT_PATTERNS = [
(".layer.0.SelfAttention.", ".attention."),
(".layer.1.DenseReluDense.", ".dense."),
]
DECODER_REPLACEMENT_PATTERNS = [
(".layer.0.SelfAttention.", ".self_attention."),
(".layer.1.EncDecAttention.", ".cross_attention."),
(".layer.2.DenseReluDense.", ".dense."),
]
def replace_key(key: str) -> str:
for old, new in SHARED_REPLACEMENT_PATTERNS:
key = key.replace(old, new)
if key.startswith("encoder."):
for old, new in ENCODER_REPLACEMENT_PATTERNS:
key = key.replace(old, new)
elif key.startswith("decoder."):
for old, new in DECODER_REPLACEMENT_PATTERNS:
key = key.replace(old, new)
return key
def convert(model_name, dtype):
dtype = getattr(np, dtype)
model = T5ForConditionalGeneration.from_pretrained(model_name, torch_dtype="auto")
weights = {
replace_key(k): v.numpy().astype(dtype)
for k, v in model.state_dict().items()
}
file_name = model_name.replace("/", "-")
print(f"Saving weights to {file_name}.npz")
np.savez(f"{file_name}.npz", **weights)
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser(description="Convert T5 weights to MLX")
parser.add_argument(
"--model",
type=str,
help="Name of the T5 model.",
default="t5-small",
)
parser.add_argument(
"--dtype",
help="The model data type.",
type=str,
choices=["float16", "float32"],
default="float32",
)
args = parser.parse_args()
convert(args.model, args.dtype)

54
t5/hf_t5.py Normal file
View File

@ -0,0 +1,54 @@
from transformers import T5ForConditionalGeneration, T5EncoderModel, AutoTokenizer
import argparse
def embed(t5_model: str):
batch = [
"translate English to German: That is good.",
"This is an example of T5 working on MLX.",
]
tokenizer = AutoTokenizer.from_pretrained(t5_model)
torch_model = T5EncoderModel.from_pretrained(t5_model)
torch_tokens = tokenizer(batch, return_tensors="pt", padding=True)
torch_forward = torch_model(**torch_tokens, output_hidden_states=True)
torch_output = torch_forward.last_hidden_state.detach().numpy()
print("\n TF BERT:")
for input_str, embedding in list(zip(batch, torch_output)):
print("Input:", input_str)
print(embedding)
print()
def generate(t5_model: str):
prompt = "translate English to German: As much as six inches of rain could fall in the New York City region through Monday morning, and officials warned of flooding along the coast."
tokenizer = AutoTokenizer.from_pretrained(t5_model)
torch_model = T5ForConditionalGeneration.from_pretrained(t5_model)
torch_tokens = tokenizer(prompt, return_tensors="pt", padding=True).input_ids
outputs = torch_model.generate(torch_tokens, do_sample=False, max_length=512)
print(tokenizer.decode(outputs[0], skip_special_tokens=True))
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="Run the T5 model using Hugging Face Transformers."
)
parser.add_argument(
"--encode-only",
action="store_true",
help="Only run the encoder and print the embeddings.",
default=False,
)
parser.add_argument(
"--model",
default="t5-small",
help="The huggingface name of the T5 model to save.",
)
args = parser.parse_args()
if args.encode_only:
embed(args.model)
else:
generate(args.model)

3
t5/requirements.txt Normal file
View File

@ -0,0 +1,3 @@
mlx
numpy
transformers

469
t5/t5.py Normal file
View File

@ -0,0 +1,469 @@
import argparse
from typing import Optional, Tuple, List
from time import perf_counter_ns
import numpy as np
import mlx.core as mx
import mlx.nn as nn
from mlx.utils import tree_unflatten, tree_map
from transformers import T5Config, T5Tokenizer
def _relative_position_bucket(
relative_position, bidirectional=True, num_buckets=32, max_distance=128
):
"""
Adapted from HF Tensorflow:
https://github.com/huggingface/transformers/blob/main/src/transformers/models/t5/modeling_t5.py
Translate relative position to a bucket number for relative attention. The relative position is defined as
memory_position - query_position, i.e. the distance in tokens from the attending position to the attended-to
position. If bidirectional=False, then positive relative positions are invalid. We use smaller buckets for
small absolute relative_position and larger buckets for larger absolute relative_positions. All relative
positions >=max_distance map to the same bucket. All relative positions <=-max_distance map to the same bucket.
This should allow for more graceful generalization to longer sequences than the model has been trained on
Args:
relative_position: an int32 Tensor
bidirectional: a boolean - whether the attention is bidirectional
num_buckets: an integer
max_distance: an integer
Returns:
a Tensor with the same shape as relative_position, containing int32 values in the range [0, num_buckets)
"""
relative_buckets = 0
if bidirectional:
num_buckets //= 2
relative_buckets += (relative_position > 0).astype(mx.int16) * num_buckets
relative_position = mx.abs(relative_position)
else:
relative_position = -mx.minimum(
relative_position, mx.zeros_like(relative_position)
)
# now relative_position is in the range [0, inf)
# half of the buckets are for exact increments in positions
max_exact = num_buckets // 2
is_small = relative_position < max_exact
# The other half of the buckets are for logarithmically bigger bins in positions up to max_distance
scale = (num_buckets - max_exact) / np.log(max_distance / max_exact)
relative_position_if_large = max_exact + (
mx.log(relative_position.astype(mx.float32) / max_exact) * scale
).astype(mx.int16)
relative_position_if_large = mx.minimum(relative_position_if_large, num_buckets - 1)
relative_buckets += mx.where(
is_small, relative_position, relative_position_if_large
)
return relative_buckets
class RelativePositionBias(nn.Module):
def __init__(self, config: T5Config, bidirectional: bool):
self.bidirectional = bidirectional
self.num_buckets = config.relative_attention_num_buckets
self.max_distance = config.relative_attention_max_distance
self.n_heads = config.num_heads
self.embeddings = nn.Embedding(
config.relative_attention_num_buckets, config.num_heads
)
def __call__(self, query_length: int, key_length: int, offset: int = 0):
"""Compute binned relative position bias"""
context_position = mx.arange(offset, query_length)[:, None]
memory_position = mx.arange(key_length)[None, :]
# shape (query_length, key_length)
relative_position = memory_position - context_position
relative_position_bucket = _relative_position_bucket(
relative_position,
bidirectional=self.bidirectional,
num_buckets=self.num_buckets,
max_distance=self.max_distance,
)
# shape (query_length, key_length, num_heads)
values = self.embeddings(relative_position_bucket)
# shape (num_heads, query_length, key_length)
return values.transpose(2, 0, 1)
class MultiHeadAttention(nn.Module):
def __init__(self, config: T5Config):
super().__init__()
inner_dim = config.d_kv * config.num_heads
self.num_heads = config.num_heads
self.query_proj = nn.Linear(config.d_model, inner_dim, bias=False)
self.key_proj = nn.Linear(config.d_model, inner_dim, bias=False)
self.value_proj = nn.Linear(config.d_model, inner_dim, bias=False)
self.out_proj = nn.Linear(inner_dim, config.d_model, bias=False)
def __call__(
self,
queries: mx.array,
keys: mx.array,
values: mx.array,
mask: Optional[mx.array],
cache: Optional[Tuple[mx.array, mx.array]] = None,
) -> [mx.array, Tuple[mx.array, mx.array]]:
queries = self.query_proj(queries)
keys = self.key_proj(keys)
values = self.value_proj(values)
num_heads = self.num_heads
B, L, _ = queries.shape
_, S, _ = keys.shape
queries = queries.reshape(B, L, num_heads, -1).transpose(0, 2, 1, 3)
keys = keys.reshape(B, S, num_heads, -1).transpose(0, 2, 3, 1)
values = values.reshape(B, S, num_heads, -1).transpose(0, 2, 1, 3)
if cache is not None:
key_cache, value_cache = cache
keys = mx.concatenate([key_cache, keys], axis=3)
values = mx.concatenate([value_cache, values], axis=2)
# Dimensions are [batch x num heads x sequence x hidden dim]
queries = queries
scores = queries @ keys
if mask is not None:
scores = scores + mask.astype(scores.dtype)
scores = mx.softmax(scores.astype(mx.float32), axis=-1).astype(scores.dtype)
values_hat = (scores @ values).transpose(0, 2, 1, 3).reshape(B, L, -1)
return self.out_proj(values_hat), (keys, values)
class RMSNorm(nn.Module):
def __init__(self, dims: int, eps: float = 1e-5):
super().__init__()
self.weight = mx.ones((dims,))
self.eps = eps
def _norm(self, x):
return x * mx.rsqrt(x.square().mean(-1, keepdims=True) + self.eps)
def __call__(self, x):
t = x.dtype
output = self._norm(x).astype(t)
return self.weight * output
class DenseActivation(nn.Module):
def __init__(self, config: T5Config):
super().__init__()
mlp_dims = config.d_ff or config.d_model * 4
self.gated = config.feed_forward_proj.startswith("gated")
if self.gated:
self.wi_0 = nn.Linear(config.d_model, mlp_dims, bias=False)
self.wi_1 = nn.Linear(config.d_model, mlp_dims, bias=False)
else:
self.wi = nn.Linear(config.d_model, mlp_dims, bias=False)
self.wo = nn.Linear(mlp_dims, config.d_model, bias=False)
activation = config.feed_forward_proj.removeprefix("gated-")
if activation == "relu":
self.act = nn.relu
elif activation == "gelu":
self.act = nn.gelu
elif activation == "silu":
self.act = nn.silu
else:
raise ValueError(f"Unknown activation: {activation}")
def __call__(self, x):
if self.gated:
hidden_act = self.act(self.wi_0(x))
hidden_linear = self.wi_1(x)
x = hidden_act * hidden_linear
else:
x = self.act(self.wi(x))
return self.wo(x)
class TransformerEncoderLayer(nn.Module):
def __init__(self, config: T5Config):
super().__init__()
self.attention = MultiHeadAttention(config)
self.ln1 = RMSNorm(config.d_model, eps=config.layer_norm_epsilon)
self.ln2 = RMSNorm(config.d_model, eps=config.layer_norm_epsilon)
self.dense = DenseActivation(config)
def __call__(self, x, mask):
y = self.ln1(x)
y, _ = self.attention(y, y, y, mask=mask)
x = x + y
y = self.ln2(x)
y = self.dense(y)
return x + y
class TransformerEncoder(nn.Module):
def __init__(self, config: T5Config):
super().__init__()
self.layers = [
TransformerEncoderLayer(config) for i in range(config.num_layers)
]
self.ln = RMSNorm(config.d_model, eps=config.layer_norm_epsilon)
self.relative_attention_bias = RelativePositionBias(config, bidirectional=True)
def __call__(self, x: mx.array):
pos_bias = self.relative_attention_bias(x.shape[1], x.shape[1])
for layer in self.layers:
x = layer(x, mask=pos_bias)
return self.ln(x)
class TransformerDecoderLayer(nn.Module):
def __init__(self, config: T5Config):
super().__init__()
self.self_attention = MultiHeadAttention(config)
self.cross_attention = MultiHeadAttention(config)
self.ln1 = RMSNorm(config.d_model, eps=config.layer_norm_epsilon)
self.ln2 = RMSNorm(config.d_model, eps=config.layer_norm_epsilon)
self.ln3 = RMSNorm(config.d_model, eps=config.layer_norm_epsilon)
self.dense = DenseActivation(config)
def __call__(
self,
x: mx.array,
memory: mx.array,
mask: mx.array,
memory_mask: mx.array,
cache: Optional[List[Tuple[mx.array, mx.array]]] = None,
):
y = self.ln1(x)
y, cache = self.self_attention(y, y, y, mask, cache)
x = x + y
y = self.ln2(x)
y, _ = self.cross_attention(y, memory, memory, memory_mask)
x = x + y
y = self.ln3(x)
y = self.dense(y)
x = x + y
return x, cache
class TransformerDecoder(nn.Module):
def __init__(self, config: T5Config):
super().__init__()
self.layers = [
TransformerDecoderLayer(config) for i in range(config.num_layers)
]
self.ln = RMSNorm(config.d_model, eps=config.layer_norm_epsilon)
self.relative_attention_bias = RelativePositionBias(config, bidirectional=False)
def __call__(self, x, memory, mask, memory_mask, cache=None):
if cache is not None:
offset = cache[0][0].shape[3]
else:
offset = 0
cache = [None] * len(self.layers)
T = offset + x.shape[1]
pos_bias = self.relative_attention_bias(T, T, offset=offset)
if mask is not None:
mask += pos_bias
else:
mask = pos_bias
for e, layer in enumerate(self.layers):
x, cache[e] = layer(x, memory, mask, memory_mask, cache=cache[e])
x = self.ln(x)
return x, cache
class OutputHead(nn.Module):
def __init__(self, config: T5Config):
self.linear = nn.Linear(config.d_model, config.vocab_size, bias=False)
def __call__(self, inputs):
return self.linear(inputs)
class T5(nn.Module):
def __init__(self, config: T5Config):
self.wte = nn.Embedding(config.vocab_size, config.d_model)
self.encoder = TransformerEncoder(config)
self.decoder = TransformerDecoder(config)
self.tie_word_embeddings = config.tie_word_embeddings
if not self.tie_word_embeddings:
self.lm_head = OutputHead(config)
self.model_dim = config.d_model
def encode(self, inputs: mx.array):
return self.encoder(self.wte(inputs))
def decode(
self,
inputs: mx.array,
memory: mx.array,
cache=None,
):
inputs = self.wte(inputs)
T = inputs.shape[1]
if T > 1:
mask = nn.MultiHeadAttention.create_additive_causal_mask(T)
mask = mask.astype(inputs.dtype)
else:
mask = None
y, cache = self.decoder(
inputs, memory=memory, mask=mask, memory_mask=None, cache=cache
)
if not self.tie_word_embeddings:
y *= self.model_dim**-0.5
y = self.lm_head(y)
else:
y = y @ self.wte.weight.T
return y, cache
def __call__(
self,
inputs: mx.array,
decoder_inputs: mx.array,
):
return self.decode(decoder_inputs, self.encode(inputs))[0]
class Tokenizer:
def __init__(self, model_name: str, config: T5Config):
self._decoder_start_id = config.decoder_start_token_id
self._tokenizer = T5Tokenizer.from_pretrained(
args.model,
legacy=False,
model_max_length=getattr(config, 'n_positions', 512)
)
@property
def eos_id(self) -> int:
return self._tokenizer.eos_token_id
@property
def decoder_start_id(self) -> int:
return self._decoder_start_id
def encode(self, s: str) -> mx.array:
return mx.array(
self._tokenizer(
s,
return_tensors="np",
return_attention_mask=False,
)["input_ids"]
)
def decode(self, t: List[int], with_sep: bool = True) -> str:
tokens = self._tokenizer.convert_ids_to_tokens(t)
return "".join(t.replace("", " " if with_sep else "") for t in tokens)
def generate(prompt: str, model: T5, tokenizer: Tokenizer, temp: Optional[float] = 0.0):
def sample(logits):
if temp == 0:
return mx.argmax(logits, axis=-1)
else:
return mx.random.categorical(logits * (1 / temp))
prompt = tokenizer.encode(prompt)
decoder_inputs = mx.array([tokenizer.decoder_start_id])
memory = model.encode(prompt)
cache = None
y = decoder_inputs
while True:
logits, cache = model.decode(y[None], memory, cache=cache)
y = sample(logits[:, -1, :])
yield y.squeeze()
def load_model(model_name: str, dtype: str = "float16"):
config = T5Config.from_pretrained(args.model)
dtype = getattr(mx, dtype)
model = T5(config)
file_name = model_name.replace("/", "-")
weights = mx.load(f"{file_name}.npz")
weights = tree_unflatten(list(weights.items()))
weights = tree_map(lambda p: p.astype(dtype), weights)
model.update(weights)
mx.eval(model.parameters())
return model, Tokenizer(args.model, config)
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="T5 Inference script")
parser.add_argument(
"--model",
type=str,
help="Name of the T5 model.",
default="t5-small",
)
parser.add_argument(
"--prompt",
help="",
default="translate English to German: That is good.",
)
parser.add_argument(
"--encode-only",
action="store_true",
default=False,
help="Whether to decode or not. If true, will output last layer of encoder.",
)
parser.add_argument(
"--max-tokens",
"-m",
type=int,
default=100,
help="Maximum number of tokens to generate",
)
parser.add_argument(
"--temp",
help="The sampling temperature.",
type=float,
default=0.0,
)
parser.add_argument(
"--dtype",
help="The model data type.",
type=str,
choices=["float16", "bfloat16", "float32"],
default="bfloat16",
)
parser.add_argument("--seed", type=int, default=0, help="The PRNG seed")
args = parser.parse_args()
mx.random.seed(args.seed)
model, tokenizer = load_model(args.model, args.dtype)
if args.encode_only:
print("[INFO] Encoding with T5...", flush=True)
print(args.prompt, flush=True)
encoder_output = model.encode(tokenizer.encode(args.prompt))
print(encoder_output, flush=True)
exit(0)
print("[INFO] Generating with T5...", flush=True)
print("Input: ", args.prompt, flush=True)
start = perf_counter_ns()
for token, n_tokens in zip(
generate(args.prompt, model, tokenizer, args.temp), range(args.max_tokens)
):
if token.item() == tokenizer.eos_id:
break
print(
tokenizer.decode([token.item()], with_sep=n_tokens > 0),
end="",
flush=True,
)
n_tokens += 1
end = perf_counter_ns()
elapsed = (end - start) / 1.0e9
print()
print(f"Time: {elapsed:.2f} seconds, tokens/s: {n_tokens / elapsed:.2f}")