From 34a62ddc4949bfaacb52d3e2015d362d95ad3e94 Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Thu, 28 Dec 2023 14:55:25 -0800 Subject: [PATCH] switch to t5 --- llms/speculative_decoding/README.md | 49 ++- llms/speculative_decoding/convert.py | 75 ++++ llms/speculative_decoding/decoder.py | 125 +++--- llms/speculative_decoding/main.py | 73 +++- llms/speculative_decoding/model.py | 437 ++++++++++++++------- llms/speculative_decoding/requirements.txt | 5 +- t5/t5.py | 1 - 7 files changed, 529 insertions(+), 236 deletions(-) create mode 100644 llms/speculative_decoding/convert.py diff --git a/llms/speculative_decoding/README.md b/llms/speculative_decoding/README.md index 0a549ac3..53b773f7 100644 --- a/llms/speculative_decoding/README.md +++ b/llms/speculative_decoding/README.md @@ -1,11 +1,11 @@ # Speculative Decoding -This example implements [speculative decoding] for text generation.[^1]. -Speculative decoding uses a smaller draft model to propose several tokens, and -then a larger model which decides which tokens to accept. The generated text is -identical to what the larger model would produce on its own, but with far fewer -forward passes of the large model since it can evaluate the draft tokens in -parallel. +This example implements speculative decoding with the T5 model for text +generation.[^1] Speculative decoding uses a smaller draft model to propose +several tokens, and a larger model to decide which tokens to accept. The +distribution of the generated text is identical to what the larger model would +produce on its own, but with far fewer forward passes of the large model since +it can evaluate the draft tokens in parallel. ### Setup @@ -16,6 +16,19 @@ cd speculative_decoding pip install -r requirements.txt ``` +Then convert the model and the draft model. For example, you can convert th +T5 11B model with: + +``` +python convert.py --model t5-11b +``` + +And for the draft model, convert the T5 small model with: + +``` +python convert.py --model t5-small +``` + ### Run You can run with the default arguments: @@ -24,9 +37,27 @@ You can run with the default arguments: python main.py ``` +To see a full list of options use: +``` +python main.py --help +``` + +### Notes + Speculative decoding works well when most of the tokens from the draft model are accepted by the larger model. That's more likely to happen if the models -are trained on similar data. The default setting in this example uses TinyLlama -as a draft morel for Llama 7B. +are trained on similar data. -[^1] See the paper [Fast Inference from Transformers via Speculative Decoding](https://arxiv.org/abs/2211.17192) +One way to increase the chance of accepting a draft token is with the parameter +`--delta`. This parameter can be in the range `[0, 1]`. If it is `1` then all +the draft tokens will be accepted by the model. If it is `0`, then only draft +tokens which match the original acceptance criterion kept.[^1] Values closer to +`1` increase the chance that a draft token is accepted. + +Conversely, the fewer draft tokens accepted by the model, the more expensive +speculative decoding is. You can use `--draft` to tune the number of draft +tokens per model evaluation in order to reduce the number of discarded draft +tokens. + +[^1] See the paper [Fast Inference from Transformers via Speculative +Decoding](https://arxiv.org/abs/2211.17192) diff --git a/llms/speculative_decoding/convert.py b/llms/speculative_decoding/convert.py new file mode 100644 index 00000000..e2108a0c --- /dev/null +++ b/llms/speculative_decoding/convert.py @@ -0,0 +1,75 @@ +import numpy as np +from transformers import T5ForConditionalGeneration + +SHARED_REPLACEMENT_PATTERNS = [ + (".block.", ".layers."), + (".k.", ".key_proj."), + (".o.", ".out_proj."), + (".q.", ".query_proj."), + (".v.", ".value_proj."), + ("shared.", "wte."), + ("lm_head.", "lm_head.linear."), + (".layer.0.layer_norm.", ".ln1."), + (".layer.1.layer_norm.", ".ln2."), + (".layer.2.layer_norm.", ".ln3."), + (".final_layer_norm.", ".ln."), + ( + "layers.0.layer.0.SelfAttention.relative_attention_bias.", + "relative_attention_bias.embeddings.", + ), +] + +ENCODER_REPLACEMENT_PATTERNS = [ + (".layer.0.SelfAttention.", ".attention."), + (".layer.1.DenseReluDense.", ".dense."), +] + +DECODER_REPLACEMENT_PATTERNS = [ + (".layer.0.SelfAttention.", ".self_attention."), + (".layer.1.EncDecAttention.", ".cross_attention."), + (".layer.2.DenseReluDense.", ".dense."), +] + + +def replace_key(key: str) -> str: + for old, new in SHARED_REPLACEMENT_PATTERNS: + key = key.replace(old, new) + if key.startswith("encoder."): + for old, new in ENCODER_REPLACEMENT_PATTERNS: + key = key.replace(old, new) + elif key.startswith("decoder."): + for old, new in DECODER_REPLACEMENT_PATTERNS: + key = key.replace(old, new) + return key + + +def convert(model_name, dtype): + dtype = getattr(np, dtype) + model = T5ForConditionalGeneration.from_pretrained(model_name, torch_dtype="auto") + weights = { + replace_key(k): v.numpy().astype(dtype) for k, v in model.state_dict().items() + } + file_name = model_name.replace("/", "-") + print(f"Saving weights to {file_name}.npz") + np.savez(f"{file_name}.npz", **weights) + + +if __name__ == "__main__": + import argparse + + parser = argparse.ArgumentParser(description="Convert T5 weights to MLX") + parser.add_argument( + "--model", + type=str, + help="Name of the T5 model.", + default="t5-small", + ) + parser.add_argument( + "--dtype", + help="The model data type.", + type=str, + choices=["float16", "float32"], + default="float32", + ) + args = parser.parse_args() + convert(args.model, args.dtype) diff --git a/llms/speculative_decoding/decoder.py b/llms/speculative_decoding/decoder.py index fc29118b..838edd91 100644 --- a/llms/speculative_decoding/decoder.py +++ b/llms/speculative_decoding/decoder.py @@ -1,4 +1,3 @@ -import time from dataclasses import dataclass, field from typing import List, Optional @@ -6,18 +5,26 @@ import mlx.core as mx import mlx.nn as nn import numpy as np import transformers -from model import Llama -from prompts import create_urial_prompt +from model import Model class Tokenizer: def __init__(self, model_name: str): - self._tokenizer = transformers.AutoTokenizer.from_pretrained(model_name) + self._tokenizer = transformers.AutoTokenizer.from_pretrained( + model_name, + legacy=False, + model_max_length=512, + ) + self._decoder_start_id = 0 @property def eos_id(self) -> int: return self._tokenizer.eos_token_id + @property + def decoder_start_id(self) -> int: + return self._decoder_start_id + def encode(self, s: str) -> mx.array: return mx.array( self._tokenizer(s, return_tensors="np", return_attention_mask=False,)[ @@ -25,44 +32,34 @@ class Tokenizer: ].squeeze(0) ) - def decode(self, t: List[int], with_sep: bool = True) -> str: - tokens = self._tokenizer.convert_ids_to_tokens(t) - return "".join(t.replace("▁", " " if with_sep else "") for t in tokens) + def decode(self, t: List[int]) -> str: + return self._tokenizer.decode(t) class SpeculativeDecoder: def __init__( self, - model: str, - draft_model: str = None, + model: Model, + draft_model: Model, + tokenizer: str, num_draft: int = 5, delta: float = 0.0, ): - self.tokenizer = Tokenizer(model) - self.model = Llama.from_hugging_face(model) - if draft_model is not None: - self.draft_model = Llama.from_hugging_face(draft_model) + self.tokenizer = Tokenizer(tokenizer) + self.model = model + self.draft_model = draft_model self.num_draft = num_draft self.delta = delta - def tokenize(self, prompt): - # if self.tokenizer.chat_template is not None: - # tokenized = self.tokenizer.apply_chat_template( - # prompt, tokenize=True, add_generation_prompt=True - # ) - # else: - # use urial zero-shot template - tokenized = self.tokenizer.encode(create_urial_prompt(prompt["content"])) - return tokenized - def _generate( self, x: mx.array, + memory: mx.array, draft: bool = False, ): model = self.draft_model if draft else self.model while True: - logits = model(x[None, :], next_tokens=1).squeeze((0, 1)) + logits = model.decode(x[None], memory)[0, -1] x = mx.argmax(logits, keepdims=True) lognorm = mx.logsumexp(logits.astype(mx.float32)) logprob = logits[x] - lognorm @@ -72,25 +69,26 @@ class SpeculativeDecoder: self, prompt, max_tokens: int = 100, - draft: bool = False, ): - x = self.tokenize(prompt) - start = time.time() - for (token, _), n in zip(self._generate(x, draft=draft), range(max_tokens)): - token = token.item() + memory = self.model.encode(self.tokenizer.encode(prompt)[None]) + x = mx.array([self.tokenizer.decoder_start_id]) + skip = 0 + outputs = [] + for (token, _), n in zip(self._generate(x, memory), range(max_tokens)): if token == self.tokenizer.eos_id: break - print(self.tokenizer.decode(token, with_sep=n > 0), end="", flush=True) - run_time = time.time() - start + outputs.append(token.item()) + if (n + 1) % 10 == 0: + str_output = self.tokenizer.decode(outputs) + print(str_output[skip:], end="", flush=True) + skip = len(str_output) + + print(self.tokenizer.decode(outputs)[skip:], end="", flush=True) print() - print(f"=== GENERATED {n + 1} TOKENS in {run_time} SECONDS ===") - if draft: - self.draft_model.reset_cache() - else: - self.model.reset_cache() + self.model.reset_cache() def _get_num_accept(self, draft_tokens, draft_probs, model_logits): - # equal_toks = sampled[:-1] == draft_tokens + # accept_toks = mx.argmax(model_logits, axis=-1) == draft_tokens model_probs = mx.take_along_axis( model_logits, draft_tokens[:, None], @@ -111,14 +109,19 @@ class SpeculativeDecoder: def sample(logits): return mx.argmax(logits, axis=-1) - tokens = mx.array(self.tokenize(prompt), mx.uint32) - start = time.time() + prompt = mx.array(self.tokenizer.encode(prompt), mx.uint32)[None] + memory = self.model.encode(prompt) + draft_memory = self.draft_model.encode(prompt) - decoding_steps = 0 + tokens = mx.array([self.tokenizer.decoder_start_id]) + + n_steps = 0 ntoks = 0 - accepted_draft_tokens = 0 - total_draft_tokens = 0 + n_accepted = 0 + n_draft = 0 + outputs = [] + skip = 0 draft_inputs = tokens inputs = tokens while True: @@ -127,7 +130,7 @@ class SpeculativeDecoder: draft_probs = [] for _, (t, p) in zip( range(ntoks, min(ntoks + self.num_draft, max_tokens)), - self._generate(draft_inputs, draft=True), + self._generate(draft_inputs, draft_memory, draft=True), ): draft_tokens.append(t) draft_probs.append(p) @@ -138,10 +141,10 @@ class SpeculativeDecoder: draft_tokens = mx.concatenate(draft_tokens) draft_probs = mx.concatenate(draft_probs) verify_tokens = mx.concatenate([inputs, draft_tokens]) - logits = self.model( - verify_tokens[None, :], next_tokens=draft_tokens.size + 1 + logits = self.model.decode( + verify_tokens[None, :], + memory, ).squeeze(0) - # sampled = sample(logits).squeeze(0) # Only keep samples that match the draft: num_to_accept = self._get_num_accept( @@ -155,38 +158,34 @@ class SpeculativeDecoder: [new_tokens, mx.argmax(logits[num_to_accept], keepdims=True)] ) - accepted_draft_tokens += num_to_accept - total_draft_tokens += draft_tokens.size + n_accepted += num_to_accept + n_draft += draft_tokens.size # Rewind the cache for unaccepted tokens: if (n := draft_tokens.size) > num_to_accept: self.draft_model.truncate_cache(n - new_tokens.size) self.model.truncate_cache(n - new_tokens.size + 1) - decoding_steps += 1 + n_steps += 1 for t in new_tokens.tolist(): if t == self.tokenizer.eos_id or ntoks >= max_tokens: break - print(self.tokenizer.decode(t, with_sep=ntoks > 0), end="", flush=True) + outputs.append(t) ntoks += 1 + + str_output = self.tokenizer.decode(outputs) + print(str_output[skip:], end="", flush=True) + skip = len(str_output) + if ntoks >= max_tokens or new_tokens[-1] == self.tokenizer.eos_id: break draft_inputs = new_tokens[max(new_tokens.size - 2, 0) :] inputs = draft_inputs[-1:] - end = time.time() + print(self.tokenizer.decode(outputs)[skip:], end="", flush=True) + print() + self.model.reset_cache() self.draft_model.reset_cache() - print() - print( - "=== GENERATED", - ntoks, - "TOKENS IN", - round(end - start, 2), - "SECONDS ===", - ) - print( - f"=== ACCEPTED {accepted_draft_tokens} of {total_draft_tokens} DRAFT TOKENS ===" - ) - print("=== DECODING STEPS", decoding_steps, "===") + return {"n_accepted": n_accepted, "n_draft": n_draft, "n_steps": n_steps} diff --git a/llms/speculative_decoding/main.py b/llms/speculative_decoding/main.py index eeae15ca..259f1507 100644 --- a/llms/speculative_decoding/main.py +++ b/llms/speculative_decoding/main.py @@ -1,31 +1,51 @@ import argparse +import glob +import json +import time +from pathlib import Path import mlx.core as mx +import mlx.nn as nn from decoder import SpeculativeDecoder +from mlx.utils import tree_unflatten +from model import Model +from transformers import T5Config + + +def load_model(model_name: str): + config = T5Config.from_pretrained(model_name) + model = Model(config) + weights = mx.load(f"{model_name}.npz") + weights = tree_unflatten(list(weights.items())) + model.update(weights) + mx.eval(model.parameters()) + return model def main(args): mx.random.seed(args.seed) spec_decoder = SpeculativeDecoder( - # model="TinyLlama/TinyLlama-1.1B-intermediate-step-1195k-token-2.5T", - model="meta-llama/Llama-2-7b-hf", - draft_model="TinyLlama/TinyLlama-1.1B-intermediate-step-1195k-token-2.5T", + model=load_model(args.model_name), + draft_model=load_model(args.draft_model_name), + tokenizer=args.model_name, delta=args.delta, num_draft=args.num_draft, ) - prompt = {"role": "user", "content": "Finish the monologue: To be, or not to be..."} + tic = time.time() + print(args.prompt) + if args.regular_decode: + spec_decoder.generate(args.prompt, max_tokens=args.max_tokens) + else: + stats = spec_decoder.speculative_decode(args.prompt, max_tokens=args.max_tokens) + print("=" * 10) + print(f"Accepted {stats['n_accepted']} / {stats['n_draft']}.") + print(f"Decoding steps {stats['n_steps']}.") - # Do 1 regular generation to get warmed up (the first one is slow) - # engine.generate(messages, max_tokens=1) - # engine.generate(messages, max_tokens=1, draft=True) - - # Time regular generation - spec_decoder.generate(prompt, max_tokens=125) - - # Time speculative decoding - spec_decoder.speculative_decode(prompt, max_tokens=125) + toc = time.time() + print("=" * 10) + print(f"Full generation time {toc - tic:.3f}") if __name__ == "__main__": @@ -36,17 +56,44 @@ if __name__ == "__main__": default=5, help="Number of draft tokens to use per decoding step.", ) + parser.add_argument( + "--model-name", + help="Name of the model.", + default="t5-small", + ) + parser.add_argument( + "--draft-model-name", + help="Name of the draft model.", + default="t5-small", + ) parser.add_argument( "--seed", type=int, default=0, help="PRNG seed.", ) + parser.add_argument( + "--max-tokens", + "-m", + type=int, + default=100, + help="Maximum number of tokens to generate.", + ) + parser.add_argument( + "--prompt", + default="translate English to French: Let's go to the store and buy some groceries including eggs, avocadoes, and bread.", + help="The prompt processed by the model.", + ) parser.add_argument( "--delta", type=float, default=0.1, help="Lenience for accepting the proposal tokens.", ) + parser.add_argument( + "--regular-decode", + action="store_true", + help="Use regular decoding instead of speculative decoding.", + ) args = parser.parse_args() main(args) diff --git a/llms/speculative_decoding/model.py b/llms/speculative_decoding/model.py index 1bf70fee..ed4a7d77 100644 --- a/llms/speculative_decoding/model.py +++ b/llms/speculative_decoding/model.py @@ -1,17 +1,135 @@ -from typing import Optional, Tuple +from typing import List, Optional, Tuple import mlx.core as mx import mlx.nn as nn +import numpy as np from mlx.utils import tree_map, tree_unflatten -from transformers import AutoModelForCausalLM, LlamaConfig +from transformers import AutoTokenizer, T5Config -def create_additive_causal_mask(N: int, offset: int = 0, dtype: mx.Dtype = mx.float32): - rinds = mx.arange(offset + N) - linds = mx.arange(offset, offset + N) if offset else rinds - mask = linds[:, None] < rinds[None] - mask = mask.astype(dtype) * -1e9 - return mask +def _relative_position_bucket( + relative_position, bidirectional=True, num_buckets=32, max_distance=128 +): + """ + Adapted from HF Tensorflow: + https://github.com/huggingface/transformers/blob/main/src/transformers/models/t5/modeling_t5.py + + Translate relative position to a bucket number for relative attention. The relative position is defined as + memory_position - query_position, i.e. the distance in tokens from the attending position to the attended-to + position. If bidirectional=False, then positive relative positions are invalid. We use smaller buckets for + small absolute relative_position and larger buckets for larger absolute relative_positions. All relative + positions >=max_distance map to the same bucket. All relative positions <=-max_distance map to the same bucket. + This should allow for more graceful generalization to longer sequences than the model has been trained on + + Args: + relative_position: an int32 Tensor + bidirectional: a boolean - whether the attention is bidirectional + num_buckets: an integer + max_distance: an integer + + Returns: + a Tensor with the same shape as relative_position, containing int32 values in the range [0, num_buckets) + """ + relative_buckets = 0 + if bidirectional: + num_buckets //= 2 + relative_buckets += (relative_position > 0).astype(mx.int16) * num_buckets + relative_position = mx.abs(relative_position) + else: + relative_position = -mx.minimum( + relative_position, mx.zeros_like(relative_position) + ) + # now relative_position is in the range [0, inf) + + # half of the buckets are for exact increments in positions + max_exact = num_buckets // 2 + is_small = relative_position < max_exact + + # The other half of the buckets are for logarithmically bigger bins in positions up to max_distance + scale = (num_buckets - max_exact) / np.log(max_distance / max_exact) + relative_position_if_large = max_exact + ( + mx.log(relative_position.astype(mx.float32) / max_exact) * scale + ).astype(mx.int16) + relative_position_if_large = mx.minimum(relative_position_if_large, num_buckets - 1) + relative_buckets += mx.where( + is_small, relative_position, relative_position_if_large + ) + return relative_buckets + + +class RelativePositionBias(nn.Module): + def __init__(self, config: T5Config, bidirectional: bool): + self.bidirectional = bidirectional + self.num_buckets = config.relative_attention_num_buckets + self.max_distance = config.relative_attention_max_distance + self.n_heads = config.num_heads + self.embeddings = nn.Embedding( + config.relative_attention_num_buckets, config.num_heads + ) + + def __call__(self, query_length: int, key_length: int, offset: int = 0): + """Compute binned relative position bias""" + context_position = mx.arange(offset, query_length)[:, None] + memory_position = mx.arange(key_length)[None, :] + + # shape (query_length, key_length) + relative_position = memory_position - context_position + relative_position_bucket = _relative_position_bucket( + relative_position, + bidirectional=self.bidirectional, + num_buckets=self.num_buckets, + max_distance=self.max_distance, + ) + + # shape (query_length, key_length, num_heads) + values = self.embeddings(relative_position_bucket) + + # shape (num_heads, query_length, key_length) + return values.transpose(2, 0, 1) + + +class MultiHeadAttention(nn.Module): + def __init__(self, config: T5Config): + super().__init__() + inner_dim = config.d_kv * config.num_heads + self.num_heads = config.num_heads + self.query_proj = nn.Linear(config.d_model, inner_dim, bias=False) + self.key_proj = nn.Linear(config.d_model, inner_dim, bias=False) + self.value_proj = nn.Linear(config.d_model, inner_dim, bias=False) + self.out_proj = nn.Linear(inner_dim, config.d_model, bias=False) + + def __call__( + self, + queries: mx.array, + keys: mx.array, + values: mx.array, + mask: Optional[mx.array], + cache: Optional[Tuple[mx.array, mx.array]] = None, + ) -> [mx.array, Tuple[mx.array, mx.array]]: + queries = self.query_proj(queries) + keys = self.key_proj(keys) + values = self.value_proj(values) + + num_heads = self.num_heads + B, L, _ = queries.shape + _, S, _ = keys.shape + queries = queries.reshape(B, L, num_heads, -1).transpose(0, 2, 1, 3) + keys = keys.reshape(B, S, num_heads, -1).transpose(0, 2, 1, 3) + values = values.reshape(B, S, num_heads, -1).transpose(0, 2, 1, 3) + + if cache is not None: + key_cache, value_cache = cache + keys = mx.concatenate([key_cache, keys], axis=2) + values = mx.concatenate([value_cache, values], axis=2) + + # Dimensions are [batch x num heads x sequence x hidden dim] + scores = queries @ keys.transpose(0, 1, 3, 2) + if mask is not None: + scores = scores + mask.astype(scores.dtype) + + scores = mx.softmax(scores.astype(mx.float32), axis=-1).astype(scores.dtype) + values_hat = (scores @ values).transpose(0, 2, 1, 3).reshape(B, L, -1) + return self.out_proj(values_hat), (keys, values) class RMSNorm(nn.Module): @@ -24,177 +142,200 @@ class RMSNorm(nn.Module): return x * mx.rsqrt(x.square().mean(-1, keepdims=True) + self.eps) def __call__(self, x): - output = self._norm(x.astype(mx.float32)).astype(x.dtype) + t = x.dtype + output = self._norm(x).astype(t) return self.weight * output -class Attention(nn.Module): - def __init__(self, config: LlamaConfig): +class DenseActivation(nn.Module): + def __init__(self, config: T5Config): super().__init__() - self.config = config - - self.n_heads: int = config.num_attention_heads - self.n_kv_heads: int = config.num_key_value_heads - self.repeats = self.n_heads // self.n_kv_heads - self.head_dim = config.hidden_size // self.n_heads - self.scale = self.head_dim**-0.5 - - self.q_proj = nn.Linear(config.hidden_size, config.hidden_size, bias=False) - self.k_proj = nn.Linear( - config.hidden_size, config.hidden_size // self.repeats, bias=False - ) - self.v_proj = nn.Linear( - config.hidden_size, config.hidden_size // self.repeats, bias=False - ) - self.o_proj = nn.Linear(config.hidden_size, config.hidden_size, bias=False) - self.rope = nn.RoPE(self.head_dim, traditional=False) - - def __call__( - self, - x: mx.array, - mask: Optional[mx.array] = None, - cache: Optional[Tuple[mx.array, mx.array]] = None, - ) -> mx.array: - B, L, D = x.shape - - queries, keys, values = self.q_proj(x), self.k_proj(x), self.v_proj(x) - - # Prepare the queries, keys and values for the attention computation - queries = queries.reshape(B, L, self.n_heads, -1).transpose(0, 2, 1, 3) - keys = keys.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3) - values = values.reshape(B, L, self.n_kv_heads, -1).transpose( - 0, 2, 1, 3 - ) # B, n_kv_heads, L, head_dim - - def repeat(a): - a = mx.concatenate([mx.expand_dims(a, 2)] * self.repeats, axis=2) - kv_size = a.shape[-1] - return a.reshape([B, self.n_heads, -1, kv_size]) - - if cache is not None: - key_cache, value_cache = cache - queries = self.rope(queries, offset=key_cache.shape[2]) - keys = self.rope(keys, offset=key_cache.shape[2]) - keys = mx.concatenate([key_cache, keys], axis=2) - values = mx.concatenate([value_cache, values], axis=2) + mlp_dims = config.d_ff or config.d_model * 4 + self.gated = config.feed_forward_proj.startswith("gated") + if self.gated: + self.wi_0 = nn.Linear(config.d_model, mlp_dims, bias=False) + self.wi_1 = nn.Linear(config.d_model, mlp_dims, bias=False) else: - queries = self.rope(queries) - keys = self.rope(keys) + self.wi = nn.Linear(config.d_model, mlp_dims, bias=False) + self.wo = nn.Linear(mlp_dims, config.d_model, bias=False) + activation = config.feed_forward_proj.removeprefix("gated-") + if activation == "relu": + self.act = nn.relu + elif activation == "gelu": + self.act = nn.gelu + elif activation == "silu": + self.act = nn.silu + else: + raise ValueError(f"Unknown activation: {activation}") - scores = (queries * self.scale) @ repeat(keys).transpose(0, 1, 3, 2) - if mask is not None: - scores += mask - scores = mx.softmax(scores.astype(mx.float32), axis=-1).astype(scores.dtype) - output = (scores @ repeat(values)).transpose(0, 2, 1, 3).reshape(B, L, -1) - return self.o_proj(output), (keys, values) + def __call__(self, x): + if self.gated: + hidden_act = self.act(self.wi_0(x)) + hidden_linear = self.wi_1(x) + x = hidden_act * hidden_linear + else: + x = self.act(self.wi(x)) + return self.wo(x) -class FeedForward(nn.Module): - def __init__(self, config: LlamaConfig): +class TransformerEncoderLayer(nn.Module): + def __init__(self, config: T5Config): super().__init__() - self.gate_proj = nn.Linear( - config.hidden_size, config.intermediate_size, bias=False - ) - self.down_proj = nn.Linear( - config.intermediate_size, config.hidden_size, bias=False - ) - self.up_proj = nn.Linear( - config.hidden_size, config.intermediate_size, bias=False - ) + self.attention = MultiHeadAttention(config) + self.ln1 = RMSNorm(config.d_model, eps=config.layer_norm_epsilon) + self.ln2 = RMSNorm(config.d_model, eps=config.layer_norm_epsilon) + self.dense = DenseActivation(config) - def __call__(self, x) -> mx.array: - return self.down_proj(nn.silu(self.gate_proj(x)) * self.up_proj(x)) + def __call__(self, x, mask): + y = self.ln1(x) + y, _ = self.attention(y, y, y, mask=mask) + x = x + y + + y = self.ln2(x) + y = self.dense(y) + return x + y -class TransformerBlock(nn.Module): - def __init__(self, config: LlamaConfig): +class TransformerEncoder(nn.Module): + def __init__(self, config: T5Config): super().__init__() - self.n_heads = config.num_attention_heads - self.dim = config.hidden_size - self.self_attn = Attention(config=config) - self.mlp = FeedForward(config=config) - self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) - self.post_attention_layernorm = RMSNorm( - config.hidden_size, eps=config.rms_norm_eps - ) + self.layers = [ + TransformerEncoderLayer(config) for i in range(config.num_layers) + ] + self.ln = RMSNorm(config.d_model, eps=config.layer_norm_epsilon) + self.relative_attention_bias = RelativePositionBias(config, bidirectional=True) + + def __call__(self, x: mx.array): + pos_bias = self.relative_attention_bias(x.shape[1], x.shape[1]) + for layer in self.layers: + x = layer(x, mask=pos_bias) + return self.ln(x) + + +class TransformerDecoderLayer(nn.Module): + def __init__(self, config: T5Config): + super().__init__() + self.self_attention = MultiHeadAttention(config) + self.cross_attention = MultiHeadAttention(config) + self.ln1 = RMSNorm(config.d_model, eps=config.layer_norm_epsilon) + self.ln2 = RMSNorm(config.d_model, eps=config.layer_norm_epsilon) + self.ln3 = RMSNorm(config.d_model, eps=config.layer_norm_epsilon) + self.dense = DenseActivation(config) def __call__( self, x: mx.array, - mask: Optional[mx.array] = None, - cache: Optional[Tuple[mx.array, mx.array]] = None, - ) -> mx.array: - r, cache = self.self_attn(self.input_layernorm(x), mask, cache) - h = x + r - r = self.mlp(self.post_attention_layernorm(h)) - out = h + r - return out, cache + memory: mx.array, + mask: mx.array, + memory_mask: mx.array, + cache: Optional[List[Tuple[mx.array, mx.array]]] = None, + ): + y = self.ln1(x) + y, cache = self.self_attention(y, y, y, mask, cache) + x = x + y + + y = self.ln2(x) + y, _ = self.cross_attention(y, memory, memory, memory_mask) + x = x + y + + y = self.ln3(x) + y = self.dense(y) + x = x + y + + return x, cache -class Llama(nn.Module): - def __init__(self, config: LlamaConfig): +def create_additive_causal_mask(N: int, offset: int = 0): + rinds = mx.arange(offset + N) + linds = mx.arange(offset, offset + N) if offset else rinds + mask = linds[:, None] < rinds[None] + return mask * -1e9 + + +class TransformerDecoder(nn.Module): + def __init__(self, config: T5Config): super().__init__() - self.config = config - self.vocab_size = config.vocab_size - self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size) - self.layers = [ - TransformerBlock(config=config) for _ in range(config.num_hidden_layers) - ] - self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) - self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + n_layers = getattr(config, "num_decoder_layers", config.num_layers) + self.layers = [TransformerDecoderLayer(config) for i in range(n_layers)] + self.ln = RMSNorm(config.d_model, eps=config.layer_norm_epsilon) + self.relative_attention_bias = RelativePositionBias(config, bidirectional=False) + + def __call__(self, x, memory, cache=None): + if cache[0] is not None: + offset = cache[0][0].shape[2] + else: + offset = 0 + + T = x.shape[1] + if T > 1: + mask = create_additive_causal_mask(T, offset) + else: + mask = None + + pos_bias = self.relative_attention_bias(T + offset, T + offset, offset=offset) + if mask is not None: + mask += pos_bias + else: + mask = pos_bias + + for e, layer in enumerate(self.layers): + x, cache[e] = layer(x, memory, mask, None, cache=cache[e]) + x = self.ln(x) + + return x, cache + + +class OutputHead(nn.Module): + def __init__(self, config: T5Config): + self.linear = nn.Linear(config.d_model, config.vocab_size, bias=False) + + def __call__(self, inputs): + return self.linear(inputs) + + +class Model(nn.Module): + def __init__(self, config: T5Config): + self.wte = nn.Embedding(config.vocab_size, config.d_model) + self.encoder = TransformerEncoder(config) + self.decoder = TransformerDecoder(config) + self.tie_word_embeddings = config.tie_word_embeddings + if not self.tie_word_embeddings: + self.lm_head = OutputHead(config) + self.model_dim = config.d_model self.reset_cache() + def encode(self, inputs: mx.array): + return self.encoder(self.wte(inputs)) + def truncate_cache(self, num_to_truncate): if num_to_truncate <= 0: return - cache_length = self.kv_cache[0][0].shape[2] + cache_length = self.cache[0][0].shape[2] if num_to_truncate < cache_length: - self.kv_cache = tree_map( - lambda x: x[:, :, :-num_to_truncate, :], self.kv_cache - ) + self.cache = tree_map(lambda x: x[:, :, :-num_to_truncate, :], self.cache) else: self.reset_cache() def reset_cache(self): - self.kv_cache = [None] * len(self.layers) + self.cache = [None] * len(self.decoder.layers) + + def decode( + self, + inputs: mx.array, + memory: mx.array, + ): + inputs = self.wte(inputs) + y, self.cache = self.decoder(inputs, memory=memory, cache=self.cache) + if not self.tie_word_embeddings: + y *= self.model_dim**-0.5 + y = self.lm_head(y) + else: + y = y @ self.wte.weight.T + return y def __call__( self, - x: mx.array, - next_tokens: int = -1, + inputs: mx.array, + decoder_inputs: mx.array, ): - if self.kv_cache[0]: - offset = self.kv_cache[0][0].shape[-2] - else: - offset = 0 - - if x.shape[1] > 1: - mask = create_additive_causal_mask(x.shape[1], offset) - mask = mask.astype(self.embed_tokens.weight.dtype) - else: - mask = None - - x = self.embed_tokens(x) - for idx, layer in enumerate(self.layers): - x, self.kv_cache[idx] = layer(x, mask, cache=self.kv_cache[idx]) - - if next_tokens > 0: - x = x[:, -next_tokens:] - - x = self.norm(x) - return self.lm_head(x) - - @classmethod - def from_hugging_face(cls, model_path: str, quantized: bool = True): - config = LlamaConfig.from_pretrained(model_path) - torch_weights = AutoModelForCausalLM.from_pretrained(model_path).state_dict() - weights = { - k.replace("model.", ""): mx.array(v.numpy(), mx.float16) - for k, v in torch_weights.items() - } - model = cls(config) - model.update(tree_unflatten(list(weights.items()))) - # if quantization is not None: - # nn.QuantizedLinear.quantize_module(model, **quantization) - mx.eval(model.parameters()) - return model + return self.decode(decoder_inputs, self.encode(inputs))[0] diff --git a/llms/speculative_decoding/requirements.txt b/llms/speculative_decoding/requirements.txt index ef977a3d..78e7a889 100644 --- a/llms/speculative_decoding/requirements.txt +++ b/llms/speculative_decoding/requirements.txt @@ -1,3 +1,4 @@ -mlx>=0.0.5 +mlx>=0.0.6 transformers -numpy \ No newline at end of file +numpy +accelerate diff --git a/t5/t5.py b/t5/t5.py index 2acd39b4..3812393c 100644 --- a/t5/t5.py +++ b/t5/t5.py @@ -125,7 +125,6 @@ class MultiHeadAttention(nn.Module): values = mx.concatenate([value_cache, values], axis=2) # Dimensions are [batch x num heads x sequence x hidden dim] - queries = queries scores = queries @ keys if mask is not None: scores = scores + mask.astype(scores.dtype)