add speculative decoding example for llama (#149)

* speculative decoding

* add sample 0

* spec decode gives same results as regular decode

* rebase

* use accept reject criteria

* switch to t5

* update readme

* readme nit

* nits

* nits

* nits

---------

Co-authored-by: Benjamin Anderson <benjamin@Benjamins-MBP.lan>
Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
Benjamin Anderson 2023-12-28 17:20:43 -06:00 committed by GitHub
parent 07c163d9d9
commit 09566c7257
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 775 additions and 1 deletions

View File

@ -0,0 +1,66 @@
# Speculative Decoding
This example implements speculative decoding with the T5 model for text
generation.[^1][^2] Speculative decoding uses a smaller draft model to propose
several tokens, and a larger model to decide which tokens to accept. The
distribution of the generated text is identical to what the larger model would
produce on its own, but with far fewer forward passes of the large model since
it can evaluate the draft tokens in parallel.
### Setup
First, install the requirements:
```
cd speculative_decoding
pip install -r requirements.txt
```
Then convert the model and the draft model. We'll use T5-XXL (11B parameters)
for the main model. Convert it with:
```
python convert.py --model t5-11b
```
We'll use T5-small for the draft model. Convert it with:
```
python convert.py --model t5-small
```
### Run
You can run with the default arguments:
```
python main.py
```
To see a full list of options use:
```
python main.py --help
```
### Notes
Speculative decoding works well when most of the tokens from the draft model
are accepted by the larger model. That's more likely to happen if the models
are trained on similar data.
One way to increase the chance of accepting a draft token is with the parameter
`--delta`. This parameter can be in the range $[0, 1]$. If it is $1$ then all
the draft tokens will be accepted by the model. If it is $0$, then only draft
tokens which match the original acceptance criterion are kept.[^1] Values
closer to $1$ increase the chance that a draft token is accepted.
Conversely, the fewer draft tokens accepted by the main model, the more
expensive speculative decoding is. You can use `--num-draft` to tune the number
of draft tokens per model evaluation in order to reduce the number of discarded
draft tokens. Decreasing `--num-draft` will decrease the number of discarded
draft tokens at the expense of more large model evaluations.
[^1]: See the paper [Fast Inference from Transformers via Speculative
Decoding](https://arxiv.org/abs/2211.17192)
[^2]: For more information on T5 see the [original paper](https://arxiv.org/abs/1910.10683)
or the [Hugging Face page](https://huggingface.co/docs/transformers/model_doc/t5).

View File

@ -0,0 +1,75 @@
import numpy as np
from transformers import T5ForConditionalGeneration
SHARED_REPLACEMENT_PATTERNS = [
(".block.", ".layers."),
(".k.", ".key_proj."),
(".o.", ".out_proj."),
(".q.", ".query_proj."),
(".v.", ".value_proj."),
("shared.", "wte."),
("lm_head.", "lm_head.linear."),
(".layer.0.layer_norm.", ".ln1."),
(".layer.1.layer_norm.", ".ln2."),
(".layer.2.layer_norm.", ".ln3."),
(".final_layer_norm.", ".ln."),
(
"layers.0.layer.0.SelfAttention.relative_attention_bias.",
"relative_attention_bias.embeddings.",
),
]
ENCODER_REPLACEMENT_PATTERNS = [
(".layer.0.SelfAttention.", ".attention."),
(".layer.1.DenseReluDense.", ".dense."),
]
DECODER_REPLACEMENT_PATTERNS = [
(".layer.0.SelfAttention.", ".self_attention."),
(".layer.1.EncDecAttention.", ".cross_attention."),
(".layer.2.DenseReluDense.", ".dense."),
]
def replace_key(key: str) -> str:
for old, new in SHARED_REPLACEMENT_PATTERNS:
key = key.replace(old, new)
if key.startswith("encoder."):
for old, new in ENCODER_REPLACEMENT_PATTERNS:
key = key.replace(old, new)
elif key.startswith("decoder."):
for old, new in DECODER_REPLACEMENT_PATTERNS:
key = key.replace(old, new)
return key
def convert(model_name, dtype):
dtype = getattr(np, dtype)
model = T5ForConditionalGeneration.from_pretrained(model_name, torch_dtype="auto")
weights = {
replace_key(k): v.numpy().astype(dtype) for k, v in model.state_dict().items()
}
file_name = model_name.replace("/", "-")
print(f"Saving weights to {file_name}.npz")
np.savez(f"{file_name}.npz", **weights)
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser(description="Convert T5 weights to MLX")
parser.add_argument(
"--model",
type=str,
help="Name of the T5 model.",
default="t5-small",
)
parser.add_argument(
"--dtype",
help="The model data type.",
type=str,
choices=["float16", "float32"],
default="float32",
)
args = parser.parse_args()
convert(args.model, args.dtype)

View File

@ -0,0 +1,191 @@
from dataclasses import dataclass, field
from typing import List, Optional
import mlx.core as mx
import mlx.nn as nn
import numpy as np
import transformers
from model import Model
class Tokenizer:
def __init__(self, model_name: str):
self._tokenizer = transformers.AutoTokenizer.from_pretrained(
model_name,
legacy=False,
model_max_length=512,
)
self._decoder_start_id = 0
@property
def eos_id(self) -> int:
return self._tokenizer.eos_token_id
@property
def decoder_start_id(self) -> int:
return self._decoder_start_id
def encode(self, s: str) -> mx.array:
return mx.array(
self._tokenizer(s, return_tensors="np", return_attention_mask=False,)[
"input_ids"
].squeeze(0)
)
def decode(self, t: List[int]) -> str:
return self._tokenizer.decode(t)
class SpeculativeDecoder:
def __init__(
self,
model: Model,
draft_model: Model,
tokenizer: str,
num_draft: int = 5,
delta: float = 0.0,
):
self.tokenizer = Tokenizer(tokenizer)
self.model = model
self.draft_model = draft_model
self.num_draft = num_draft
self.delta = delta
def _generate(
self,
x: mx.array,
memory: mx.array,
draft: bool = False,
):
model = self.draft_model if draft else self.model
while True:
logits = model.decode(x[None], memory)[0, -1]
x = mx.argmax(logits, keepdims=True)
lognorm = mx.logsumexp(logits.astype(mx.float32))
logprob = logits[x] - lognorm
yield x, logprob
def generate(
self,
prompt,
max_tokens: int = 100,
):
memory = self.model.encode(self.tokenizer.encode(prompt)[None])
x = mx.array([self.tokenizer.decoder_start_id])
skip = 0
outputs = []
for (token, _), n in zip(self._generate(x, memory), range(max_tokens)):
if token == self.tokenizer.eos_id:
break
outputs.append(token.item())
if (n + 1) % 10 == 0:
str_output = self.tokenizer.decode(outputs)
print(str_output[skip:], end="", flush=True)
skip = len(str_output)
print(self.tokenizer.decode(outputs)[skip:], end="", flush=True)
print()
self.model.reset_cache()
def _get_num_accept(self, draft_tokens, draft_probs, model_logits):
# accept_toks = mx.argmax(model_logits, axis=-1) == draft_tokens
model_probs = mx.take_along_axis(
model_logits,
draft_tokens[:, None],
axis=-1,
).squeeze(-1)
model_probs -= mx.logsumexp(model_logits.astype(mx.float32), axis=-1)
unis = mx.random.uniform(shape=(draft_tokens.size,))
log_unis = mx.log(mx.maximum(unis - self.delta, 0.0))
accept_toks = log_unis <= ((model_probs - draft_probs))
num_to_accept = (accept_toks.tolist() + [False]).index(False)
return num_to_accept
def speculative_decode(
self,
prompt,
max_tokens: int = 100,
):
def sample(logits):
return mx.argmax(logits, axis=-1)
prompt = mx.array(self.tokenizer.encode(prompt), mx.uint32)[None]
memory = self.model.encode(prompt)
draft_memory = self.draft_model.encode(prompt)
tokens = mx.array([self.tokenizer.decoder_start_id])
n_steps = 0
ntoks = 0
n_accepted = 0
n_draft = 0
outputs = []
skip = 0
draft_inputs = tokens
inputs = tokens
while True:
# For each decoding step: generate n tokens from a draft model
draft_tokens = []
draft_probs = []
for _, (t, p) in zip(
range(ntoks, min(ntoks + self.num_draft, max_tokens)),
self._generate(draft_inputs, draft_memory, draft=True),
):
draft_tokens.append(t)
draft_probs.append(p)
if t.item() == self.tokenizer.eos_id:
break
# Verify the draft tokens with the last verified token:
draft_tokens = mx.concatenate(draft_tokens)
draft_probs = mx.concatenate(draft_probs)
verify_tokens = mx.concatenate([inputs, draft_tokens])
logits = self.model.decode(
verify_tokens[None, :],
memory,
).squeeze(0)
# Only keep samples that match the draft:
num_to_accept = self._get_num_accept(
draft_tokens,
draft_probs,
logits[:-1],
)
new_tokens = draft_tokens[:num_to_accept]
# Get the next token from the main model as well
new_tokens = mx.concatenate(
[new_tokens, mx.argmax(logits[num_to_accept], keepdims=True)]
)
n_accepted += num_to_accept
n_draft += draft_tokens.size
# Rewind the cache for unaccepted tokens:
if (n := draft_tokens.size) > num_to_accept:
self.draft_model.truncate_cache(n - new_tokens.size)
self.model.truncate_cache(n - new_tokens.size + 1)
n_steps += 1
for t in new_tokens.tolist():
if t == self.tokenizer.eos_id or ntoks >= max_tokens:
break
outputs.append(t)
ntoks += 1
str_output = self.tokenizer.decode(outputs)
print(str_output[skip:], end="", flush=True)
skip = len(str_output)
if ntoks >= max_tokens or new_tokens[-1] == self.tokenizer.eos_id:
break
draft_inputs = new_tokens[max(new_tokens.size - 2, 0) :]
inputs = draft_inputs[-1:]
print(self.tokenizer.decode(outputs)[skip:], end="", flush=True)
print()
self.model.reset_cache()
self.draft_model.reset_cache()
return {"n_accepted": n_accepted, "n_draft": n_draft, "n_steps": n_steps}

View File

@ -0,0 +1,99 @@
import argparse
import glob
import json
import time
from pathlib import Path
import mlx.core as mx
import mlx.nn as nn
from decoder import SpeculativeDecoder
from mlx.utils import tree_unflatten
from model import Model
from transformers import T5Config
def load_model(model_name: str):
config = T5Config.from_pretrained(model_name)
model = Model(config)
weights = mx.load(f"{model_name}.npz")
weights = tree_unflatten(list(weights.items()))
model.update(weights)
mx.eval(model.parameters())
return model
def main(args):
mx.random.seed(args.seed)
spec_decoder = SpeculativeDecoder(
model=load_model(args.model_name),
draft_model=load_model(args.draft_model_name),
tokenizer=args.model_name,
delta=args.delta,
num_draft=args.num_draft,
)
tic = time.time()
print(args.prompt)
if args.regular_decode:
spec_decoder.generate(args.prompt, max_tokens=args.max_tokens)
else:
stats = spec_decoder.speculative_decode(args.prompt, max_tokens=args.max_tokens)
print("=" * 10)
print(f"Accepted {stats['n_accepted']} / {stats['n_draft']}.")
print(f"Decoding steps {stats['n_steps']}.")
toc = time.time()
print("=" * 10)
print(f"Full generation time {toc - tic:.3f}")
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Convert Llama weights to MLX")
parser.add_argument(
"--num-draft",
type=int,
default=5,
help="Number of draft tokens to use per decoding step.",
)
parser.add_argument(
"--model-name",
help="Name of the model.",
default="t5-small",
)
parser.add_argument(
"--draft-model-name",
help="Name of the draft model.",
default="t5-small",
)
parser.add_argument(
"--seed",
type=int,
default=0,
help="PRNG seed.",
)
parser.add_argument(
"--max-tokens",
"-m",
type=int,
default=100,
help="Maximum number of tokens to generate.",
)
parser.add_argument(
"--prompt",
default="translate English to French: Let's go to the store and buy some groceries including eggs, avocadoes, and bread.",
help="The prompt processed by the model.",
)
parser.add_argument(
"--delta",
type=float,
default=0.1,
help="Lenience for accepting the proposal tokens.",
)
parser.add_argument(
"--regular-decode",
action="store_true",
help="Use regular decoding instead of speculative decoding.",
)
args = parser.parse_args()
main(args)

View File

@ -0,0 +1,341 @@
from typing import List, Optional, Tuple
import mlx.core as mx
import mlx.nn as nn
import numpy as np
from mlx.utils import tree_map, tree_unflatten
from transformers import AutoTokenizer, T5Config
def _relative_position_bucket(
relative_position, bidirectional=True, num_buckets=32, max_distance=128
):
"""
Adapted from HF Tensorflow:
https://github.com/huggingface/transformers/blob/main/src/transformers/models/t5/modeling_t5.py
Translate relative position to a bucket number for relative attention. The relative position is defined as
memory_position - query_position, i.e. the distance in tokens from the attending position to the attended-to
position. If bidirectional=False, then positive relative positions are invalid. We use smaller buckets for
small absolute relative_position and larger buckets for larger absolute relative_positions. All relative
positions >=max_distance map to the same bucket. All relative positions <=-max_distance map to the same bucket.
This should allow for more graceful generalization to longer sequences than the model has been trained on
Args:
relative_position: an int32 Tensor
bidirectional: a boolean - whether the attention is bidirectional
num_buckets: an integer
max_distance: an integer
Returns:
a Tensor with the same shape as relative_position, containing int32 values in the range [0, num_buckets)
"""
relative_buckets = 0
if bidirectional:
num_buckets //= 2
relative_buckets += (relative_position > 0).astype(mx.int16) * num_buckets
relative_position = mx.abs(relative_position)
else:
relative_position = -mx.minimum(
relative_position, mx.zeros_like(relative_position)
)
# now relative_position is in the range [0, inf)
# half of the buckets are for exact increments in positions
max_exact = num_buckets // 2
is_small = relative_position < max_exact
# The other half of the buckets are for logarithmically bigger bins in positions up to max_distance
scale = (num_buckets - max_exact) / np.log(max_distance / max_exact)
relative_position_if_large = max_exact + (
mx.log(relative_position.astype(mx.float32) / max_exact) * scale
).astype(mx.int16)
relative_position_if_large = mx.minimum(relative_position_if_large, num_buckets - 1)
relative_buckets += mx.where(
is_small, relative_position, relative_position_if_large
)
return relative_buckets
class RelativePositionBias(nn.Module):
def __init__(self, config: T5Config, bidirectional: bool):
self.bidirectional = bidirectional
self.num_buckets = config.relative_attention_num_buckets
self.max_distance = config.relative_attention_max_distance
self.n_heads = config.num_heads
self.embeddings = nn.Embedding(
config.relative_attention_num_buckets, config.num_heads
)
def __call__(self, query_length: int, key_length: int, offset: int = 0):
"""Compute binned relative position bias"""
context_position = mx.arange(offset, query_length)[:, None]
memory_position = mx.arange(key_length)[None, :]
# shape (query_length, key_length)
relative_position = memory_position - context_position
relative_position_bucket = _relative_position_bucket(
relative_position,
bidirectional=self.bidirectional,
num_buckets=self.num_buckets,
max_distance=self.max_distance,
)
# shape (query_length, key_length, num_heads)
values = self.embeddings(relative_position_bucket)
# shape (num_heads, query_length, key_length)
return values.transpose(2, 0, 1)
class MultiHeadAttention(nn.Module):
def __init__(self, config: T5Config):
super().__init__()
inner_dim = config.d_kv * config.num_heads
self.num_heads = config.num_heads
self.query_proj = nn.Linear(config.d_model, inner_dim, bias=False)
self.key_proj = nn.Linear(config.d_model, inner_dim, bias=False)
self.value_proj = nn.Linear(config.d_model, inner_dim, bias=False)
self.out_proj = nn.Linear(inner_dim, config.d_model, bias=False)
def __call__(
self,
queries: mx.array,
keys: mx.array,
values: mx.array,
mask: Optional[mx.array],
cache: Optional[Tuple[mx.array, mx.array]] = None,
) -> [mx.array, Tuple[mx.array, mx.array]]:
queries = self.query_proj(queries)
keys = self.key_proj(keys)
values = self.value_proj(values)
num_heads = self.num_heads
B, L, _ = queries.shape
_, S, _ = keys.shape
queries = queries.reshape(B, L, num_heads, -1).transpose(0, 2, 1, 3)
keys = keys.reshape(B, S, num_heads, -1).transpose(0, 2, 1, 3)
values = values.reshape(B, S, num_heads, -1).transpose(0, 2, 1, 3)
if cache is not None:
key_cache, value_cache = cache
keys = mx.concatenate([key_cache, keys], axis=2)
values = mx.concatenate([value_cache, values], axis=2)
# Dimensions are [batch x num heads x sequence x hidden dim]
scores = queries @ keys.transpose(0, 1, 3, 2)
if mask is not None:
scores = scores + mask.astype(scores.dtype)
scores = mx.softmax(scores.astype(mx.float32), axis=-1).astype(scores.dtype)
values_hat = (scores @ values).transpose(0, 2, 1, 3).reshape(B, L, -1)
return self.out_proj(values_hat), (keys, values)
class RMSNorm(nn.Module):
def __init__(self, dims: int, eps: float = 1e-5):
super().__init__()
self.weight = mx.ones((dims,))
self.eps = eps
def _norm(self, x):
return x * mx.rsqrt(x.square().mean(-1, keepdims=True) + self.eps)
def __call__(self, x):
t = x.dtype
output = self._norm(x).astype(t)
return self.weight * output
class DenseActivation(nn.Module):
def __init__(self, config: T5Config):
super().__init__()
mlp_dims = config.d_ff or config.d_model * 4
self.gated = config.feed_forward_proj.startswith("gated")
if self.gated:
self.wi_0 = nn.Linear(config.d_model, mlp_dims, bias=False)
self.wi_1 = nn.Linear(config.d_model, mlp_dims, bias=False)
else:
self.wi = nn.Linear(config.d_model, mlp_dims, bias=False)
self.wo = nn.Linear(mlp_dims, config.d_model, bias=False)
activation = config.feed_forward_proj.removeprefix("gated-")
if activation == "relu":
self.act = nn.relu
elif activation == "gelu":
self.act = nn.gelu
elif activation == "silu":
self.act = nn.silu
else:
raise ValueError(f"Unknown activation: {activation}")
def __call__(self, x):
if self.gated:
hidden_act = self.act(self.wi_0(x))
hidden_linear = self.wi_1(x)
x = hidden_act * hidden_linear
else:
x = self.act(self.wi(x))
return self.wo(x)
class TransformerEncoderLayer(nn.Module):
def __init__(self, config: T5Config):
super().__init__()
self.attention = MultiHeadAttention(config)
self.ln1 = RMSNorm(config.d_model, eps=config.layer_norm_epsilon)
self.ln2 = RMSNorm(config.d_model, eps=config.layer_norm_epsilon)
self.dense = DenseActivation(config)
def __call__(self, x, mask):
y = self.ln1(x)
y, _ = self.attention(y, y, y, mask=mask)
x = x + y
y = self.ln2(x)
y = self.dense(y)
return x + y
class TransformerEncoder(nn.Module):
def __init__(self, config: T5Config):
super().__init__()
self.layers = [
TransformerEncoderLayer(config) for i in range(config.num_layers)
]
self.ln = RMSNorm(config.d_model, eps=config.layer_norm_epsilon)
self.relative_attention_bias = RelativePositionBias(config, bidirectional=True)
def __call__(self, x: mx.array):
pos_bias = self.relative_attention_bias(x.shape[1], x.shape[1])
for layer in self.layers:
x = layer(x, mask=pos_bias)
return self.ln(x)
class TransformerDecoderLayer(nn.Module):
def __init__(self, config: T5Config):
super().__init__()
self.self_attention = MultiHeadAttention(config)
self.cross_attention = MultiHeadAttention(config)
self.ln1 = RMSNorm(config.d_model, eps=config.layer_norm_epsilon)
self.ln2 = RMSNorm(config.d_model, eps=config.layer_norm_epsilon)
self.ln3 = RMSNorm(config.d_model, eps=config.layer_norm_epsilon)
self.dense = DenseActivation(config)
def __call__(
self,
x: mx.array,
memory: mx.array,
mask: mx.array,
memory_mask: mx.array,
cache: Optional[List[Tuple[mx.array, mx.array]]] = None,
):
y = self.ln1(x)
y, cache = self.self_attention(y, y, y, mask, cache)
x = x + y
y = self.ln2(x)
y, _ = self.cross_attention(y, memory, memory, memory_mask)
x = x + y
y = self.ln3(x)
y = self.dense(y)
x = x + y
return x, cache
def create_additive_causal_mask(N: int, offset: int = 0):
rinds = mx.arange(offset + N)
linds = mx.arange(offset, offset + N) if offset else rinds
mask = linds[:, None] < rinds[None]
return mask * -1e9
class TransformerDecoder(nn.Module):
def __init__(self, config: T5Config):
super().__init__()
n_layers = getattr(config, "num_decoder_layers", config.num_layers)
self.layers = [TransformerDecoderLayer(config) for i in range(n_layers)]
self.ln = RMSNorm(config.d_model, eps=config.layer_norm_epsilon)
self.relative_attention_bias = RelativePositionBias(config, bidirectional=False)
def __call__(self, x, memory, cache=None):
if cache[0] is not None:
offset = cache[0][0].shape[2]
else:
offset = 0
T = x.shape[1]
if T > 1:
mask = create_additive_causal_mask(T, offset)
else:
mask = None
pos_bias = self.relative_attention_bias(T + offset, T + offset, offset=offset)
if mask is not None:
mask += pos_bias
else:
mask = pos_bias
for e, layer in enumerate(self.layers):
x, cache[e] = layer(x, memory, mask, None, cache=cache[e])
x = self.ln(x)
return x, cache
class OutputHead(nn.Module):
def __init__(self, config: T5Config):
self.linear = nn.Linear(config.d_model, config.vocab_size, bias=False)
def __call__(self, inputs):
return self.linear(inputs)
class Model(nn.Module):
def __init__(self, config: T5Config):
self.wte = nn.Embedding(config.vocab_size, config.d_model)
self.encoder = TransformerEncoder(config)
self.decoder = TransformerDecoder(config)
self.tie_word_embeddings = config.tie_word_embeddings
if not self.tie_word_embeddings:
self.lm_head = OutputHead(config)
self.model_dim = config.d_model
self.reset_cache()
def encode(self, inputs: mx.array):
return self.encoder(self.wte(inputs))
def truncate_cache(self, num_to_truncate):
if num_to_truncate <= 0:
return
cache_length = self.cache[0][0].shape[2]
if num_to_truncate < cache_length:
self.cache = tree_map(lambda x: x[:, :, :-num_to_truncate, :], self.cache)
else:
self.reset_cache()
def reset_cache(self):
self.cache = [None] * len(self.decoder.layers)
def decode(
self,
inputs: mx.array,
memory: mx.array,
):
inputs = self.wte(inputs)
y, self.cache = self.decoder(inputs, memory=memory, cache=self.cache)
if not self.tie_word_embeddings:
y *= self.model_dim**-0.5
y = self.lm_head(y)
else:
y = y @ self.wte.weight.T
return y
def __call__(
self,
inputs: mx.array,
decoder_inputs: mx.array,
):
return self.decode(decoder_inputs, self.encode(inputs))[0]

View File

@ -0,0 +1,3 @@
mlx>=0.0.6
transformers
numpy

View File

@ -125,7 +125,6 @@ class MultiHeadAttention(nn.Module):
values = mx.concatenate([value_cache, values], axis=2) values = mx.concatenate([value_cache, values], axis=2)
# Dimensions are [batch x num heads x sequence x hidden dim] # Dimensions are [batch x num heads x sequence x hidden dim]
queries = queries
scores = queries @ keys scores = queries @ keys
if mask is not None: if mask is not None:
scores = scores + mask.astype(scores.dtype) scores = scores + mask.astype(scores.dtype)