diff --git a/ACKNOWLEDGMENTS.md b/ACKNOWLEDGMENTS.md new file mode 100644 index 00000000..3c368003 --- /dev/null +++ b/ACKNOWLEDGMENTS.md @@ -0,0 +1,10 @@ +# Individual Contributors + +If you wish to be acknowledged for your contributions, please list your name +with a short description of your contribution(s) below. For example: + +- Jane Smith: Added the `foo` example. + +MLX Examples was developed with contributions from the following individuals: + +- Juarez Bochi: Added support for T5 models. diff --git a/README.md b/README.md index 7988e37a..5ea53d25 100644 --- a/README.md +++ b/README.md @@ -18,5 +18,24 @@ Some more useful examples include: ## Contributing -Check out the [contribution guidelines](CONTRIBUTING.md) for more information -on contributing to this repo. +We are grateful for all of [our +contributors](ACKNOWLEDGMENTS.md#Individual-Contributors). If you contribute +to MLX Examples and wish to be acknowledged, please add your name to to the list in your +pull request. + +## Citing MLX Examples + +The MLX software suite was initially developed with equal contribution by Awni +Hannun, Jagrit Digani, Angelos Katharopoulos, and Ronan Collobert. If you find +MLX Examples useful in your research and wish to cite it, please use the following +BibTex entry: + +``` +@software{mlx2023, + author = {Awni Hannun and Jagrit Digani and Angelos Katharopoulos and Ronan Collobert}, + title = {{MLX}: Efficient and flexible machine learning on Apple silicon}, + url = {https://github.com/ml-explore}, + version = {0.0}, + year = {2023}, +} +``` diff --git a/llama/README.md b/llama/README.md index 220d1b16..39c0267c 100644 --- a/llama/README.md +++ b/llama/README.md @@ -3,8 +3,8 @@ An example of generating text with Llama (1 or 2) using MLX. Llama is a set of open source language models from Meta AI Research[^1][^2] -ranging from 7B to 70B parameters. This example also supports Llama Chat and -Code Llama. +ranging from 7B to 70B parameters. This example also supports Meta's Llama Chat +and Code Llama models, as well as the 1.1B TinyLlama models from SUTD.[^3] ### Setup @@ -25,10 +25,19 @@ Alternatively, you can also download a select converted checkpoints from the [mlx-llama](https://huggingface.co/mlx-llama) community organisation on Hugging Face and skip the conversion step. +You can download the TinyLlama models directly from [Hugging +Face](https://huggingface.co/TinyLlama). + Convert the weights with: ``` -python convert.py --model_path +python convert.py --model-path +``` + +For TinyLlama use + +``` +python convert.py --model-path --model-name tiny_llama ``` The conversion script will save the converted weights in the same location. @@ -39,10 +48,11 @@ Once you've converted the weights to MLX format, you can interact with the LlaMA model: ``` -python llama.py "hello" +python llama.py --prompt "hello" ``` Run `python llama.py --help` for more details. [^1]: For Llama v1 refer to the [arXiv paper](https://arxiv.org/abs/2302.13971) and [blog post](https://ai.meta.com/blog/large-language-model-llama-meta-ai/) for more details. [^2]: For Llama v2 refer to the [blob post](https://ai.meta.com/llama/) +[^3]: For TinyLlama refer to the [gihub repository](https://github.com/jzhang38/TinyLlama?tab=readme-ov-file) diff --git a/llama/convert.py b/llama/convert.py index 89ce8a36..957f3d22 100644 --- a/llama/convert.py +++ b/llama/convert.py @@ -3,45 +3,35 @@ import argparse import collections import glob -from pathlib import Path - +import json import numpy as np +from pathlib import Path import torch -SHARD_FIRST = ["wv", "wq", "wk", "w1", "w3", "output"] -SHARD_SECOND = ["tok_embeddings", "wo", "w2"] -SHARD_WEIGHTS = set(SHARD_FIRST + SHARD_SECOND) +def llama(model_path): + SHARD_FIRST = ["wv", "wq", "wk", "w1", "w3", "output"] + SHARD_SECOND = ["tok_embeddings", "wo", "w2"] + SHARD_WEIGHTS = set(SHARD_FIRST + SHARD_SECOND) -def shard_key(k): - keys = k.split(".") - if len(keys) < 2: - return None - return keys[-2] + def shard_key(k): + keys = k.split(".") + if len(keys) < 2: + return None + return keys[-2] + def unshard(k, v): + wn = shard_key(k) + if wn not in SHARD_WEIGHTS: + return v + elif wn in SHARD_FIRST: + axis = 0 + elif wn in SHARD_SECOND: + axis = 1 + else: + raise ValueError("Invalid weight name") + return np.concatenate(v, axis=axis) -def unshard(k, v): - wn = shard_key(k) - if wn not in SHARD_WEIGHTS: - return v - elif wn in SHARD_FIRST: - axis = 0 - elif wn in SHARD_SECOND: - axis = 1 - else: - raise ValueError("Invalid weight name") - return np.concatenate(v, axis=axis) - - -if __name__ == "__main__": - parser = argparse.ArgumentParser(description="Convert Llama weights to MLX") - parser.add_argument( - "--model_path", - help="Path to the Torch model. The MLX weights will also be saved there.", - ) - args = parser.parse_args() - - model_path = Path(args.model_path) torch_files = glob.glob(str(model_path / "consolidated.*.pth")) weights = collections.defaultdict(list) for wf in torch_files: @@ -53,7 +43,96 @@ if __name__ == "__main__": else: weights[k] = v - out_file = str(model_path / "weights.npz") for k, v in weights.items(): weights[k] = unshard(k, v) - np.savez(out_file, **weights) + return weights, None + + +def tiny_llama(model_path): + try: + import transformers + except ImportError as e: + print("The transformers package must be installed for this model conversion:") + print("pip install transformers") + import sys + + sys.exit(0) + + model = transformers.AutoModelForCausalLM.from_pretrained( + str(model_path) + ).state_dict() + config = transformers.AutoConfig.from_pretrained(model_path) + + # things to change + # 1. there's no "model." in the weight names + model = {k.replace("model.", ""): v for k, v in model.items()} + + # 2. mlp is called feed_forward + model = {k.replace("mlp", "feed_forward"): v for k, v in model.items()} + + # 3. up_proj, down_proj, gate_proj + model = {k.replace("down_proj", "w2"): v for k, v in model.items()} + model = {k.replace("up_proj", "w3"): v for k, v in model.items()} + model = {k.replace("gate_proj", "w1"): v for k, v in model.items()} + + # 4. layernorms + model = { + k.replace("input_layernorm", "attention_norm"): v for k, v in model.items() + } + model = { + k.replace("post_attention_layernorm", "ffn_norm"): v for k, v in model.items() + } + + # 5. lm head + model = {k.replace("lm_head", "output"): v for k, v in model.items()} + + # 6. token emb + model = {k.replace("embed_tokens", "tok_embeddings"): v for k, v in model.items()} + + # 7. attention + model = {k.replace("self_attn", "attention"): v for k, v in model.items()} + model = {k.replace("q_proj", "wq"): v for k, v in model.items()} + model = {k.replace("k_proj", "wk"): v for k, v in model.items()} + model = {k.replace("v_proj", "wv"): v for k, v in model.items()} + model = {k.replace("o_proj", "wo"): v for k, v in model.items()} + + params = {} + params["dim"] = config.hidden_size + params["hidden_dim"] = config.intermediate_size + params["n_heads"] = config.num_attention_heads + if hasattr(config, "num_key_value_heads"): + params["n_kv_heads"] = config.num_key_value_heads + params["n_layers"] = config.num_hidden_layers + params["vocab_size"] = config.vocab_size + params["norm_eps"] = config.rms_norm_eps + params["rope_traditional"] = False + weights = {k: v.to(torch.float16).numpy() for k, v in model.items()} + + return weights, params + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Convert Llama weights to MLX") + parser.add_argument( + "--model-path", + help="Path to the model. The MLX weights will also be saved there.", + ) + parser.add_argument( + "--model-name", + help=( + "Name of the model to convert. Use 'llama' for models in the " + "Llama family distributed by Meta including Llama 1, Llama 2, " + "Coda Llama, and Llama chat." + ), + choices=["tiny_llama", "llama"], + default="llama", + ) + + args = parser.parse_args() + + model_path = Path(args.model_path) + weights, params = globals()[args.model_name](model_path) + np.savez(str(model_path / "weights.npz"), **weights) + if params is not None: + with open(model_path / "params.json", "w") as fid: + json.dump(params, fid, indent=4) diff --git a/llama/llama.py b/llama/llama.py index 5f169de4..2c1f4d16 100644 --- a/llama/llama.py +++ b/llama/llama.py @@ -24,6 +24,7 @@ class ModelArgs: norm_eps: float vocab_size: int rope_theta: float + rope_traditional: bool = True class RMSNorm(nn.Module): @@ -77,7 +78,9 @@ class Attention(nn.Module): self.wk = nn.Linear(args.dim, args.n_kv_heads * args.head_dim, bias=False) self.wv = nn.Linear(args.dim, args.n_kv_heads * args.head_dim, bias=False) self.wo = nn.Linear(args.n_heads * args.head_dim, args.dim, bias=False) - self.rope = RoPE(args.head_dim, traditional=True, base=args.rope_theta) + self.rope = RoPE( + args.head_dim, traditional=args.rope_traditional, base=args.rope_theta + ) def __call__( self, @@ -234,7 +237,7 @@ def generate(args): input("Press enter to start generation") print("------") - + print(args.prompt) x = mx.array([[tokenizer.bos_id()] + tokenizer.encode(args.prompt)]) skip = 0 prompt_processing = None @@ -248,7 +251,7 @@ def generate(args): mx.eval(token) prompt_processing = toc("Prompt processing", start) - if len(tokens) >= args.num_tokens: + if len(tokens) >= args.max_tokens: break elif (len(tokens) % args.write_every) == 0: @@ -261,8 +264,7 @@ def generate(args): mx.eval(tokens) full_gen = toc("Full generation", start) s = tokenizer.decode([t.item() for t in tokens]) - print(s[skip:], end="", flush=True) - print() + print(s[skip:], flush=True) print("------") print(prompt_processing) print(full_gen) @@ -292,7 +294,7 @@ def few_shot_generate(args): mx.eval(token) prompt_processing = toc("Prompt processing", start) - if len(tokens) >= args.num_tokens: + if len(tokens) >= args.max_tokens: break mx.eval(tokens) @@ -316,7 +318,8 @@ def few_shot_generate(args): s = tokenizer.decode([t.item() for t in tokens]) print(s[skip:], end="", flush=True) - prompt = open(args.prompt).read().strip() + print("[INFO] Loading few-shot examples from: {}".format(args.few_shot)) + prompt = open(args.few_shot).read().strip() while True: question = input("Ask a question: ") generate(prompt.replace("{}", question)) @@ -354,14 +357,17 @@ if __name__ == "__main__": "model", help="Path to the model directory containing the MLX weights" ) parser.add_argument("tokenizer", help="The sentencepiece tokenizer") - parser.add_argument("prompt", help="The message to be processed by the model") + parser.add_argument( + "--prompt", + help="The message to be processed by the model. Ignored when --few-shot is provided.", + default="In the beginning the Universe was created.", + ) parser.add_argument( "--few-shot", - action="store_true", help="Read a few shot prompt from a file (as in `sample_prompt.txt`).", ) parser.add_argument( - "--num-tokens", "-n", type=int, default=100, help="How many tokens to generate" + "--max-tokens", "-m", type=int, default=100, help="How many tokens to generate" ) parser.add_argument( "--write-every", type=int, default=1, help="After how many tokens to detokenize" diff --git a/llms/qwen/.gitignore b/llms/qwen/.gitignore new file mode 100644 index 00000000..0c68f15d --- /dev/null +++ b/llms/qwen/.gitignore @@ -0,0 +1,2 @@ +weights.npz +config.json diff --git a/llms/qwen/README.md b/llms/qwen/README.md new file mode 100644 index 00000000..f9276098 --- /dev/null +++ b/llms/qwen/README.md @@ -0,0 +1,41 @@ +# Qwen + +Qwen (通义千问) are a family of language models developed by Alibaba Cloud.[^1] +The architecture of the Qwen models is similar to Llama except for the bias in +the attention layers. + +## Setup + +First download and convert the model with: + +```sh +python convert.py +``` +The script downloads the model from Hugging Face. The default model is +`Qwen/Qwen-1_8B`. Check out the [Hugging Face page](https://huggingface.co/Qwen) to see a list of available models. + +The conversion script will make the `weights.npz` and `config.json` files in +the working directory. + +## Generate + +To generate text with the default prompt: + +```sh +python qwen.py +``` + +If you change the model, make sure to pass the corresponding tokenizer. E.g., +for Qwen 7B use: + +``` +python qwen.py --tokenizer Qwen/Qwen-7B +``` + +To see a list of options, run: + +```sh +python qwen.py --help +``` + +[^1]: For more details on the model see the official repo of [Qwen](https://github.com/QwenLM/Qwen) and the [Hugging Face](https://huggingface.co/Qwen). diff --git a/llms/qwen/convert.py b/llms/qwen/convert.py new file mode 100644 index 00000000..50a8d7a8 --- /dev/null +++ b/llms/qwen/convert.py @@ -0,0 +1,42 @@ +import argparse +from transformers import AutoModelForCausalLM +import numpy as np +import torch +import json + + +def replace_key(key: str) -> str: + if key.startswith("transformer."): + # remove transformer prefix + key = key.replace("transformer.", "") + + return key + + +def convert(model_path: str = "Qwen/Qwen-1_8B"): + model = AutoModelForCausalLM.from_pretrained( + model_path, trust_remote_code=True, torch_dtype=torch.float16 + ) + state_dict = model.state_dict() + weights = {replace_key(k): v.numpy() for k, v in state_dict.items()} + np.savez("weights.npz", **weights) + + # write config + config = model.config + config_dict = config.to_dict() + with open("config.json", "w") as f: + json.dump(config_dict, f, indent=4) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Convert Qwen model to npz") + + parser.add_argument( + "--model", + help="The huggingface model to be converted", + default="Qwen/Qwen-1_8B", + ) + + args = parser.parse_args() + + convert(args.model) diff --git a/llms/qwen/qwen.py b/llms/qwen/qwen.py new file mode 100644 index 00000000..c490d650 --- /dev/null +++ b/llms/qwen/qwen.py @@ -0,0 +1,269 @@ +import argparse +from dataclasses import dataclass +import json +import mlx.core as mx +import mlx.nn as nn +from mlx.utils import tree_unflatten +from transformers import AutoTokenizer + + +@dataclass +class ModelArgs: + hidden_size: int = 2048 + num_attention_heads: int = 16 + num_hidden_layers: int = 24 + kv_channels: int = 128 + max_position_embeddings: int = 8192 + layer_norm_epsilon: float = 1e-6 + intermediate_size: int = 11008 + no_bias: bool = True + vocab_size: int = 151936 + + +class RMSNorm(nn.Module): + def __init__(self, dims: int, eps: float = 1e-5): + super().__init__() + self.weight = mx.ones((dims,)) + self.eps = eps + + def _norm(self, x): + return x * mx.rsqrt(x.square().mean(-1, keepdims=True) + self.eps) + + def __call__(self, x): + output = self._norm(x.astype(mx.float32)).astype(x.dtype) + return self.weight * output + + +class Attention(nn.Module): + def __init__(self, args: ModelArgs): + super().__init__() + + hidden_size = args.hidden_size + self.num_attention_heads = args.num_attention_heads + + hidden_size_per_attention_head = hidden_size // self.num_attention_heads + + self.rotary_emb = nn.RoPE(hidden_size_per_attention_head, traditional=False) + + proj_size = args.kv_channels * self.num_attention_heads + + self.c_attn = nn.Linear(hidden_size, proj_size * 3, bias=True) + self.c_proj = nn.Linear(hidden_size, proj_size, bias=not args.no_bias) + + self.scale = hidden_size_per_attention_head**-0.5 + + def __call__(self, x, mask=None, cache=None): + qkv = self.c_attn(x) + + q, k, v = mx.split(qkv, 3, axis=-1) + + B, L, _ = q.shape + + q = q.reshape(B, L, self.num_attention_heads, -1).transpose(0, 2, 1, 3) + k = k.reshape(B, L, self.num_attention_heads, -1).transpose(0, 2, 1, 3) + v = v.reshape(B, L, self.num_attention_heads, -1).transpose(0, 2, 1, 3) + + if cache is not None: + k_cache, v_cache = cache + q = self.rotary_emb(q, offset=k_cache.shape[2]) + k = self.rotary_emb(k, offset=k_cache.shape[2]) + k = mx.concatenate([k_cache, k], axis=2) + v = mx.concatenate([v_cache, v], axis=2) + + else: + q = self.rotary_emb(q) + k = self.rotary_emb(k) + + scores = (q * self.scale) @ k.transpose(0, 1, 3, 2) + + if mask is not None: + scores = scores + mask + + scores = mx.softmax(scores.astype(mx.float32), axis=-1).astype(scores.dtype) + v_hat = (scores @ v).transpose(0, 2, 1, 3).reshape(B, L, -1) + + return self.c_proj(v_hat), (k, v) + + +class MLP(nn.Module): + def __init__(self, args: ModelArgs): + super().__init__() + + self.w1 = nn.Linear( + args.hidden_size, args.intermediate_size // 2, bias=not args.no_bias + ) + self.w2 = nn.Linear( + args.intermediate_size // 2, args.hidden_size, bias=not args.no_bias + ) + self.c_proj = nn.Linear( + args.intermediate_size // 2, args.hidden_size, bias=not args.no_bias + ) + + def __call__(self, x): + a1 = self.w1(x) + a2 = self.w2(x) + return self.c_proj(a1 * nn.silu(a2)) + + +class TransformerBlock(nn.Module): + def __init__(self, args: ModelArgs): + super().__init__() + + self.ln_1 = RMSNorm(args.hidden_size, eps=args.layer_norm_epsilon) + self.attn = Attention(args) + self.ln_2 = RMSNorm(args.hidden_size, eps=args.layer_norm_epsilon) + self.mlp = MLP(args) + + def __call__(self, x, mask=None, cache=None): + residual = x + x = self.ln_1(x) + x, cache = self.attn(x, mask=mask, cache=cache) + residual = x + residual + x = self.ln_2(residual) + x = self.mlp(x) + x = x + residual + + return x, cache + + +class Qwen(nn.Module): + def __init__(self, args: ModelArgs): + super().__init__() + + self.embed_dim = args.hidden_size + + self.wte = nn.Embedding(args.vocab_size, args.hidden_size) + self.h = [TransformerBlock(args) for _ in range(args.num_hidden_layers)] + self.ln_f = RMSNorm(self.embed_dim, eps=args.layer_norm_epsilon) + + self.lm_head = nn.Linear(self.embed_dim, args.vocab_size, bias=False) + + def __call__(self, inputs, mask=None, cache=None): + x = self.wte(inputs) + + mask = None + T = x.shape[1] + if T > 1: + mask = nn.MultiHeadAttention.create_additive_causal_mask(T) + mask = mask.astype(x.dtype) + + if cache is None: + cache = [None] * len(self.h) + + for e, layer in enumerate(self.h): + x, cache[e] = layer(x, mask, cache[e]) + + x = self.ln_f(x[:, T - 1 : T, :]) + return self.lm_head(x), cache + + +def generate(prompt: mx.array, model: Qwen, temp: 0.0): + def sample(logits): + if temp == 0: + return mx.argmax(logits, axis=-1) + else: + return mx.random.categorical(logits * (1 / temp)) + + logits, cache = model(prompt) + y = sample(logits[:, -1, :]) + yield y + + while True: + logits, cache = model(y[:, None], cache=cache) + y = sample(logits.squeeze(1)) + yield y + + +def load_model( + tokenizer_path: str = "Qwen/Qwen-1_8B", config_path: str = "config.json" +): + model_args = ModelArgs() + + with open(config_path, "r") as f: + config = json.load(f) + model_args.vocab_size = config["vocab_size"] + model_args.hidden_size = config["hidden_size"] + model_args.num_attention_heads = config["num_attention_heads"] + model_args.num_hidden_layers = config["num_hidden_layers"] + model_args.kv_channels = config["kv_channels"] + model_args.max_position_embeddings = config["max_position_embeddings"] + model_args.layer_norm_epsilon = config["layer_norm_epsilon"] + model_args.intermediate_size = config["intermediate_size"] + model_args.no_bias = config["no_bias"] + + model = Qwen(model_args) + + weights = mx.load("weights.npz") + model.update(tree_unflatten(list(weights.items()))) + tokenizer = AutoTokenizer.from_pretrained( + tokenizer_path, trust_remote_code=True, eos_token="<|endoftext|>" + ) + return model, tokenizer + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Qwen inference script") + parser.add_argument( + "--tokenizer", + help="The tokenizer to be used, defaults to Qwen/Qwen-1_8B", + default="Qwen/Qwen-1_8B", + ) + parser.add_argument( + "--prompt", + help="The message to be processed by the model", + # The example from the official huggingface repo of Qwen + default="蒙古国的首都是乌兰巴托(Ulaanbaatar)\n冰岛的首都是雷克雅未克(Reykjavik)\n埃塞俄比亚的首都是", + ) + parser.add_argument( + "--max_tokens", + "-m", + type=int, + default=100, + help="Maximum number of tokens to generate", + ) + parser.add_argument( + "--temp", + help="The sampling temperature.", + type=float, + default=0.0, + ) + parser.add_argument("--seed", type=int, default=0, help="The PRNG seed") + args = parser.parse_args() + + mx.random.seed(args.seed) + + model, tokenizer = load_model(args.tokenizer) + + prompt = tokenizer( + args.prompt, + return_tensors="np", + return_attention_mask=False, + )["input_ids"] + + prompt = mx.array(prompt) + + print(args.prompt, end="", flush=True) + + tokens = [] + for token, _ in zip(generate(prompt, model, args.temp), range(args.max_tokens)): + tokens.append(token) + + if (len(tokens) % 10) == 0: + mx.eval(tokens) + eos_index = next( + (i for i, t in enumerate(tokens) if t.item() == tokenizer.eos_token_id), + None, + ) + + if eos_index is not None: + tokens = tokens[:eos_index] + + s = tokenizer.decode([t.item() for t in tokens]) + print(s, end="", flush=True) + tokens = [] + if eos_index is not None: + break + + mx.eval(tokens) + s = tokenizer.decode([t.item() for t in tokens]) + print(s, flush=True) diff --git a/llms/qwen/requirements.txt b/llms/qwen/requirements.txt new file mode 100644 index 00000000..0ce17aec --- /dev/null +++ b/llms/qwen/requirements.txt @@ -0,0 +1,7 @@ +einops +mlx +numpy +transformers>=4.35 +transformers_stream_generator>=0.0.4 +torch +tiktoken diff --git a/lora/convert.py b/lora/convert.py index 16af7931..3fdb5d42 100644 --- a/lora/convert.py +++ b/lora/convert.py @@ -32,21 +32,30 @@ if __name__ == "__main__": os.makedirs(args.mlx_model) mlx_path = Path(args.mlx_model) + # Copy the tokenizer + tokenizer_path = torch_path / "tokenizer.model" + if not tokenizer_path.exists(): + print(f"Make sure there is a file tokenizer.model in {args.torch_model}") + exit(0) + shutil.copyfile( + str(tokenizer_path), + str(mlx_path / "tokenizer.model"), + ) + + # Copy the model weights state = torch.load(str(torch_path / "consolidated.00.pth")) np.savez( str(mlx_path / "weights.npz"), - **{k: v.to(torch.float16).numpy() for k, v in state.items()} - ) - - # Copy the tokenizer - shutil.copyfile( - str(torch_path / "tokenizer.model"), - str(mlx_path / "tokenizer.model"), + **{k: v.to(torch.float16).numpy() for k, v in state.items()}, ) # Copy the params with open(torch_path / "params.json", "r") as f: config = json.loads(f.read()) + unused = ["multiple_of"] + for k in unused: + if k in config: + config.pop(k) n_heads = config["n_heads"] if "sliding_window" in config: config.pop("sliding_window") @@ -55,6 +64,6 @@ if __name__ == "__main__": if "head_dim" not in config: config["head_dim"] = config["dim"] // n_heads if "hidden_dim" not in config: - config["hidden_dim"] = state["layers.0.feed_forward.w1.weight"].shape + config["hidden_dim"] = state["layers.0.feed_forward.w1.weight"].shape[0] with open(mlx_path / "params.json", "w") as outfile: - json.dump(config, outfile) + json.dump(config, outfile, indent=4) diff --git a/lora/lora.py b/lora/lora.py index 2e0fa0a1..e1412da3 100644 --- a/lora/lora.py +++ b/lora/lora.py @@ -332,9 +332,9 @@ def load_model(folder: str, dtype=mx.float16): tokenizer = Tokenizer(str(model_path / "tokenizer.model")) with open(model_path / "params.json", "r") as f: config = json.loads(f.read()) - model_args = ModelArgs(**config) if config.get("vocab_size", -1) < 0: config["vocab_size"] = tokenizer.vocab_size + model_args = ModelArgs(**config) weights = mx.load(str(model_path / "weights.npz")) weights = tree_unflatten(list(weights.items())) weights = tree_map(lambda p: p.astype(dtype), weights) diff --git a/t5/.gitignore b/t5/.gitignore new file mode 100644 index 00000000..ded9aef9 --- /dev/null +++ b/t5/.gitignore @@ -0,0 +1 @@ +*.npz diff --git a/t5/README.md b/t5/README.md new file mode 100644 index 00000000..a0cc861b --- /dev/null +++ b/t5/README.md @@ -0,0 +1,53 @@ +# T5 + +The T5 models are encoder-decoder models pre-trained on a mixture of +unsupervised and supervised tasks.[^1] These models work well on a variety of +tasks by prepending task-specific prefixes to the input, e.g.: +`translate English to German: …`, `summarize: ….`, etc. + +This example also supports the FLAN-T5 models variants.[^2] + +## Setup + +Download and convert the model: + +```sh +python convert.py --model +``` + +This will make the `.npz` file which MLX can read. + +The `` can be any of the following: + +| Model Name | Model Size | +| ---------- | ---------- +| t5-small | 60 million | +| t5-base | 220 million | +| t5-large | 770 million | +| t5-3b | 3 billion | +| t5-11b | 11 billion | + +The FLAN variants can be specified with `google/flan-t5-small`, +`google/flan-t5-base`, etc. See the [Hugging Face +page](https://huggingface.co/docs/transformers/model_doc/flan-t5) for a +complete list of models. + +## Generate + +Generate text with: + +```sh +python t5.py --model t5-small --prompt "translate English to German: A tasty apple" +``` + +This should give the output: `Ein leckerer Apfel` + +To see a list of options run: + +```sh +python t5.py --help +``` + +[^1]: For more information on T5 see the [original paper](https://arxiv.org/abs/1910.10683) + or the [Hugging Face page](https://huggingface.co/docs/transformers/model_doc/t5). +[^2]: For more information on FLAN-T5 see the [original paper](https://arxiv.org/abs/2210.11416). diff --git a/t5/convert.py b/t5/convert.py new file mode 100644 index 00000000..089d262d --- /dev/null +++ b/t5/convert.py @@ -0,0 +1,77 @@ +from transformers import T5ForConditionalGeneration +import numpy as np + + +SHARED_REPLACEMENT_PATTERNS = [ + (".block.", ".layers."), + (".k.", ".key_proj."), + (".o.", ".out_proj."), + (".q.", ".query_proj."), + (".v.", ".value_proj."), + ("shared.", "wte."), + ("lm_head.", "lm_head.linear."), + (".layer.0.layer_norm.", ".ln1."), + (".layer.1.layer_norm.", ".ln2."), + (".layer.2.layer_norm.", ".ln3."), + (".final_layer_norm.", ".ln."), + ( + "layers.0.layer.0.SelfAttention.relative_attention_bias.", + "relative_attention_bias.embeddings.", + ), +] + +ENCODER_REPLACEMENT_PATTERNS = [ + (".layer.0.SelfAttention.", ".attention."), + (".layer.1.DenseReluDense.", ".dense."), +] + +DECODER_REPLACEMENT_PATTERNS = [ + (".layer.0.SelfAttention.", ".self_attention."), + (".layer.1.EncDecAttention.", ".cross_attention."), + (".layer.2.DenseReluDense.", ".dense."), +] + + +def replace_key(key: str) -> str: + for old, new in SHARED_REPLACEMENT_PATTERNS: + key = key.replace(old, new) + if key.startswith("encoder."): + for old, new in ENCODER_REPLACEMENT_PATTERNS: + key = key.replace(old, new) + elif key.startswith("decoder."): + for old, new in DECODER_REPLACEMENT_PATTERNS: + key = key.replace(old, new) + return key + + +def convert(model_name, dtype): + dtype = getattr(np, dtype) + model = T5ForConditionalGeneration.from_pretrained(model_name, torch_dtype="auto") + weights = { + replace_key(k): v.numpy().astype(dtype) + for k, v in model.state_dict().items() + } + file_name = model_name.replace("/", "-") + print(f"Saving weights to {file_name}.npz") + np.savez(f"{file_name}.npz", **weights) + + +if __name__ == "__main__": + import argparse + + parser = argparse.ArgumentParser(description="Convert T5 weights to MLX") + parser.add_argument( + "--model", + type=str, + help="Name of the T5 model.", + default="t5-small", + ) + parser.add_argument( + "--dtype", + help="The model data type.", + type=str, + choices=["float16", "float32"], + default="float32", + ) + args = parser.parse_args() + convert(args.model, args.dtype) diff --git a/t5/hf_t5.py b/t5/hf_t5.py new file mode 100644 index 00000000..ddd99610 --- /dev/null +++ b/t5/hf_t5.py @@ -0,0 +1,54 @@ +from transformers import T5ForConditionalGeneration, T5EncoderModel, AutoTokenizer + +import argparse + + +def embed(t5_model: str): + batch = [ + "translate English to German: That is good.", + "This is an example of T5 working on MLX.", + ] + + tokenizer = AutoTokenizer.from_pretrained(t5_model) + torch_model = T5EncoderModel.from_pretrained(t5_model) + torch_tokens = tokenizer(batch, return_tensors="pt", padding=True) + torch_forward = torch_model(**torch_tokens, output_hidden_states=True) + torch_output = torch_forward.last_hidden_state.detach().numpy() + + print("\n TF BERT:") + for input_str, embedding in list(zip(batch, torch_output)): + print("Input:", input_str) + print(embedding) + print() + + +def generate(t5_model: str): + prompt = "translate English to German: As much as six inches of rain could fall in the New York City region through Monday morning, and officials warned of flooding along the coast." + tokenizer = AutoTokenizer.from_pretrained(t5_model) + torch_model = T5ForConditionalGeneration.from_pretrained(t5_model) + torch_tokens = tokenizer(prompt, return_tensors="pt", padding=True).input_ids + outputs = torch_model.generate(torch_tokens, do_sample=False, max_length=512) + print(tokenizer.decode(outputs[0], skip_special_tokens=True)) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + description="Run the T5 model using Hugging Face Transformers." + ) + parser.add_argument( + "--encode-only", + action="store_true", + help="Only run the encoder and print the embeddings.", + default=False, + ) + parser.add_argument( + "--model", + default="t5-small", + help="The huggingface name of the T5 model to save.", + ) + args = parser.parse_args() + if args.encode_only: + embed(args.model) + else: + generate(args.model) + diff --git a/t5/requirements.txt b/t5/requirements.txt new file mode 100644 index 00000000..4a37303a --- /dev/null +++ b/t5/requirements.txt @@ -0,0 +1,3 @@ +mlx +numpy +transformers diff --git a/t5/t5.py b/t5/t5.py new file mode 100644 index 00000000..f80c3cb3 --- /dev/null +++ b/t5/t5.py @@ -0,0 +1,469 @@ +import argparse +from typing import Optional, Tuple, List +from time import perf_counter_ns + +import numpy as np +import mlx.core as mx +import mlx.nn as nn +from mlx.utils import tree_unflatten, tree_map +from transformers import T5Config, T5Tokenizer + + +def _relative_position_bucket( + relative_position, bidirectional=True, num_buckets=32, max_distance=128 +): + """ + Adapted from HF Tensorflow: + https://github.com/huggingface/transformers/blob/main/src/transformers/models/t5/modeling_t5.py + + Translate relative position to a bucket number for relative attention. The relative position is defined as + memory_position - query_position, i.e. the distance in tokens from the attending position to the attended-to + position. If bidirectional=False, then positive relative positions are invalid. We use smaller buckets for + small absolute relative_position and larger buckets for larger absolute relative_positions. All relative + positions >=max_distance map to the same bucket. All relative positions <=-max_distance map to the same bucket. + This should allow for more graceful generalization to longer sequences than the model has been trained on + + Args: + relative_position: an int32 Tensor + bidirectional: a boolean - whether the attention is bidirectional + num_buckets: an integer + max_distance: an integer + + Returns: + a Tensor with the same shape as relative_position, containing int32 values in the range [0, num_buckets) + """ + relative_buckets = 0 + if bidirectional: + num_buckets //= 2 + relative_buckets += (relative_position > 0).astype(mx.int16) * num_buckets + relative_position = mx.abs(relative_position) + else: + relative_position = -mx.minimum( + relative_position, mx.zeros_like(relative_position) + ) + # now relative_position is in the range [0, inf) + + # half of the buckets are for exact increments in positions + max_exact = num_buckets // 2 + is_small = relative_position < max_exact + + # The other half of the buckets are for logarithmically bigger bins in positions up to max_distance + scale = (num_buckets - max_exact) / np.log(max_distance / max_exact) + relative_position_if_large = max_exact + ( + mx.log(relative_position.astype(mx.float32) / max_exact) * scale + ).astype(mx.int16) + relative_position_if_large = mx.minimum(relative_position_if_large, num_buckets - 1) + relative_buckets += mx.where( + is_small, relative_position, relative_position_if_large + ) + return relative_buckets + + +class RelativePositionBias(nn.Module): + def __init__(self, config: T5Config, bidirectional: bool): + self.bidirectional = bidirectional + self.num_buckets = config.relative_attention_num_buckets + self.max_distance = config.relative_attention_max_distance + self.n_heads = config.num_heads + self.embeddings = nn.Embedding( + config.relative_attention_num_buckets, config.num_heads + ) + + def __call__(self, query_length: int, key_length: int, offset: int = 0): + """Compute binned relative position bias""" + context_position = mx.arange(offset, query_length)[:, None] + memory_position = mx.arange(key_length)[None, :] + + # shape (query_length, key_length) + relative_position = memory_position - context_position + relative_position_bucket = _relative_position_bucket( + relative_position, + bidirectional=self.bidirectional, + num_buckets=self.num_buckets, + max_distance=self.max_distance, + ) + + # shape (query_length, key_length, num_heads) + values = self.embeddings(relative_position_bucket) + + # shape (num_heads, query_length, key_length) + return values.transpose(2, 0, 1) + + +class MultiHeadAttention(nn.Module): + def __init__(self, config: T5Config): + super().__init__() + inner_dim = config.d_kv * config.num_heads + self.num_heads = config.num_heads + self.query_proj = nn.Linear(config.d_model, inner_dim, bias=False) + self.key_proj = nn.Linear(config.d_model, inner_dim, bias=False) + self.value_proj = nn.Linear(config.d_model, inner_dim, bias=False) + self.out_proj = nn.Linear(inner_dim, config.d_model, bias=False) + + def __call__( + self, + queries: mx.array, + keys: mx.array, + values: mx.array, + mask: Optional[mx.array], + cache: Optional[Tuple[mx.array, mx.array]] = None, + ) -> [mx.array, Tuple[mx.array, mx.array]]: + queries = self.query_proj(queries) + keys = self.key_proj(keys) + values = self.value_proj(values) + + num_heads = self.num_heads + B, L, _ = queries.shape + _, S, _ = keys.shape + queries = queries.reshape(B, L, num_heads, -1).transpose(0, 2, 1, 3) + keys = keys.reshape(B, S, num_heads, -1).transpose(0, 2, 3, 1) + values = values.reshape(B, S, num_heads, -1).transpose(0, 2, 1, 3) + + if cache is not None: + key_cache, value_cache = cache + keys = mx.concatenate([key_cache, keys], axis=3) + values = mx.concatenate([value_cache, values], axis=2) + + # Dimensions are [batch x num heads x sequence x hidden dim] + queries = queries + scores = queries @ keys + if mask is not None: + scores = scores + mask.astype(scores.dtype) + + scores = mx.softmax(scores.astype(mx.float32), axis=-1).astype(scores.dtype) + values_hat = (scores @ values).transpose(0, 2, 1, 3).reshape(B, L, -1) + return self.out_proj(values_hat), (keys, values) + + +class RMSNorm(nn.Module): + def __init__(self, dims: int, eps: float = 1e-5): + super().__init__() + self.weight = mx.ones((dims,)) + self.eps = eps + + def _norm(self, x): + return x * mx.rsqrt(x.square().mean(-1, keepdims=True) + self.eps) + + def __call__(self, x): + t = x.dtype + output = self._norm(x).astype(t) + return self.weight * output + + +class DenseActivation(nn.Module): + def __init__(self, config: T5Config): + super().__init__() + mlp_dims = config.d_ff or config.d_model * 4 + self.gated = config.feed_forward_proj.startswith("gated") + if self.gated: + self.wi_0 = nn.Linear(config.d_model, mlp_dims, bias=False) + self.wi_1 = nn.Linear(config.d_model, mlp_dims, bias=False) + else: + self.wi = nn.Linear(config.d_model, mlp_dims, bias=False) + self.wo = nn.Linear(mlp_dims, config.d_model, bias=False) + activation = config.feed_forward_proj.removeprefix("gated-") + if activation == "relu": + self.act = nn.relu + elif activation == "gelu": + self.act = nn.gelu + elif activation == "silu": + self.act = nn.silu + else: + raise ValueError(f"Unknown activation: {activation}") + + def __call__(self, x): + if self.gated: + hidden_act = self.act(self.wi_0(x)) + hidden_linear = self.wi_1(x) + x = hidden_act * hidden_linear + else: + x = self.act(self.wi(x)) + return self.wo(x) + + +class TransformerEncoderLayer(nn.Module): + def __init__(self, config: T5Config): + super().__init__() + self.attention = MultiHeadAttention(config) + self.ln1 = RMSNorm(config.d_model, eps=config.layer_norm_epsilon) + self.ln2 = RMSNorm(config.d_model, eps=config.layer_norm_epsilon) + self.dense = DenseActivation(config) + + def __call__(self, x, mask): + y = self.ln1(x) + y, _ = self.attention(y, y, y, mask=mask) + x = x + y + + y = self.ln2(x) + y = self.dense(y) + return x + y + + +class TransformerEncoder(nn.Module): + def __init__(self, config: T5Config): + super().__init__() + self.layers = [ + TransformerEncoderLayer(config) for i in range(config.num_layers) + ] + self.ln = RMSNorm(config.d_model, eps=config.layer_norm_epsilon) + self.relative_attention_bias = RelativePositionBias(config, bidirectional=True) + + def __call__(self, x: mx.array): + pos_bias = self.relative_attention_bias(x.shape[1], x.shape[1]) + for layer in self.layers: + x = layer(x, mask=pos_bias) + return self.ln(x) + + +class TransformerDecoderLayer(nn.Module): + def __init__(self, config: T5Config): + super().__init__() + self.self_attention = MultiHeadAttention(config) + self.cross_attention = MultiHeadAttention(config) + self.ln1 = RMSNorm(config.d_model, eps=config.layer_norm_epsilon) + self.ln2 = RMSNorm(config.d_model, eps=config.layer_norm_epsilon) + self.ln3 = RMSNorm(config.d_model, eps=config.layer_norm_epsilon) + self.dense = DenseActivation(config) + + def __call__( + self, + x: mx.array, + memory: mx.array, + mask: mx.array, + memory_mask: mx.array, + cache: Optional[List[Tuple[mx.array, mx.array]]] = None, + ): + y = self.ln1(x) + y, cache = self.self_attention(y, y, y, mask, cache) + x = x + y + + y = self.ln2(x) + y, _ = self.cross_attention(y, memory, memory, memory_mask) + x = x + y + + y = self.ln3(x) + y = self.dense(y) + x = x + y + + return x, cache + + +class TransformerDecoder(nn.Module): + def __init__(self, config: T5Config): + super().__init__() + self.layers = [ + TransformerDecoderLayer(config) for i in range(config.num_layers) + ] + self.ln = RMSNorm(config.d_model, eps=config.layer_norm_epsilon) + self.relative_attention_bias = RelativePositionBias(config, bidirectional=False) + + def __call__(self, x, memory, mask, memory_mask, cache=None): + if cache is not None: + offset = cache[0][0].shape[3] + else: + offset = 0 + cache = [None] * len(self.layers) + + T = offset + x.shape[1] + pos_bias = self.relative_attention_bias(T, T, offset=offset) + if mask is not None: + mask += pos_bias + else: + mask = pos_bias + + for e, layer in enumerate(self.layers): + x, cache[e] = layer(x, memory, mask, memory_mask, cache=cache[e]) + x = self.ln(x) + + return x, cache + + +class OutputHead(nn.Module): + def __init__(self, config: T5Config): + self.linear = nn.Linear(config.d_model, config.vocab_size, bias=False) + + def __call__(self, inputs): + return self.linear(inputs) + + +class T5(nn.Module): + def __init__(self, config: T5Config): + self.wte = nn.Embedding(config.vocab_size, config.d_model) + self.encoder = TransformerEncoder(config) + self.decoder = TransformerDecoder(config) + self.tie_word_embeddings = config.tie_word_embeddings + if not self.tie_word_embeddings: + self.lm_head = OutputHead(config) + self.model_dim = config.d_model + + def encode(self, inputs: mx.array): + return self.encoder(self.wte(inputs)) + + def decode( + self, + inputs: mx.array, + memory: mx.array, + cache=None, + ): + inputs = self.wte(inputs) + T = inputs.shape[1] + if T > 1: + mask = nn.MultiHeadAttention.create_additive_causal_mask(T) + mask = mask.astype(inputs.dtype) + else: + mask = None + + y, cache = self.decoder( + inputs, memory=memory, mask=mask, memory_mask=None, cache=cache + ) + if not self.tie_word_embeddings: + y *= self.model_dim**-0.5 + y = self.lm_head(y) + else: + y = y @ self.wte.weight.T + return y, cache + + def __call__( + self, + inputs: mx.array, + decoder_inputs: mx.array, + ): + return self.decode(decoder_inputs, self.encode(inputs))[0] + + +class Tokenizer: + def __init__(self, model_name: str, config: T5Config): + self._decoder_start_id = config.decoder_start_token_id + self._tokenizer = T5Tokenizer.from_pretrained( + args.model, + legacy=False, + model_max_length=getattr(config, 'n_positions', 512) + ) + + @property + def eos_id(self) -> int: + return self._tokenizer.eos_token_id + + @property + def decoder_start_id(self) -> int: + return self._decoder_start_id + + def encode(self, s: str) -> mx.array: + return mx.array( + self._tokenizer( + s, + return_tensors="np", + return_attention_mask=False, + )["input_ids"] + ) + + def decode(self, t: List[int], with_sep: bool = True) -> str: + tokens = self._tokenizer.convert_ids_to_tokens(t) + return "".join(t.replace("▁", " " if with_sep else "") for t in tokens) + + +def generate(prompt: str, model: T5, tokenizer: Tokenizer, temp: Optional[float] = 0.0): + def sample(logits): + if temp == 0: + return mx.argmax(logits, axis=-1) + else: + return mx.random.categorical(logits * (1 / temp)) + + prompt = tokenizer.encode(prompt) + decoder_inputs = mx.array([tokenizer.decoder_start_id]) + memory = model.encode(prompt) + cache = None + y = decoder_inputs + while True: + logits, cache = model.decode(y[None], memory, cache=cache) + y = sample(logits[:, -1, :]) + yield y.squeeze() + + +def load_model(model_name: str, dtype: str = "float16"): + config = T5Config.from_pretrained(args.model) + dtype = getattr(mx, dtype) + model = T5(config) + file_name = model_name.replace("/", "-") + weights = mx.load(f"{file_name}.npz") + weights = tree_unflatten(list(weights.items())) + weights = tree_map(lambda p: p.astype(dtype), weights) + model.update(weights) + mx.eval(model.parameters()) + return model, Tokenizer(args.model, config) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="T5 Inference script") + parser.add_argument( + "--model", + type=str, + help="Name of the T5 model.", + default="t5-small", + ) + parser.add_argument( + "--prompt", + help="", + default="translate English to German: That is good.", + ) + parser.add_argument( + "--encode-only", + action="store_true", + default=False, + help="Whether to decode or not. If true, will output last layer of encoder.", + ) + parser.add_argument( + "--max-tokens", + "-m", + type=int, + default=100, + help="Maximum number of tokens to generate", + ) + parser.add_argument( + "--temp", + help="The sampling temperature.", + type=float, + default=0.0, + ) + parser.add_argument( + "--dtype", + help="The model data type.", + type=str, + choices=["float16", "bfloat16", "float32"], + default="bfloat16", + ) + + parser.add_argument("--seed", type=int, default=0, help="The PRNG seed") + args = parser.parse_args() + + mx.random.seed(args.seed) + + model, tokenizer = load_model(args.model, args.dtype) + + if args.encode_only: + print("[INFO] Encoding with T5...", flush=True) + print(args.prompt, flush=True) + encoder_output = model.encode(tokenizer.encode(args.prompt)) + print(encoder_output, flush=True) + exit(0) + + print("[INFO] Generating with T5...", flush=True) + print("Input: ", args.prompt, flush=True) + + start = perf_counter_ns() + for token, n_tokens in zip( + generate(args.prompt, model, tokenizer, args.temp), range(args.max_tokens) + ): + if token.item() == tokenizer.eos_id: + break + print( + tokenizer.decode([token.item()], with_sep=n_tokens > 0), + end="", + flush=True, + ) + + n_tokens += 1 + end = perf_counter_ns() + elapsed = (end - start) / 1.0e9 + print() + print(f"Time: {elapsed:.2f} seconds, tokens/s: {n_tokens / elapsed:.2f}")