mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-08-30 02:53:41 +08:00
Merge branch 'ml-explore:main' into main
This commit is contained in:
commit
a33c3095c4
10
ACKNOWLEDGMENTS.md
Normal file
10
ACKNOWLEDGMENTS.md
Normal file
@ -0,0 +1,10 @@
|
|||||||
|
# Individual Contributors
|
||||||
|
|
||||||
|
If you wish to be acknowledged for your contributions, please list your name
|
||||||
|
with a short description of your contribution(s) below. For example:
|
||||||
|
|
||||||
|
- Jane Smith: Added the `foo` example.
|
||||||
|
|
||||||
|
MLX Examples was developed with contributions from the following individuals:
|
||||||
|
|
||||||
|
- Juarez Bochi: Added support for T5 models.
|
23
README.md
23
README.md
@ -18,5 +18,24 @@ Some more useful examples include:
|
|||||||
|
|
||||||
## Contributing
|
## Contributing
|
||||||
|
|
||||||
Check out the [contribution guidelines](CONTRIBUTING.md) for more information
|
We are grateful for all of [our
|
||||||
on contributing to this repo.
|
contributors](ACKNOWLEDGMENTS.md#Individual-Contributors). If you contribute
|
||||||
|
to MLX Examples and wish to be acknowledged, please add your name to to the list in your
|
||||||
|
pull request.
|
||||||
|
|
||||||
|
## Citing MLX Examples
|
||||||
|
|
||||||
|
The MLX software suite was initially developed with equal contribution by Awni
|
||||||
|
Hannun, Jagrit Digani, Angelos Katharopoulos, and Ronan Collobert. If you find
|
||||||
|
MLX Examples useful in your research and wish to cite it, please use the following
|
||||||
|
BibTex entry:
|
||||||
|
|
||||||
|
```
|
||||||
|
@software{mlx2023,
|
||||||
|
author = {Awni Hannun and Jagrit Digani and Angelos Katharopoulos and Ronan Collobert},
|
||||||
|
title = {{MLX}: Efficient and flexible machine learning on Apple silicon},
|
||||||
|
url = {https://github.com/ml-explore},
|
||||||
|
version = {0.0},
|
||||||
|
year = {2023},
|
||||||
|
}
|
||||||
|
```
|
||||||
|
@ -3,8 +3,8 @@
|
|||||||
An example of generating text with Llama (1 or 2) using MLX.
|
An example of generating text with Llama (1 or 2) using MLX.
|
||||||
|
|
||||||
Llama is a set of open source language models from Meta AI Research[^1][^2]
|
Llama is a set of open source language models from Meta AI Research[^1][^2]
|
||||||
ranging from 7B to 70B parameters. This example also supports Llama Chat and
|
ranging from 7B to 70B parameters. This example also supports Meta's Llama Chat
|
||||||
Code Llama.
|
and Code Llama models, as well as the 1.1B TinyLlama models from SUTD.[^3]
|
||||||
|
|
||||||
### Setup
|
### Setup
|
||||||
|
|
||||||
@ -25,10 +25,19 @@ Alternatively, you can also download a select converted checkpoints from the
|
|||||||
[mlx-llama](https://huggingface.co/mlx-llama) community organisation on Hugging
|
[mlx-llama](https://huggingface.co/mlx-llama) community organisation on Hugging
|
||||||
Face and skip the conversion step.
|
Face and skip the conversion step.
|
||||||
|
|
||||||
|
You can download the TinyLlama models directly from [Hugging
|
||||||
|
Face](https://huggingface.co/TinyLlama).
|
||||||
|
|
||||||
Convert the weights with:
|
Convert the weights with:
|
||||||
|
|
||||||
```
|
```
|
||||||
python convert.py --model_path <path_to_torch_model>
|
python convert.py --model-path <path_to_torch_model>
|
||||||
|
```
|
||||||
|
|
||||||
|
For TinyLlama use
|
||||||
|
|
||||||
|
```
|
||||||
|
python convert.py --model-path <path_to_torch_model> --model-name tiny_llama
|
||||||
```
|
```
|
||||||
|
|
||||||
The conversion script will save the converted weights in the same location.
|
The conversion script will save the converted weights in the same location.
|
||||||
@ -39,10 +48,11 @@ Once you've converted the weights to MLX format, you can interact with the
|
|||||||
LlaMA model:
|
LlaMA model:
|
||||||
|
|
||||||
```
|
```
|
||||||
python llama.py <path_to_model> <path_to_tokenizer.model> "hello"
|
python llama.py <path_to_model> <path_to_tokenizer.model> --prompt "hello"
|
||||||
```
|
```
|
||||||
|
|
||||||
Run `python llama.py --help` for more details.
|
Run `python llama.py --help` for more details.
|
||||||
|
|
||||||
[^1]: For Llama v1 refer to the [arXiv paper](https://arxiv.org/abs/2302.13971) and [blog post](https://ai.meta.com/blog/large-language-model-llama-meta-ai/) for more details.
|
[^1]: For Llama v1 refer to the [arXiv paper](https://arxiv.org/abs/2302.13971) and [blog post](https://ai.meta.com/blog/large-language-model-llama-meta-ai/) for more details.
|
||||||
[^2]: For Llama v2 refer to the [blob post](https://ai.meta.com/llama/)
|
[^2]: For Llama v2 refer to the [blob post](https://ai.meta.com/llama/)
|
||||||
|
[^3]: For TinyLlama refer to the [gihub repository](https://github.com/jzhang38/TinyLlama?tab=readme-ov-file)
|
||||||
|
147
llama/convert.py
147
llama/convert.py
@ -3,45 +3,35 @@
|
|||||||
import argparse
|
import argparse
|
||||||
import collections
|
import collections
|
||||||
import glob
|
import glob
|
||||||
from pathlib import Path
|
import json
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
from pathlib import Path
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
SHARD_FIRST = ["wv", "wq", "wk", "w1", "w3", "output"]
|
|
||||||
SHARD_SECOND = ["tok_embeddings", "wo", "w2"]
|
|
||||||
SHARD_WEIGHTS = set(SHARD_FIRST + SHARD_SECOND)
|
|
||||||
|
|
||||||
|
def llama(model_path):
|
||||||
|
SHARD_FIRST = ["wv", "wq", "wk", "w1", "w3", "output"]
|
||||||
|
SHARD_SECOND = ["tok_embeddings", "wo", "w2"]
|
||||||
|
SHARD_WEIGHTS = set(SHARD_FIRST + SHARD_SECOND)
|
||||||
|
|
||||||
def shard_key(k):
|
def shard_key(k):
|
||||||
keys = k.split(".")
|
keys = k.split(".")
|
||||||
if len(keys) < 2:
|
if len(keys) < 2:
|
||||||
return None
|
return None
|
||||||
return keys[-2]
|
return keys[-2]
|
||||||
|
|
||||||
|
def unshard(k, v):
|
||||||
|
wn = shard_key(k)
|
||||||
|
if wn not in SHARD_WEIGHTS:
|
||||||
|
return v
|
||||||
|
elif wn in SHARD_FIRST:
|
||||||
|
axis = 0
|
||||||
|
elif wn in SHARD_SECOND:
|
||||||
|
axis = 1
|
||||||
|
else:
|
||||||
|
raise ValueError("Invalid weight name")
|
||||||
|
return np.concatenate(v, axis=axis)
|
||||||
|
|
||||||
def unshard(k, v):
|
|
||||||
wn = shard_key(k)
|
|
||||||
if wn not in SHARD_WEIGHTS:
|
|
||||||
return v
|
|
||||||
elif wn in SHARD_FIRST:
|
|
||||||
axis = 0
|
|
||||||
elif wn in SHARD_SECOND:
|
|
||||||
axis = 1
|
|
||||||
else:
|
|
||||||
raise ValueError("Invalid weight name")
|
|
||||||
return np.concatenate(v, axis=axis)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
parser = argparse.ArgumentParser(description="Convert Llama weights to MLX")
|
|
||||||
parser.add_argument(
|
|
||||||
"--model_path",
|
|
||||||
help="Path to the Torch model. The MLX weights will also be saved there.",
|
|
||||||
)
|
|
||||||
args = parser.parse_args()
|
|
||||||
|
|
||||||
model_path = Path(args.model_path)
|
|
||||||
torch_files = glob.glob(str(model_path / "consolidated.*.pth"))
|
torch_files = glob.glob(str(model_path / "consolidated.*.pth"))
|
||||||
weights = collections.defaultdict(list)
|
weights = collections.defaultdict(list)
|
||||||
for wf in torch_files:
|
for wf in torch_files:
|
||||||
@ -53,7 +43,96 @@ if __name__ == "__main__":
|
|||||||
else:
|
else:
|
||||||
weights[k] = v
|
weights[k] = v
|
||||||
|
|
||||||
out_file = str(model_path / "weights.npz")
|
|
||||||
for k, v in weights.items():
|
for k, v in weights.items():
|
||||||
weights[k] = unshard(k, v)
|
weights[k] = unshard(k, v)
|
||||||
np.savez(out_file, **weights)
|
return weights, None
|
||||||
|
|
||||||
|
|
||||||
|
def tiny_llama(model_path):
|
||||||
|
try:
|
||||||
|
import transformers
|
||||||
|
except ImportError as e:
|
||||||
|
print("The transformers package must be installed for this model conversion:")
|
||||||
|
print("pip install transformers")
|
||||||
|
import sys
|
||||||
|
|
||||||
|
sys.exit(0)
|
||||||
|
|
||||||
|
model = transformers.AutoModelForCausalLM.from_pretrained(
|
||||||
|
str(model_path)
|
||||||
|
).state_dict()
|
||||||
|
config = transformers.AutoConfig.from_pretrained(model_path)
|
||||||
|
|
||||||
|
# things to change
|
||||||
|
# 1. there's no "model." in the weight names
|
||||||
|
model = {k.replace("model.", ""): v for k, v in model.items()}
|
||||||
|
|
||||||
|
# 2. mlp is called feed_forward
|
||||||
|
model = {k.replace("mlp", "feed_forward"): v for k, v in model.items()}
|
||||||
|
|
||||||
|
# 3. up_proj, down_proj, gate_proj
|
||||||
|
model = {k.replace("down_proj", "w2"): v for k, v in model.items()}
|
||||||
|
model = {k.replace("up_proj", "w3"): v for k, v in model.items()}
|
||||||
|
model = {k.replace("gate_proj", "w1"): v for k, v in model.items()}
|
||||||
|
|
||||||
|
# 4. layernorms
|
||||||
|
model = {
|
||||||
|
k.replace("input_layernorm", "attention_norm"): v for k, v in model.items()
|
||||||
|
}
|
||||||
|
model = {
|
||||||
|
k.replace("post_attention_layernorm", "ffn_norm"): v for k, v in model.items()
|
||||||
|
}
|
||||||
|
|
||||||
|
# 5. lm head
|
||||||
|
model = {k.replace("lm_head", "output"): v for k, v in model.items()}
|
||||||
|
|
||||||
|
# 6. token emb
|
||||||
|
model = {k.replace("embed_tokens", "tok_embeddings"): v for k, v in model.items()}
|
||||||
|
|
||||||
|
# 7. attention
|
||||||
|
model = {k.replace("self_attn", "attention"): v for k, v in model.items()}
|
||||||
|
model = {k.replace("q_proj", "wq"): v for k, v in model.items()}
|
||||||
|
model = {k.replace("k_proj", "wk"): v for k, v in model.items()}
|
||||||
|
model = {k.replace("v_proj", "wv"): v for k, v in model.items()}
|
||||||
|
model = {k.replace("o_proj", "wo"): v for k, v in model.items()}
|
||||||
|
|
||||||
|
params = {}
|
||||||
|
params["dim"] = config.hidden_size
|
||||||
|
params["hidden_dim"] = config.intermediate_size
|
||||||
|
params["n_heads"] = config.num_attention_heads
|
||||||
|
if hasattr(config, "num_key_value_heads"):
|
||||||
|
params["n_kv_heads"] = config.num_key_value_heads
|
||||||
|
params["n_layers"] = config.num_hidden_layers
|
||||||
|
params["vocab_size"] = config.vocab_size
|
||||||
|
params["norm_eps"] = config.rms_norm_eps
|
||||||
|
params["rope_traditional"] = False
|
||||||
|
weights = {k: v.to(torch.float16).numpy() for k, v in model.items()}
|
||||||
|
|
||||||
|
return weights, params
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
parser = argparse.ArgumentParser(description="Convert Llama weights to MLX")
|
||||||
|
parser.add_argument(
|
||||||
|
"--model-path",
|
||||||
|
help="Path to the model. The MLX weights will also be saved there.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--model-name",
|
||||||
|
help=(
|
||||||
|
"Name of the model to convert. Use 'llama' for models in the "
|
||||||
|
"Llama family distributed by Meta including Llama 1, Llama 2, "
|
||||||
|
"Coda Llama, and Llama chat."
|
||||||
|
),
|
||||||
|
choices=["tiny_llama", "llama"],
|
||||||
|
default="llama",
|
||||||
|
)
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
model_path = Path(args.model_path)
|
||||||
|
weights, params = globals()[args.model_name](model_path)
|
||||||
|
np.savez(str(model_path / "weights.npz"), **weights)
|
||||||
|
if params is not None:
|
||||||
|
with open(model_path / "params.json", "w") as fid:
|
||||||
|
json.dump(params, fid, indent=4)
|
||||||
|
@ -24,6 +24,7 @@ class ModelArgs:
|
|||||||
norm_eps: float
|
norm_eps: float
|
||||||
vocab_size: int
|
vocab_size: int
|
||||||
rope_theta: float
|
rope_theta: float
|
||||||
|
rope_traditional: bool = True
|
||||||
|
|
||||||
|
|
||||||
class RMSNorm(nn.Module):
|
class RMSNorm(nn.Module):
|
||||||
@ -77,7 +78,9 @@ class Attention(nn.Module):
|
|||||||
self.wk = nn.Linear(args.dim, args.n_kv_heads * args.head_dim, bias=False)
|
self.wk = nn.Linear(args.dim, args.n_kv_heads * args.head_dim, bias=False)
|
||||||
self.wv = nn.Linear(args.dim, args.n_kv_heads * args.head_dim, bias=False)
|
self.wv = nn.Linear(args.dim, args.n_kv_heads * args.head_dim, bias=False)
|
||||||
self.wo = nn.Linear(args.n_heads * args.head_dim, args.dim, bias=False)
|
self.wo = nn.Linear(args.n_heads * args.head_dim, args.dim, bias=False)
|
||||||
self.rope = RoPE(args.head_dim, traditional=True, base=args.rope_theta)
|
self.rope = RoPE(
|
||||||
|
args.head_dim, traditional=args.rope_traditional, base=args.rope_theta
|
||||||
|
)
|
||||||
|
|
||||||
def __call__(
|
def __call__(
|
||||||
self,
|
self,
|
||||||
@ -234,7 +237,7 @@ def generate(args):
|
|||||||
|
|
||||||
input("Press enter to start generation")
|
input("Press enter to start generation")
|
||||||
print("------")
|
print("------")
|
||||||
|
print(args.prompt)
|
||||||
x = mx.array([[tokenizer.bos_id()] + tokenizer.encode(args.prompt)])
|
x = mx.array([[tokenizer.bos_id()] + tokenizer.encode(args.prompt)])
|
||||||
skip = 0
|
skip = 0
|
||||||
prompt_processing = None
|
prompt_processing = None
|
||||||
@ -248,7 +251,7 @@ def generate(args):
|
|||||||
mx.eval(token)
|
mx.eval(token)
|
||||||
prompt_processing = toc("Prompt processing", start)
|
prompt_processing = toc("Prompt processing", start)
|
||||||
|
|
||||||
if len(tokens) >= args.num_tokens:
|
if len(tokens) >= args.max_tokens:
|
||||||
break
|
break
|
||||||
|
|
||||||
elif (len(tokens) % args.write_every) == 0:
|
elif (len(tokens) % args.write_every) == 0:
|
||||||
@ -261,8 +264,7 @@ def generate(args):
|
|||||||
mx.eval(tokens)
|
mx.eval(tokens)
|
||||||
full_gen = toc("Full generation", start)
|
full_gen = toc("Full generation", start)
|
||||||
s = tokenizer.decode([t.item() for t in tokens])
|
s = tokenizer.decode([t.item() for t in tokens])
|
||||||
print(s[skip:], end="", flush=True)
|
print(s[skip:], flush=True)
|
||||||
print()
|
|
||||||
print("------")
|
print("------")
|
||||||
print(prompt_processing)
|
print(prompt_processing)
|
||||||
print(full_gen)
|
print(full_gen)
|
||||||
@ -292,7 +294,7 @@ def few_shot_generate(args):
|
|||||||
mx.eval(token)
|
mx.eval(token)
|
||||||
prompt_processing = toc("Prompt processing", start)
|
prompt_processing = toc("Prompt processing", start)
|
||||||
|
|
||||||
if len(tokens) >= args.num_tokens:
|
if len(tokens) >= args.max_tokens:
|
||||||
break
|
break
|
||||||
|
|
||||||
mx.eval(tokens)
|
mx.eval(tokens)
|
||||||
@ -316,7 +318,8 @@ def few_shot_generate(args):
|
|||||||
s = tokenizer.decode([t.item() for t in tokens])
|
s = tokenizer.decode([t.item() for t in tokens])
|
||||||
print(s[skip:], end="", flush=True)
|
print(s[skip:], end="", flush=True)
|
||||||
|
|
||||||
prompt = open(args.prompt).read().strip()
|
print("[INFO] Loading few-shot examples from: {}".format(args.few_shot))
|
||||||
|
prompt = open(args.few_shot).read().strip()
|
||||||
while True:
|
while True:
|
||||||
question = input("Ask a question: ")
|
question = input("Ask a question: ")
|
||||||
generate(prompt.replace("{}", question))
|
generate(prompt.replace("{}", question))
|
||||||
@ -354,14 +357,17 @@ if __name__ == "__main__":
|
|||||||
"model", help="Path to the model directory containing the MLX weights"
|
"model", help="Path to the model directory containing the MLX weights"
|
||||||
)
|
)
|
||||||
parser.add_argument("tokenizer", help="The sentencepiece tokenizer")
|
parser.add_argument("tokenizer", help="The sentencepiece tokenizer")
|
||||||
parser.add_argument("prompt", help="The message to be processed by the model")
|
parser.add_argument(
|
||||||
|
"--prompt",
|
||||||
|
help="The message to be processed by the model. Ignored when --few-shot is provided.",
|
||||||
|
default="In the beginning the Universe was created.",
|
||||||
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--few-shot",
|
"--few-shot",
|
||||||
action="store_true",
|
|
||||||
help="Read a few shot prompt from a file (as in `sample_prompt.txt`).",
|
help="Read a few shot prompt from a file (as in `sample_prompt.txt`).",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--num-tokens", "-n", type=int, default=100, help="How many tokens to generate"
|
"--max-tokens", "-m", type=int, default=100, help="How many tokens to generate"
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--write-every", type=int, default=1, help="After how many tokens to detokenize"
|
"--write-every", type=int, default=1, help="After how many tokens to detokenize"
|
||||||
|
2
llms/qwen/.gitignore
vendored
Normal file
2
llms/qwen/.gitignore
vendored
Normal file
@ -0,0 +1,2 @@
|
|||||||
|
weights.npz
|
||||||
|
config.json
|
41
llms/qwen/README.md
Normal file
41
llms/qwen/README.md
Normal file
@ -0,0 +1,41 @@
|
|||||||
|
# Qwen
|
||||||
|
|
||||||
|
Qwen (通义千问) are a family of language models developed by Alibaba Cloud.[^1]
|
||||||
|
The architecture of the Qwen models is similar to Llama except for the bias in
|
||||||
|
the attention layers.
|
||||||
|
|
||||||
|
## Setup
|
||||||
|
|
||||||
|
First download and convert the model with:
|
||||||
|
|
||||||
|
```sh
|
||||||
|
python convert.py
|
||||||
|
```
|
||||||
|
The script downloads the model from Hugging Face. The default model is
|
||||||
|
`Qwen/Qwen-1_8B`. Check out the [Hugging Face page](https://huggingface.co/Qwen) to see a list of available models.
|
||||||
|
|
||||||
|
The conversion script will make the `weights.npz` and `config.json` files in
|
||||||
|
the working directory.
|
||||||
|
|
||||||
|
## Generate
|
||||||
|
|
||||||
|
To generate text with the default prompt:
|
||||||
|
|
||||||
|
```sh
|
||||||
|
python qwen.py
|
||||||
|
```
|
||||||
|
|
||||||
|
If you change the model, make sure to pass the corresponding tokenizer. E.g.,
|
||||||
|
for Qwen 7B use:
|
||||||
|
|
||||||
|
```
|
||||||
|
python qwen.py --tokenizer Qwen/Qwen-7B
|
||||||
|
```
|
||||||
|
|
||||||
|
To see a list of options, run:
|
||||||
|
|
||||||
|
```sh
|
||||||
|
python qwen.py --help
|
||||||
|
```
|
||||||
|
|
||||||
|
[^1]: For more details on the model see the official repo of [Qwen](https://github.com/QwenLM/Qwen) and the [Hugging Face](https://huggingface.co/Qwen).
|
42
llms/qwen/convert.py
Normal file
42
llms/qwen/convert.py
Normal file
@ -0,0 +1,42 @@
|
|||||||
|
import argparse
|
||||||
|
from transformers import AutoModelForCausalLM
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
import json
|
||||||
|
|
||||||
|
|
||||||
|
def replace_key(key: str) -> str:
|
||||||
|
if key.startswith("transformer."):
|
||||||
|
# remove transformer prefix
|
||||||
|
key = key.replace("transformer.", "")
|
||||||
|
|
||||||
|
return key
|
||||||
|
|
||||||
|
|
||||||
|
def convert(model_path: str = "Qwen/Qwen-1_8B"):
|
||||||
|
model = AutoModelForCausalLM.from_pretrained(
|
||||||
|
model_path, trust_remote_code=True, torch_dtype=torch.float16
|
||||||
|
)
|
||||||
|
state_dict = model.state_dict()
|
||||||
|
weights = {replace_key(k): v.numpy() for k, v in state_dict.items()}
|
||||||
|
np.savez("weights.npz", **weights)
|
||||||
|
|
||||||
|
# write config
|
||||||
|
config = model.config
|
||||||
|
config_dict = config.to_dict()
|
||||||
|
with open("config.json", "w") as f:
|
||||||
|
json.dump(config_dict, f, indent=4)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
parser = argparse.ArgumentParser(description="Convert Qwen model to npz")
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--model",
|
||||||
|
help="The huggingface model to be converted",
|
||||||
|
default="Qwen/Qwen-1_8B",
|
||||||
|
)
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
convert(args.model)
|
269
llms/qwen/qwen.py
Normal file
269
llms/qwen/qwen.py
Normal file
@ -0,0 +1,269 @@
|
|||||||
|
import argparse
|
||||||
|
from dataclasses import dataclass
|
||||||
|
import json
|
||||||
|
import mlx.core as mx
|
||||||
|
import mlx.nn as nn
|
||||||
|
from mlx.utils import tree_unflatten
|
||||||
|
from transformers import AutoTokenizer
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ModelArgs:
|
||||||
|
hidden_size: int = 2048
|
||||||
|
num_attention_heads: int = 16
|
||||||
|
num_hidden_layers: int = 24
|
||||||
|
kv_channels: int = 128
|
||||||
|
max_position_embeddings: int = 8192
|
||||||
|
layer_norm_epsilon: float = 1e-6
|
||||||
|
intermediate_size: int = 11008
|
||||||
|
no_bias: bool = True
|
||||||
|
vocab_size: int = 151936
|
||||||
|
|
||||||
|
|
||||||
|
class RMSNorm(nn.Module):
|
||||||
|
def __init__(self, dims: int, eps: float = 1e-5):
|
||||||
|
super().__init__()
|
||||||
|
self.weight = mx.ones((dims,))
|
||||||
|
self.eps = eps
|
||||||
|
|
||||||
|
def _norm(self, x):
|
||||||
|
return x * mx.rsqrt(x.square().mean(-1, keepdims=True) + self.eps)
|
||||||
|
|
||||||
|
def __call__(self, x):
|
||||||
|
output = self._norm(x.astype(mx.float32)).astype(x.dtype)
|
||||||
|
return self.weight * output
|
||||||
|
|
||||||
|
|
||||||
|
class Attention(nn.Module):
|
||||||
|
def __init__(self, args: ModelArgs):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
hidden_size = args.hidden_size
|
||||||
|
self.num_attention_heads = args.num_attention_heads
|
||||||
|
|
||||||
|
hidden_size_per_attention_head = hidden_size // self.num_attention_heads
|
||||||
|
|
||||||
|
self.rotary_emb = nn.RoPE(hidden_size_per_attention_head, traditional=False)
|
||||||
|
|
||||||
|
proj_size = args.kv_channels * self.num_attention_heads
|
||||||
|
|
||||||
|
self.c_attn = nn.Linear(hidden_size, proj_size * 3, bias=True)
|
||||||
|
self.c_proj = nn.Linear(hidden_size, proj_size, bias=not args.no_bias)
|
||||||
|
|
||||||
|
self.scale = hidden_size_per_attention_head**-0.5
|
||||||
|
|
||||||
|
def __call__(self, x, mask=None, cache=None):
|
||||||
|
qkv = self.c_attn(x)
|
||||||
|
|
||||||
|
q, k, v = mx.split(qkv, 3, axis=-1)
|
||||||
|
|
||||||
|
B, L, _ = q.shape
|
||||||
|
|
||||||
|
q = q.reshape(B, L, self.num_attention_heads, -1).transpose(0, 2, 1, 3)
|
||||||
|
k = k.reshape(B, L, self.num_attention_heads, -1).transpose(0, 2, 1, 3)
|
||||||
|
v = v.reshape(B, L, self.num_attention_heads, -1).transpose(0, 2, 1, 3)
|
||||||
|
|
||||||
|
if cache is not None:
|
||||||
|
k_cache, v_cache = cache
|
||||||
|
q = self.rotary_emb(q, offset=k_cache.shape[2])
|
||||||
|
k = self.rotary_emb(k, offset=k_cache.shape[2])
|
||||||
|
k = mx.concatenate([k_cache, k], axis=2)
|
||||||
|
v = mx.concatenate([v_cache, v], axis=2)
|
||||||
|
|
||||||
|
else:
|
||||||
|
q = self.rotary_emb(q)
|
||||||
|
k = self.rotary_emb(k)
|
||||||
|
|
||||||
|
scores = (q * self.scale) @ k.transpose(0, 1, 3, 2)
|
||||||
|
|
||||||
|
if mask is not None:
|
||||||
|
scores = scores + mask
|
||||||
|
|
||||||
|
scores = mx.softmax(scores.astype(mx.float32), axis=-1).astype(scores.dtype)
|
||||||
|
v_hat = (scores @ v).transpose(0, 2, 1, 3).reshape(B, L, -1)
|
||||||
|
|
||||||
|
return self.c_proj(v_hat), (k, v)
|
||||||
|
|
||||||
|
|
||||||
|
class MLP(nn.Module):
|
||||||
|
def __init__(self, args: ModelArgs):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.w1 = nn.Linear(
|
||||||
|
args.hidden_size, args.intermediate_size // 2, bias=not args.no_bias
|
||||||
|
)
|
||||||
|
self.w2 = nn.Linear(
|
||||||
|
args.intermediate_size // 2, args.hidden_size, bias=not args.no_bias
|
||||||
|
)
|
||||||
|
self.c_proj = nn.Linear(
|
||||||
|
args.intermediate_size // 2, args.hidden_size, bias=not args.no_bias
|
||||||
|
)
|
||||||
|
|
||||||
|
def __call__(self, x):
|
||||||
|
a1 = self.w1(x)
|
||||||
|
a2 = self.w2(x)
|
||||||
|
return self.c_proj(a1 * nn.silu(a2))
|
||||||
|
|
||||||
|
|
||||||
|
class TransformerBlock(nn.Module):
|
||||||
|
def __init__(self, args: ModelArgs):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.ln_1 = RMSNorm(args.hidden_size, eps=args.layer_norm_epsilon)
|
||||||
|
self.attn = Attention(args)
|
||||||
|
self.ln_2 = RMSNorm(args.hidden_size, eps=args.layer_norm_epsilon)
|
||||||
|
self.mlp = MLP(args)
|
||||||
|
|
||||||
|
def __call__(self, x, mask=None, cache=None):
|
||||||
|
residual = x
|
||||||
|
x = self.ln_1(x)
|
||||||
|
x, cache = self.attn(x, mask=mask, cache=cache)
|
||||||
|
residual = x + residual
|
||||||
|
x = self.ln_2(residual)
|
||||||
|
x = self.mlp(x)
|
||||||
|
x = x + residual
|
||||||
|
|
||||||
|
return x, cache
|
||||||
|
|
||||||
|
|
||||||
|
class Qwen(nn.Module):
|
||||||
|
def __init__(self, args: ModelArgs):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.embed_dim = args.hidden_size
|
||||||
|
|
||||||
|
self.wte = nn.Embedding(args.vocab_size, args.hidden_size)
|
||||||
|
self.h = [TransformerBlock(args) for _ in range(args.num_hidden_layers)]
|
||||||
|
self.ln_f = RMSNorm(self.embed_dim, eps=args.layer_norm_epsilon)
|
||||||
|
|
||||||
|
self.lm_head = nn.Linear(self.embed_dim, args.vocab_size, bias=False)
|
||||||
|
|
||||||
|
def __call__(self, inputs, mask=None, cache=None):
|
||||||
|
x = self.wte(inputs)
|
||||||
|
|
||||||
|
mask = None
|
||||||
|
T = x.shape[1]
|
||||||
|
if T > 1:
|
||||||
|
mask = nn.MultiHeadAttention.create_additive_causal_mask(T)
|
||||||
|
mask = mask.astype(x.dtype)
|
||||||
|
|
||||||
|
if cache is None:
|
||||||
|
cache = [None] * len(self.h)
|
||||||
|
|
||||||
|
for e, layer in enumerate(self.h):
|
||||||
|
x, cache[e] = layer(x, mask, cache[e])
|
||||||
|
|
||||||
|
x = self.ln_f(x[:, T - 1 : T, :])
|
||||||
|
return self.lm_head(x), cache
|
||||||
|
|
||||||
|
|
||||||
|
def generate(prompt: mx.array, model: Qwen, temp: 0.0):
|
||||||
|
def sample(logits):
|
||||||
|
if temp == 0:
|
||||||
|
return mx.argmax(logits, axis=-1)
|
||||||
|
else:
|
||||||
|
return mx.random.categorical(logits * (1 / temp))
|
||||||
|
|
||||||
|
logits, cache = model(prompt)
|
||||||
|
y = sample(logits[:, -1, :])
|
||||||
|
yield y
|
||||||
|
|
||||||
|
while True:
|
||||||
|
logits, cache = model(y[:, None], cache=cache)
|
||||||
|
y = sample(logits.squeeze(1))
|
||||||
|
yield y
|
||||||
|
|
||||||
|
|
||||||
|
def load_model(
|
||||||
|
tokenizer_path: str = "Qwen/Qwen-1_8B", config_path: str = "config.json"
|
||||||
|
):
|
||||||
|
model_args = ModelArgs()
|
||||||
|
|
||||||
|
with open(config_path, "r") as f:
|
||||||
|
config = json.load(f)
|
||||||
|
model_args.vocab_size = config["vocab_size"]
|
||||||
|
model_args.hidden_size = config["hidden_size"]
|
||||||
|
model_args.num_attention_heads = config["num_attention_heads"]
|
||||||
|
model_args.num_hidden_layers = config["num_hidden_layers"]
|
||||||
|
model_args.kv_channels = config["kv_channels"]
|
||||||
|
model_args.max_position_embeddings = config["max_position_embeddings"]
|
||||||
|
model_args.layer_norm_epsilon = config["layer_norm_epsilon"]
|
||||||
|
model_args.intermediate_size = config["intermediate_size"]
|
||||||
|
model_args.no_bias = config["no_bias"]
|
||||||
|
|
||||||
|
model = Qwen(model_args)
|
||||||
|
|
||||||
|
weights = mx.load("weights.npz")
|
||||||
|
model.update(tree_unflatten(list(weights.items())))
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(
|
||||||
|
tokenizer_path, trust_remote_code=True, eos_token="<|endoftext|>"
|
||||||
|
)
|
||||||
|
return model, tokenizer
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
parser = argparse.ArgumentParser(description="Qwen inference script")
|
||||||
|
parser.add_argument(
|
||||||
|
"--tokenizer",
|
||||||
|
help="The tokenizer to be used, defaults to Qwen/Qwen-1_8B",
|
||||||
|
default="Qwen/Qwen-1_8B",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--prompt",
|
||||||
|
help="The message to be processed by the model",
|
||||||
|
# The example from the official huggingface repo of Qwen
|
||||||
|
default="蒙古国的首都是乌兰巴托(Ulaanbaatar)\n冰岛的首都是雷克雅未克(Reykjavik)\n埃塞俄比亚的首都是",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--max_tokens",
|
||||||
|
"-m",
|
||||||
|
type=int,
|
||||||
|
default=100,
|
||||||
|
help="Maximum number of tokens to generate",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--temp",
|
||||||
|
help="The sampling temperature.",
|
||||||
|
type=float,
|
||||||
|
default=0.0,
|
||||||
|
)
|
||||||
|
parser.add_argument("--seed", type=int, default=0, help="The PRNG seed")
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
mx.random.seed(args.seed)
|
||||||
|
|
||||||
|
model, tokenizer = load_model(args.tokenizer)
|
||||||
|
|
||||||
|
prompt = tokenizer(
|
||||||
|
args.prompt,
|
||||||
|
return_tensors="np",
|
||||||
|
return_attention_mask=False,
|
||||||
|
)["input_ids"]
|
||||||
|
|
||||||
|
prompt = mx.array(prompt)
|
||||||
|
|
||||||
|
print(args.prompt, end="", flush=True)
|
||||||
|
|
||||||
|
tokens = []
|
||||||
|
for token, _ in zip(generate(prompt, model, args.temp), range(args.max_tokens)):
|
||||||
|
tokens.append(token)
|
||||||
|
|
||||||
|
if (len(tokens) % 10) == 0:
|
||||||
|
mx.eval(tokens)
|
||||||
|
eos_index = next(
|
||||||
|
(i for i, t in enumerate(tokens) if t.item() == tokenizer.eos_token_id),
|
||||||
|
None,
|
||||||
|
)
|
||||||
|
|
||||||
|
if eos_index is not None:
|
||||||
|
tokens = tokens[:eos_index]
|
||||||
|
|
||||||
|
s = tokenizer.decode([t.item() for t in tokens])
|
||||||
|
print(s, end="", flush=True)
|
||||||
|
tokens = []
|
||||||
|
if eos_index is not None:
|
||||||
|
break
|
||||||
|
|
||||||
|
mx.eval(tokens)
|
||||||
|
s = tokenizer.decode([t.item() for t in tokens])
|
||||||
|
print(s, flush=True)
|
7
llms/qwen/requirements.txt
Normal file
7
llms/qwen/requirements.txt
Normal file
@ -0,0 +1,7 @@
|
|||||||
|
einops
|
||||||
|
mlx
|
||||||
|
numpy
|
||||||
|
transformers>=4.35
|
||||||
|
transformers_stream_generator>=0.0.4
|
||||||
|
torch
|
||||||
|
tiktoken
|
@ -32,21 +32,30 @@ if __name__ == "__main__":
|
|||||||
os.makedirs(args.mlx_model)
|
os.makedirs(args.mlx_model)
|
||||||
mlx_path = Path(args.mlx_model)
|
mlx_path = Path(args.mlx_model)
|
||||||
|
|
||||||
|
# Copy the tokenizer
|
||||||
|
tokenizer_path = torch_path / "tokenizer.model"
|
||||||
|
if not tokenizer_path.exists():
|
||||||
|
print(f"Make sure there is a file tokenizer.model in {args.torch_model}")
|
||||||
|
exit(0)
|
||||||
|
shutil.copyfile(
|
||||||
|
str(tokenizer_path),
|
||||||
|
str(mlx_path / "tokenizer.model"),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Copy the model weights
|
||||||
state = torch.load(str(torch_path / "consolidated.00.pth"))
|
state = torch.load(str(torch_path / "consolidated.00.pth"))
|
||||||
np.savez(
|
np.savez(
|
||||||
str(mlx_path / "weights.npz"),
|
str(mlx_path / "weights.npz"),
|
||||||
**{k: v.to(torch.float16).numpy() for k, v in state.items()}
|
**{k: v.to(torch.float16).numpy() for k, v in state.items()},
|
||||||
)
|
|
||||||
|
|
||||||
# Copy the tokenizer
|
|
||||||
shutil.copyfile(
|
|
||||||
str(torch_path / "tokenizer.model"),
|
|
||||||
str(mlx_path / "tokenizer.model"),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# Copy the params
|
# Copy the params
|
||||||
with open(torch_path / "params.json", "r") as f:
|
with open(torch_path / "params.json", "r") as f:
|
||||||
config = json.loads(f.read())
|
config = json.loads(f.read())
|
||||||
|
unused = ["multiple_of"]
|
||||||
|
for k in unused:
|
||||||
|
if k in config:
|
||||||
|
config.pop(k)
|
||||||
n_heads = config["n_heads"]
|
n_heads = config["n_heads"]
|
||||||
if "sliding_window" in config:
|
if "sliding_window" in config:
|
||||||
config.pop("sliding_window")
|
config.pop("sliding_window")
|
||||||
@ -55,6 +64,6 @@ if __name__ == "__main__":
|
|||||||
if "head_dim" not in config:
|
if "head_dim" not in config:
|
||||||
config["head_dim"] = config["dim"] // n_heads
|
config["head_dim"] = config["dim"] // n_heads
|
||||||
if "hidden_dim" not in config:
|
if "hidden_dim" not in config:
|
||||||
config["hidden_dim"] = state["layers.0.feed_forward.w1.weight"].shape
|
config["hidden_dim"] = state["layers.0.feed_forward.w1.weight"].shape[0]
|
||||||
with open(mlx_path / "params.json", "w") as outfile:
|
with open(mlx_path / "params.json", "w") as outfile:
|
||||||
json.dump(config, outfile)
|
json.dump(config, outfile, indent=4)
|
||||||
|
@ -332,9 +332,9 @@ def load_model(folder: str, dtype=mx.float16):
|
|||||||
tokenizer = Tokenizer(str(model_path / "tokenizer.model"))
|
tokenizer = Tokenizer(str(model_path / "tokenizer.model"))
|
||||||
with open(model_path / "params.json", "r") as f:
|
with open(model_path / "params.json", "r") as f:
|
||||||
config = json.loads(f.read())
|
config = json.loads(f.read())
|
||||||
model_args = ModelArgs(**config)
|
|
||||||
if config.get("vocab_size", -1) < 0:
|
if config.get("vocab_size", -1) < 0:
|
||||||
config["vocab_size"] = tokenizer.vocab_size
|
config["vocab_size"] = tokenizer.vocab_size
|
||||||
|
model_args = ModelArgs(**config)
|
||||||
weights = mx.load(str(model_path / "weights.npz"))
|
weights = mx.load(str(model_path / "weights.npz"))
|
||||||
weights = tree_unflatten(list(weights.items()))
|
weights = tree_unflatten(list(weights.items()))
|
||||||
weights = tree_map(lambda p: p.astype(dtype), weights)
|
weights = tree_map(lambda p: p.astype(dtype), weights)
|
||||||
|
1
t5/.gitignore
vendored
Normal file
1
t5/.gitignore
vendored
Normal file
@ -0,0 +1 @@
|
|||||||
|
*.npz
|
53
t5/README.md
Normal file
53
t5/README.md
Normal file
@ -0,0 +1,53 @@
|
|||||||
|
# T5
|
||||||
|
|
||||||
|
The T5 models are encoder-decoder models pre-trained on a mixture of
|
||||||
|
unsupervised and supervised tasks.[^1] These models work well on a variety of
|
||||||
|
tasks by prepending task-specific prefixes to the input, e.g.:
|
||||||
|
`translate English to German: …`, `summarize: ….`, etc.
|
||||||
|
|
||||||
|
This example also supports the FLAN-T5 models variants.[^2]
|
||||||
|
|
||||||
|
## Setup
|
||||||
|
|
||||||
|
Download and convert the model:
|
||||||
|
|
||||||
|
```sh
|
||||||
|
python convert.py --model <model>
|
||||||
|
```
|
||||||
|
|
||||||
|
This will make the `<model>.npz` file which MLX can read.
|
||||||
|
|
||||||
|
The `<model>` can be any of the following:
|
||||||
|
|
||||||
|
| Model Name | Model Size |
|
||||||
|
| ---------- | ----------
|
||||||
|
| t5-small | 60 million |
|
||||||
|
| t5-base | 220 million |
|
||||||
|
| t5-large | 770 million |
|
||||||
|
| t5-3b | 3 billion |
|
||||||
|
| t5-11b | 11 billion |
|
||||||
|
|
||||||
|
The FLAN variants can be specified with `google/flan-t5-small`,
|
||||||
|
`google/flan-t5-base`, etc. See the [Hugging Face
|
||||||
|
page](https://huggingface.co/docs/transformers/model_doc/flan-t5) for a
|
||||||
|
complete list of models.
|
||||||
|
|
||||||
|
## Generate
|
||||||
|
|
||||||
|
Generate text with:
|
||||||
|
|
||||||
|
```sh
|
||||||
|
python t5.py --model t5-small --prompt "translate English to German: A tasty apple"
|
||||||
|
```
|
||||||
|
|
||||||
|
This should give the output: `Ein leckerer Apfel`
|
||||||
|
|
||||||
|
To see a list of options run:
|
||||||
|
|
||||||
|
```sh
|
||||||
|
python t5.py --help
|
||||||
|
```
|
||||||
|
|
||||||
|
[^1]: For more information on T5 see the [original paper](https://arxiv.org/abs/1910.10683)
|
||||||
|
or the [Hugging Face page](https://huggingface.co/docs/transformers/model_doc/t5).
|
||||||
|
[^2]: For more information on FLAN-T5 see the [original paper](https://arxiv.org/abs/2210.11416).
|
77
t5/convert.py
Normal file
77
t5/convert.py
Normal file
@ -0,0 +1,77 @@
|
|||||||
|
from transformers import T5ForConditionalGeneration
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
|
||||||
|
SHARED_REPLACEMENT_PATTERNS = [
|
||||||
|
(".block.", ".layers."),
|
||||||
|
(".k.", ".key_proj."),
|
||||||
|
(".o.", ".out_proj."),
|
||||||
|
(".q.", ".query_proj."),
|
||||||
|
(".v.", ".value_proj."),
|
||||||
|
("shared.", "wte."),
|
||||||
|
("lm_head.", "lm_head.linear."),
|
||||||
|
(".layer.0.layer_norm.", ".ln1."),
|
||||||
|
(".layer.1.layer_norm.", ".ln2."),
|
||||||
|
(".layer.2.layer_norm.", ".ln3."),
|
||||||
|
(".final_layer_norm.", ".ln."),
|
||||||
|
(
|
||||||
|
"layers.0.layer.0.SelfAttention.relative_attention_bias.",
|
||||||
|
"relative_attention_bias.embeddings.",
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
ENCODER_REPLACEMENT_PATTERNS = [
|
||||||
|
(".layer.0.SelfAttention.", ".attention."),
|
||||||
|
(".layer.1.DenseReluDense.", ".dense."),
|
||||||
|
]
|
||||||
|
|
||||||
|
DECODER_REPLACEMENT_PATTERNS = [
|
||||||
|
(".layer.0.SelfAttention.", ".self_attention."),
|
||||||
|
(".layer.1.EncDecAttention.", ".cross_attention."),
|
||||||
|
(".layer.2.DenseReluDense.", ".dense."),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def replace_key(key: str) -> str:
|
||||||
|
for old, new in SHARED_REPLACEMENT_PATTERNS:
|
||||||
|
key = key.replace(old, new)
|
||||||
|
if key.startswith("encoder."):
|
||||||
|
for old, new in ENCODER_REPLACEMENT_PATTERNS:
|
||||||
|
key = key.replace(old, new)
|
||||||
|
elif key.startswith("decoder."):
|
||||||
|
for old, new in DECODER_REPLACEMENT_PATTERNS:
|
||||||
|
key = key.replace(old, new)
|
||||||
|
return key
|
||||||
|
|
||||||
|
|
||||||
|
def convert(model_name, dtype):
|
||||||
|
dtype = getattr(np, dtype)
|
||||||
|
model = T5ForConditionalGeneration.from_pretrained(model_name, torch_dtype="auto")
|
||||||
|
weights = {
|
||||||
|
replace_key(k): v.numpy().astype(dtype)
|
||||||
|
for k, v in model.state_dict().items()
|
||||||
|
}
|
||||||
|
file_name = model_name.replace("/", "-")
|
||||||
|
print(f"Saving weights to {file_name}.npz")
|
||||||
|
np.savez(f"{file_name}.npz", **weights)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
import argparse
|
||||||
|
|
||||||
|
parser = argparse.ArgumentParser(description="Convert T5 weights to MLX")
|
||||||
|
parser.add_argument(
|
||||||
|
"--model",
|
||||||
|
type=str,
|
||||||
|
help="Name of the T5 model.",
|
||||||
|
default="t5-small",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--dtype",
|
||||||
|
help="The model data type.",
|
||||||
|
type=str,
|
||||||
|
choices=["float16", "float32"],
|
||||||
|
default="float32",
|
||||||
|
)
|
||||||
|
args = parser.parse_args()
|
||||||
|
convert(args.model, args.dtype)
|
54
t5/hf_t5.py
Normal file
54
t5/hf_t5.py
Normal file
@ -0,0 +1,54 @@
|
|||||||
|
from transformers import T5ForConditionalGeneration, T5EncoderModel, AutoTokenizer
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
|
||||||
|
|
||||||
|
def embed(t5_model: str):
|
||||||
|
batch = [
|
||||||
|
"translate English to German: That is good.",
|
||||||
|
"This is an example of T5 working on MLX.",
|
||||||
|
]
|
||||||
|
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(t5_model)
|
||||||
|
torch_model = T5EncoderModel.from_pretrained(t5_model)
|
||||||
|
torch_tokens = tokenizer(batch, return_tensors="pt", padding=True)
|
||||||
|
torch_forward = torch_model(**torch_tokens, output_hidden_states=True)
|
||||||
|
torch_output = torch_forward.last_hidden_state.detach().numpy()
|
||||||
|
|
||||||
|
print("\n TF BERT:")
|
||||||
|
for input_str, embedding in list(zip(batch, torch_output)):
|
||||||
|
print("Input:", input_str)
|
||||||
|
print(embedding)
|
||||||
|
print()
|
||||||
|
|
||||||
|
|
||||||
|
def generate(t5_model: str):
|
||||||
|
prompt = "translate English to German: As much as six inches of rain could fall in the New York City region through Monday morning, and officials warned of flooding along the coast."
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(t5_model)
|
||||||
|
torch_model = T5ForConditionalGeneration.from_pretrained(t5_model)
|
||||||
|
torch_tokens = tokenizer(prompt, return_tensors="pt", padding=True).input_ids
|
||||||
|
outputs = torch_model.generate(torch_tokens, do_sample=False, max_length=512)
|
||||||
|
print(tokenizer.decode(outputs[0], skip_special_tokens=True))
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
parser = argparse.ArgumentParser(
|
||||||
|
description="Run the T5 model using Hugging Face Transformers."
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--encode-only",
|
||||||
|
action="store_true",
|
||||||
|
help="Only run the encoder and print the embeddings.",
|
||||||
|
default=False,
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--model",
|
||||||
|
default="t5-small",
|
||||||
|
help="The huggingface name of the T5 model to save.",
|
||||||
|
)
|
||||||
|
args = parser.parse_args()
|
||||||
|
if args.encode_only:
|
||||||
|
embed(args.model)
|
||||||
|
else:
|
||||||
|
generate(args.model)
|
||||||
|
|
3
t5/requirements.txt
Normal file
3
t5/requirements.txt
Normal file
@ -0,0 +1,3 @@
|
|||||||
|
mlx
|
||||||
|
numpy
|
||||||
|
transformers
|
469
t5/t5.py
Normal file
469
t5/t5.py
Normal file
@ -0,0 +1,469 @@
|
|||||||
|
import argparse
|
||||||
|
from typing import Optional, Tuple, List
|
||||||
|
from time import perf_counter_ns
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import mlx.core as mx
|
||||||
|
import mlx.nn as nn
|
||||||
|
from mlx.utils import tree_unflatten, tree_map
|
||||||
|
from transformers import T5Config, T5Tokenizer
|
||||||
|
|
||||||
|
|
||||||
|
def _relative_position_bucket(
|
||||||
|
relative_position, bidirectional=True, num_buckets=32, max_distance=128
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Adapted from HF Tensorflow:
|
||||||
|
https://github.com/huggingface/transformers/blob/main/src/transformers/models/t5/modeling_t5.py
|
||||||
|
|
||||||
|
Translate relative position to a bucket number for relative attention. The relative position is defined as
|
||||||
|
memory_position - query_position, i.e. the distance in tokens from the attending position to the attended-to
|
||||||
|
position. If bidirectional=False, then positive relative positions are invalid. We use smaller buckets for
|
||||||
|
small absolute relative_position and larger buckets for larger absolute relative_positions. All relative
|
||||||
|
positions >=max_distance map to the same bucket. All relative positions <=-max_distance map to the same bucket.
|
||||||
|
This should allow for more graceful generalization to longer sequences than the model has been trained on
|
||||||
|
|
||||||
|
Args:
|
||||||
|
relative_position: an int32 Tensor
|
||||||
|
bidirectional: a boolean - whether the attention is bidirectional
|
||||||
|
num_buckets: an integer
|
||||||
|
max_distance: an integer
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
a Tensor with the same shape as relative_position, containing int32 values in the range [0, num_buckets)
|
||||||
|
"""
|
||||||
|
relative_buckets = 0
|
||||||
|
if bidirectional:
|
||||||
|
num_buckets //= 2
|
||||||
|
relative_buckets += (relative_position > 0).astype(mx.int16) * num_buckets
|
||||||
|
relative_position = mx.abs(relative_position)
|
||||||
|
else:
|
||||||
|
relative_position = -mx.minimum(
|
||||||
|
relative_position, mx.zeros_like(relative_position)
|
||||||
|
)
|
||||||
|
# now relative_position is in the range [0, inf)
|
||||||
|
|
||||||
|
# half of the buckets are for exact increments in positions
|
||||||
|
max_exact = num_buckets // 2
|
||||||
|
is_small = relative_position < max_exact
|
||||||
|
|
||||||
|
# The other half of the buckets are for logarithmically bigger bins in positions up to max_distance
|
||||||
|
scale = (num_buckets - max_exact) / np.log(max_distance / max_exact)
|
||||||
|
relative_position_if_large = max_exact + (
|
||||||
|
mx.log(relative_position.astype(mx.float32) / max_exact) * scale
|
||||||
|
).astype(mx.int16)
|
||||||
|
relative_position_if_large = mx.minimum(relative_position_if_large, num_buckets - 1)
|
||||||
|
relative_buckets += mx.where(
|
||||||
|
is_small, relative_position, relative_position_if_large
|
||||||
|
)
|
||||||
|
return relative_buckets
|
||||||
|
|
||||||
|
|
||||||
|
class RelativePositionBias(nn.Module):
|
||||||
|
def __init__(self, config: T5Config, bidirectional: bool):
|
||||||
|
self.bidirectional = bidirectional
|
||||||
|
self.num_buckets = config.relative_attention_num_buckets
|
||||||
|
self.max_distance = config.relative_attention_max_distance
|
||||||
|
self.n_heads = config.num_heads
|
||||||
|
self.embeddings = nn.Embedding(
|
||||||
|
config.relative_attention_num_buckets, config.num_heads
|
||||||
|
)
|
||||||
|
|
||||||
|
def __call__(self, query_length: int, key_length: int, offset: int = 0):
|
||||||
|
"""Compute binned relative position bias"""
|
||||||
|
context_position = mx.arange(offset, query_length)[:, None]
|
||||||
|
memory_position = mx.arange(key_length)[None, :]
|
||||||
|
|
||||||
|
# shape (query_length, key_length)
|
||||||
|
relative_position = memory_position - context_position
|
||||||
|
relative_position_bucket = _relative_position_bucket(
|
||||||
|
relative_position,
|
||||||
|
bidirectional=self.bidirectional,
|
||||||
|
num_buckets=self.num_buckets,
|
||||||
|
max_distance=self.max_distance,
|
||||||
|
)
|
||||||
|
|
||||||
|
# shape (query_length, key_length, num_heads)
|
||||||
|
values = self.embeddings(relative_position_bucket)
|
||||||
|
|
||||||
|
# shape (num_heads, query_length, key_length)
|
||||||
|
return values.transpose(2, 0, 1)
|
||||||
|
|
||||||
|
|
||||||
|
class MultiHeadAttention(nn.Module):
|
||||||
|
def __init__(self, config: T5Config):
|
||||||
|
super().__init__()
|
||||||
|
inner_dim = config.d_kv * config.num_heads
|
||||||
|
self.num_heads = config.num_heads
|
||||||
|
self.query_proj = nn.Linear(config.d_model, inner_dim, bias=False)
|
||||||
|
self.key_proj = nn.Linear(config.d_model, inner_dim, bias=False)
|
||||||
|
self.value_proj = nn.Linear(config.d_model, inner_dim, bias=False)
|
||||||
|
self.out_proj = nn.Linear(inner_dim, config.d_model, bias=False)
|
||||||
|
|
||||||
|
def __call__(
|
||||||
|
self,
|
||||||
|
queries: mx.array,
|
||||||
|
keys: mx.array,
|
||||||
|
values: mx.array,
|
||||||
|
mask: Optional[mx.array],
|
||||||
|
cache: Optional[Tuple[mx.array, mx.array]] = None,
|
||||||
|
) -> [mx.array, Tuple[mx.array, mx.array]]:
|
||||||
|
queries = self.query_proj(queries)
|
||||||
|
keys = self.key_proj(keys)
|
||||||
|
values = self.value_proj(values)
|
||||||
|
|
||||||
|
num_heads = self.num_heads
|
||||||
|
B, L, _ = queries.shape
|
||||||
|
_, S, _ = keys.shape
|
||||||
|
queries = queries.reshape(B, L, num_heads, -1).transpose(0, 2, 1, 3)
|
||||||
|
keys = keys.reshape(B, S, num_heads, -1).transpose(0, 2, 3, 1)
|
||||||
|
values = values.reshape(B, S, num_heads, -1).transpose(0, 2, 1, 3)
|
||||||
|
|
||||||
|
if cache is not None:
|
||||||
|
key_cache, value_cache = cache
|
||||||
|
keys = mx.concatenate([key_cache, keys], axis=3)
|
||||||
|
values = mx.concatenate([value_cache, values], axis=2)
|
||||||
|
|
||||||
|
# Dimensions are [batch x num heads x sequence x hidden dim]
|
||||||
|
queries = queries
|
||||||
|
scores = queries @ keys
|
||||||
|
if mask is not None:
|
||||||
|
scores = scores + mask.astype(scores.dtype)
|
||||||
|
|
||||||
|
scores = mx.softmax(scores.astype(mx.float32), axis=-1).astype(scores.dtype)
|
||||||
|
values_hat = (scores @ values).transpose(0, 2, 1, 3).reshape(B, L, -1)
|
||||||
|
return self.out_proj(values_hat), (keys, values)
|
||||||
|
|
||||||
|
|
||||||
|
class RMSNorm(nn.Module):
|
||||||
|
def __init__(self, dims: int, eps: float = 1e-5):
|
||||||
|
super().__init__()
|
||||||
|
self.weight = mx.ones((dims,))
|
||||||
|
self.eps = eps
|
||||||
|
|
||||||
|
def _norm(self, x):
|
||||||
|
return x * mx.rsqrt(x.square().mean(-1, keepdims=True) + self.eps)
|
||||||
|
|
||||||
|
def __call__(self, x):
|
||||||
|
t = x.dtype
|
||||||
|
output = self._norm(x).astype(t)
|
||||||
|
return self.weight * output
|
||||||
|
|
||||||
|
|
||||||
|
class DenseActivation(nn.Module):
|
||||||
|
def __init__(self, config: T5Config):
|
||||||
|
super().__init__()
|
||||||
|
mlp_dims = config.d_ff or config.d_model * 4
|
||||||
|
self.gated = config.feed_forward_proj.startswith("gated")
|
||||||
|
if self.gated:
|
||||||
|
self.wi_0 = nn.Linear(config.d_model, mlp_dims, bias=False)
|
||||||
|
self.wi_1 = nn.Linear(config.d_model, mlp_dims, bias=False)
|
||||||
|
else:
|
||||||
|
self.wi = nn.Linear(config.d_model, mlp_dims, bias=False)
|
||||||
|
self.wo = nn.Linear(mlp_dims, config.d_model, bias=False)
|
||||||
|
activation = config.feed_forward_proj.removeprefix("gated-")
|
||||||
|
if activation == "relu":
|
||||||
|
self.act = nn.relu
|
||||||
|
elif activation == "gelu":
|
||||||
|
self.act = nn.gelu
|
||||||
|
elif activation == "silu":
|
||||||
|
self.act = nn.silu
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unknown activation: {activation}")
|
||||||
|
|
||||||
|
def __call__(self, x):
|
||||||
|
if self.gated:
|
||||||
|
hidden_act = self.act(self.wi_0(x))
|
||||||
|
hidden_linear = self.wi_1(x)
|
||||||
|
x = hidden_act * hidden_linear
|
||||||
|
else:
|
||||||
|
x = self.act(self.wi(x))
|
||||||
|
return self.wo(x)
|
||||||
|
|
||||||
|
|
||||||
|
class TransformerEncoderLayer(nn.Module):
|
||||||
|
def __init__(self, config: T5Config):
|
||||||
|
super().__init__()
|
||||||
|
self.attention = MultiHeadAttention(config)
|
||||||
|
self.ln1 = RMSNorm(config.d_model, eps=config.layer_norm_epsilon)
|
||||||
|
self.ln2 = RMSNorm(config.d_model, eps=config.layer_norm_epsilon)
|
||||||
|
self.dense = DenseActivation(config)
|
||||||
|
|
||||||
|
def __call__(self, x, mask):
|
||||||
|
y = self.ln1(x)
|
||||||
|
y, _ = self.attention(y, y, y, mask=mask)
|
||||||
|
x = x + y
|
||||||
|
|
||||||
|
y = self.ln2(x)
|
||||||
|
y = self.dense(y)
|
||||||
|
return x + y
|
||||||
|
|
||||||
|
|
||||||
|
class TransformerEncoder(nn.Module):
|
||||||
|
def __init__(self, config: T5Config):
|
||||||
|
super().__init__()
|
||||||
|
self.layers = [
|
||||||
|
TransformerEncoderLayer(config) for i in range(config.num_layers)
|
||||||
|
]
|
||||||
|
self.ln = RMSNorm(config.d_model, eps=config.layer_norm_epsilon)
|
||||||
|
self.relative_attention_bias = RelativePositionBias(config, bidirectional=True)
|
||||||
|
|
||||||
|
def __call__(self, x: mx.array):
|
||||||
|
pos_bias = self.relative_attention_bias(x.shape[1], x.shape[1])
|
||||||
|
for layer in self.layers:
|
||||||
|
x = layer(x, mask=pos_bias)
|
||||||
|
return self.ln(x)
|
||||||
|
|
||||||
|
|
||||||
|
class TransformerDecoderLayer(nn.Module):
|
||||||
|
def __init__(self, config: T5Config):
|
||||||
|
super().__init__()
|
||||||
|
self.self_attention = MultiHeadAttention(config)
|
||||||
|
self.cross_attention = MultiHeadAttention(config)
|
||||||
|
self.ln1 = RMSNorm(config.d_model, eps=config.layer_norm_epsilon)
|
||||||
|
self.ln2 = RMSNorm(config.d_model, eps=config.layer_norm_epsilon)
|
||||||
|
self.ln3 = RMSNorm(config.d_model, eps=config.layer_norm_epsilon)
|
||||||
|
self.dense = DenseActivation(config)
|
||||||
|
|
||||||
|
def __call__(
|
||||||
|
self,
|
||||||
|
x: mx.array,
|
||||||
|
memory: mx.array,
|
||||||
|
mask: mx.array,
|
||||||
|
memory_mask: mx.array,
|
||||||
|
cache: Optional[List[Tuple[mx.array, mx.array]]] = None,
|
||||||
|
):
|
||||||
|
y = self.ln1(x)
|
||||||
|
y, cache = self.self_attention(y, y, y, mask, cache)
|
||||||
|
x = x + y
|
||||||
|
|
||||||
|
y = self.ln2(x)
|
||||||
|
y, _ = self.cross_attention(y, memory, memory, memory_mask)
|
||||||
|
x = x + y
|
||||||
|
|
||||||
|
y = self.ln3(x)
|
||||||
|
y = self.dense(y)
|
||||||
|
x = x + y
|
||||||
|
|
||||||
|
return x, cache
|
||||||
|
|
||||||
|
|
||||||
|
class TransformerDecoder(nn.Module):
|
||||||
|
def __init__(self, config: T5Config):
|
||||||
|
super().__init__()
|
||||||
|
self.layers = [
|
||||||
|
TransformerDecoderLayer(config) for i in range(config.num_layers)
|
||||||
|
]
|
||||||
|
self.ln = RMSNorm(config.d_model, eps=config.layer_norm_epsilon)
|
||||||
|
self.relative_attention_bias = RelativePositionBias(config, bidirectional=False)
|
||||||
|
|
||||||
|
def __call__(self, x, memory, mask, memory_mask, cache=None):
|
||||||
|
if cache is not None:
|
||||||
|
offset = cache[0][0].shape[3]
|
||||||
|
else:
|
||||||
|
offset = 0
|
||||||
|
cache = [None] * len(self.layers)
|
||||||
|
|
||||||
|
T = offset + x.shape[1]
|
||||||
|
pos_bias = self.relative_attention_bias(T, T, offset=offset)
|
||||||
|
if mask is not None:
|
||||||
|
mask += pos_bias
|
||||||
|
else:
|
||||||
|
mask = pos_bias
|
||||||
|
|
||||||
|
for e, layer in enumerate(self.layers):
|
||||||
|
x, cache[e] = layer(x, memory, mask, memory_mask, cache=cache[e])
|
||||||
|
x = self.ln(x)
|
||||||
|
|
||||||
|
return x, cache
|
||||||
|
|
||||||
|
|
||||||
|
class OutputHead(nn.Module):
|
||||||
|
def __init__(self, config: T5Config):
|
||||||
|
self.linear = nn.Linear(config.d_model, config.vocab_size, bias=False)
|
||||||
|
|
||||||
|
def __call__(self, inputs):
|
||||||
|
return self.linear(inputs)
|
||||||
|
|
||||||
|
|
||||||
|
class T5(nn.Module):
|
||||||
|
def __init__(self, config: T5Config):
|
||||||
|
self.wte = nn.Embedding(config.vocab_size, config.d_model)
|
||||||
|
self.encoder = TransformerEncoder(config)
|
||||||
|
self.decoder = TransformerDecoder(config)
|
||||||
|
self.tie_word_embeddings = config.tie_word_embeddings
|
||||||
|
if not self.tie_word_embeddings:
|
||||||
|
self.lm_head = OutputHead(config)
|
||||||
|
self.model_dim = config.d_model
|
||||||
|
|
||||||
|
def encode(self, inputs: mx.array):
|
||||||
|
return self.encoder(self.wte(inputs))
|
||||||
|
|
||||||
|
def decode(
|
||||||
|
self,
|
||||||
|
inputs: mx.array,
|
||||||
|
memory: mx.array,
|
||||||
|
cache=None,
|
||||||
|
):
|
||||||
|
inputs = self.wte(inputs)
|
||||||
|
T = inputs.shape[1]
|
||||||
|
if T > 1:
|
||||||
|
mask = nn.MultiHeadAttention.create_additive_causal_mask(T)
|
||||||
|
mask = mask.astype(inputs.dtype)
|
||||||
|
else:
|
||||||
|
mask = None
|
||||||
|
|
||||||
|
y, cache = self.decoder(
|
||||||
|
inputs, memory=memory, mask=mask, memory_mask=None, cache=cache
|
||||||
|
)
|
||||||
|
if not self.tie_word_embeddings:
|
||||||
|
y *= self.model_dim**-0.5
|
||||||
|
y = self.lm_head(y)
|
||||||
|
else:
|
||||||
|
y = y @ self.wte.weight.T
|
||||||
|
return y, cache
|
||||||
|
|
||||||
|
def __call__(
|
||||||
|
self,
|
||||||
|
inputs: mx.array,
|
||||||
|
decoder_inputs: mx.array,
|
||||||
|
):
|
||||||
|
return self.decode(decoder_inputs, self.encode(inputs))[0]
|
||||||
|
|
||||||
|
|
||||||
|
class Tokenizer:
|
||||||
|
def __init__(self, model_name: str, config: T5Config):
|
||||||
|
self._decoder_start_id = config.decoder_start_token_id
|
||||||
|
self._tokenizer = T5Tokenizer.from_pretrained(
|
||||||
|
args.model,
|
||||||
|
legacy=False,
|
||||||
|
model_max_length=getattr(config, 'n_positions', 512)
|
||||||
|
)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def eos_id(self) -> int:
|
||||||
|
return self._tokenizer.eos_token_id
|
||||||
|
|
||||||
|
@property
|
||||||
|
def decoder_start_id(self) -> int:
|
||||||
|
return self._decoder_start_id
|
||||||
|
|
||||||
|
def encode(self, s: str) -> mx.array:
|
||||||
|
return mx.array(
|
||||||
|
self._tokenizer(
|
||||||
|
s,
|
||||||
|
return_tensors="np",
|
||||||
|
return_attention_mask=False,
|
||||||
|
)["input_ids"]
|
||||||
|
)
|
||||||
|
|
||||||
|
def decode(self, t: List[int], with_sep: bool = True) -> str:
|
||||||
|
tokens = self._tokenizer.convert_ids_to_tokens(t)
|
||||||
|
return "".join(t.replace("▁", " " if with_sep else "") for t in tokens)
|
||||||
|
|
||||||
|
|
||||||
|
def generate(prompt: str, model: T5, tokenizer: Tokenizer, temp: Optional[float] = 0.0):
|
||||||
|
def sample(logits):
|
||||||
|
if temp == 0:
|
||||||
|
return mx.argmax(logits, axis=-1)
|
||||||
|
else:
|
||||||
|
return mx.random.categorical(logits * (1 / temp))
|
||||||
|
|
||||||
|
prompt = tokenizer.encode(prompt)
|
||||||
|
decoder_inputs = mx.array([tokenizer.decoder_start_id])
|
||||||
|
memory = model.encode(prompt)
|
||||||
|
cache = None
|
||||||
|
y = decoder_inputs
|
||||||
|
while True:
|
||||||
|
logits, cache = model.decode(y[None], memory, cache=cache)
|
||||||
|
y = sample(logits[:, -1, :])
|
||||||
|
yield y.squeeze()
|
||||||
|
|
||||||
|
|
||||||
|
def load_model(model_name: str, dtype: str = "float16"):
|
||||||
|
config = T5Config.from_pretrained(args.model)
|
||||||
|
dtype = getattr(mx, dtype)
|
||||||
|
model = T5(config)
|
||||||
|
file_name = model_name.replace("/", "-")
|
||||||
|
weights = mx.load(f"{file_name}.npz")
|
||||||
|
weights = tree_unflatten(list(weights.items()))
|
||||||
|
weights = tree_map(lambda p: p.astype(dtype), weights)
|
||||||
|
model.update(weights)
|
||||||
|
mx.eval(model.parameters())
|
||||||
|
return model, Tokenizer(args.model, config)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
parser = argparse.ArgumentParser(description="T5 Inference script")
|
||||||
|
parser.add_argument(
|
||||||
|
"--model",
|
||||||
|
type=str,
|
||||||
|
help="Name of the T5 model.",
|
||||||
|
default="t5-small",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--prompt",
|
||||||
|
help="",
|
||||||
|
default="translate English to German: That is good.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--encode-only",
|
||||||
|
action="store_true",
|
||||||
|
default=False,
|
||||||
|
help="Whether to decode or not. If true, will output last layer of encoder.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--max-tokens",
|
||||||
|
"-m",
|
||||||
|
type=int,
|
||||||
|
default=100,
|
||||||
|
help="Maximum number of tokens to generate",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--temp",
|
||||||
|
help="The sampling temperature.",
|
||||||
|
type=float,
|
||||||
|
default=0.0,
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--dtype",
|
||||||
|
help="The model data type.",
|
||||||
|
type=str,
|
||||||
|
choices=["float16", "bfloat16", "float32"],
|
||||||
|
default="bfloat16",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument("--seed", type=int, default=0, help="The PRNG seed")
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
mx.random.seed(args.seed)
|
||||||
|
|
||||||
|
model, tokenizer = load_model(args.model, args.dtype)
|
||||||
|
|
||||||
|
if args.encode_only:
|
||||||
|
print("[INFO] Encoding with T5...", flush=True)
|
||||||
|
print(args.prompt, flush=True)
|
||||||
|
encoder_output = model.encode(tokenizer.encode(args.prompt))
|
||||||
|
print(encoder_output, flush=True)
|
||||||
|
exit(0)
|
||||||
|
|
||||||
|
print("[INFO] Generating with T5...", flush=True)
|
||||||
|
print("Input: ", args.prompt, flush=True)
|
||||||
|
|
||||||
|
start = perf_counter_ns()
|
||||||
|
for token, n_tokens in zip(
|
||||||
|
generate(args.prompt, model, tokenizer, args.temp), range(args.max_tokens)
|
||||||
|
):
|
||||||
|
if token.item() == tokenizer.eos_id:
|
||||||
|
break
|
||||||
|
print(
|
||||||
|
tokenizer.decode([token.item()], with_sep=n_tokens > 0),
|
||||||
|
end="",
|
||||||
|
flush=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
n_tokens += 1
|
||||||
|
end = perf_counter_ns()
|
||||||
|
elapsed = (end - start) / 1.0e9
|
||||||
|
print()
|
||||||
|
print(f"Time: {elapsed:.2f} seconds, tokens/s: {n_tokens / elapsed:.2f}")
|
Loading…
Reference in New Issue
Block a user