spec decode gives same results as regular decode

This commit is contained in:
Awni Hannun
2023-12-27 21:16:07 -08:00
parent 761e61480e
commit 19ecb00bce
6 changed files with 242 additions and 303 deletions

View File

@@ -1,16 +1,32 @@
## Speculative Decoding with MLX
# Speculative Decoding
This example implements [speculative decoding](https://arxiv.org/abs/2211.17192), which allows you to use a smaller draft model to predict several tokens, and then a larger verification model to check them all in parallel. The results are output that is identical to what the larger model would produce, but with far fewer forward passes (as long as the reference model is good enough at guessing).
This example implements [speculative decoding] for text generation.[^1].
Speculative decoding uses a smaller draft model to propose several tokens, and
then a larger model which decides which tokens to accept. The generated text is
identical to what the larger model would produce on its own, but with far fewer
forward passes of the large model since it can evaluate the draft tokens in
parallel.
### Setup
First, install the requirements:
Install the requirements and then you can try it out:
```
cd speculative_decoding
pip install -r requirements.txt
python test.py
```
In order for that to happen, it's generally good if the models are trained on similar data, with a similar chat template, etc. For example, you could use Meta's 7B Llama as a draft model for the 13B Llama. In my tests, I've mostly used TinyLlama as a draft model for Llama-7B. The chat versions of TinyLlama and Llama-7B-Chat are trained with different templates, but it works OK. Alternatively, you can use base models, and a prompt to make the model act like a chat model (e.g. [URIAL](https://arxiv.org/abs/2312.01552)).
### Run
I believe the implementation is *correct* (it produces near-identical output with regular generation vs. speculative decoding, and when speculative decoding is enabled, the draft model does correctly predict many tokens). However, it assumes a batch size of 1 at the moment (I'm not actually sure how to handle batching where some drafts might have more correct tokens than others). Also I feel like it could be faster!
You can run with the default arguments:
Before merging this in, I would appreciate some help understanding how to make this faster and optimizing the performance so it's actually useful!
```
python main.py
```
Speculative decoding works well when most of the tokens from the draft model
are accepted by the larger model. That's more likely to happen if the models
are trained on similar data. The default setting in this example uses TinyLlama
as a draft morel for Llama 7B.
[^1] See the paper [Fast Inference from Transformers via Speculative Decoding]((https://arxiv.org/abs/2211.17192)

View File

@@ -0,0 +1,160 @@
import transformers
from dataclasses import dataclass, field
from model import Llama
import mlx.core as mx
import mlx.nn as nn
import time
import numpy as np
from typing import List, Optional
from prompts import create_urial_prompt
class Tokenizer:
def __init__(self, model_name: str):
self._tokenizer = transformers.AutoTokenizer.from_pretrained(model_name)
@property
def eos_id(self) -> int:
return self._tokenizer.eos_token_id
def encode(self, s: str) -> mx.array:
return mx.array(
self._tokenizer(s, return_tensors="np", return_attention_mask=False,)[
"input_ids"
].squeeze(0)
)
def decode(self, t: List[int], with_sep: bool = True) -> str:
tokens = self._tokenizer.convert_ids_to_tokens(t)
return "".join(t.replace("", " " if with_sep else "") for t in tokens)
class SpeculativeDecoder:
def __init__(self, model: str, draft_model: str = None):
self.tokenizer = Tokenizer(model)
self.model = Llama.from_hugging_face(model)
if draft_model is not None:
self.draft_model = Llama.from_hugging_face(draft_model)
def tokenize(self, prompt):
# if self.tokenizer.chat_template is not None:
# tokenized = self.tokenizer.apply_chat_template(
# prompt, tokenize=True, add_generation_prompt=True
# )
# else:
# use urial zero-shot template
tokenized = self.tokenizer.encode(create_urial_prompt(prompt["content"]))
return tokenized
def _generate(
self,
x: mx.array,
temp: float = 0.0,
draft: bool = False,
):
model = self.draft_model if draft else self.model
def sample(logits):
if temp == 0:
return mx.argmax(logits, axis=-1)
else:
return mx.random.categorical(logits * (1 / temp))
while True:
logit = model(x[None, :], next_token_only=True)
x = sample(logit)
yield x
def generate(
self,
prompt,
max_tokens: int = 100,
temp: float = 0.0,
draft: bool = False,
):
x = self.tokenize(prompt)
start = time.time()
for token, n in zip(self._generate(x, temp, draft=draft), range(max_tokens)):
token = token.item()
if token == self.tokenizer.eos_id:
break
print(self.tokenizer.decode(token, with_sep=n > 0), end="", flush=True)
run_time = time.time() - start
print()
print(f"=== GENERATED {n + 1} TOKENS in {run_time} SECONDS ===")
if draft:
self.draft_model.reset_cache()
else:
self.model.reset_cache()
def speculative_decode(
self, prompt, max_tokens: int = 100, temp: float = 0.0, n_draft: int = 5
):
def sample(logits):
if temp == 0:
return mx.argmax(logits, axis=-1)
else:
return mx.random.categorical(logits * (1 / temp))
tokens = mx.array(self.tokenize(prompt), mx.uint32)
start = time.time()
decoding_steps = 0
ntoks = 0
accepted_draft_tokens = 0
while True:
# For each decoding step: generate n tokens from a draft model
draft_tokens = []
for _, t in zip(
range(ntoks, min(ntoks + n_draft, max_tokens)),
self._generate(tokens, temp=temp, draft=True),
):
draft_tokens.append(t)
if t.item() == self.tokenizer.eos_id:
break
# Verify the draft tokens with the last verified token
draft_tokens = mx.concatenate(draft_tokens)
verify_tokens = mx.concatenate([tokens, draft_tokens])
logits = self.model(verify_tokens[None, :-1])
sampled = sample(logits[:, -draft_tokens.size :]).squeeze(0)
# Only keep samples that match the draft:
equal_toks = sampled == draft_tokens
num_to_accept = (equal_toks.tolist() + [False]).index(False)
new_tokens = sampled[: max(1, num_to_accept)]
accepted_draft_tokens += num_to_accept
# Rewind the cache for unaccepted tokens:
if (n := draft_tokens.size) > num_to_accept:
self.draft_model.truncate_cache(n - new_tokens.size)
self.model.truncate_cache(n - new_tokens.size)
decoding_steps += 1
# Check stop decodig criteria:
for t in new_tokens.tolist():
if t == self.tokenizer.eos_id:
break
print(self.tokenizer.decode(t, with_sep=ntoks > 0), end="", flush=True)
ntoks += new_tokens.size
if ntoks >= max_tokens or new_tokens[-1] == self.tokenizer.eos_id:
break
tokens = new_tokens[-1:]
end = time.time()
self.model.reset_cache()
self.draft_model.reset_cache()
print()
print(
"=== GENERATED",
ntoks,
"TOKENS IN",
round(end - start, 2),
"SECONDS ===",
)
print("=== ACCEPTED", accepted_draft_tokens, "DRAFT TOKENS ===")
print("=== DECODING STEPS", decoding_steps, "===")

View File

@@ -1,208 +0,0 @@
import transformers
from dataclasses import dataclass, field
from model import Llama
import mlx.core as mx
import mlx.nn as nn
import time
import numpy as np
from typing import Optional
from prompts import create_urial_prompt
def clone(x: mx.array):
return mx.array(np.array(x))
@dataclass
class LlamaEngine:
model: str # path to HuggingFace repo
draft_model: Optional[str] = None # path to draft model
tokenizer: transformers.AutoTokenizer = field(init=False)
def __post_init__(self):
self.tokenizer = transformers.AutoTokenizer.from_pretrained(self.model)
self.model = Llama.from_hugging_face(self.model)
if self.draft_model is not None:
self.draft_model = Llama.from_hugging_face(self.draft_model)
def tokenize(self, messages):
if self.tokenizer.chat_template is not None:
tokenized = self.tokenizer.apply_chat_template(
messages, tokenize=True, add_generation_prompt=True
)
else:
# use urial zero-shot template
tokenized = self.tokenizer.encode(
create_urial_prompt(messages[0]["content"])
)
return tokenized
def generate(
self,
messages,
num_tokens: int = 100,
temp: float = 0.8,
draft_model: bool = False, # if true gen with draft model
):
tokenized = self.tokenize(messages)
x = mx.array([tokenized])
tokens = []
start = time.time()
for token in self.model.generate(x, temp):
if token.item() == 2:
break
tokens.append(token)
if len(tokens) >= num_tokens:
break
run_time = time.time() - start
tokens = [t.item() for t in tokens]
s = self.tokenizer.decode(tokens)
# print("=== COMPLETION ===")
# print(s)
print(f"=== GENERATED {len(tokens)} TOKENS in {run_time} SECONDS ===")
return s
# generate only supports batch size 1, so should this
def speculative_decode(
self, messages, num_tokens: int = 100, temp: float = 0.8, n: int = 5
):
def sample(logits):
if temp == 0:
return mx.argmax(logits, axis=-1)
else:
return mx.random.categorical(logits * (1 / temp))
batch_size = 1
tokenized = self.tokenize(messages)
tokens = mx.array([tokenized])
prompt_len = tokens.shape[1]
start = time.time()
# prefill the main model
# sample first token & write draft from there (avoids rewinding main model which i was doing before)
logit = self.model(
tokens, read_cache=False, write_cache=True, next_token_only=True
)
first_token = sample(logit).reshape(batch_size, 1)
tokens = mx.concatenate([tokens, first_token], axis=1)
decoding_steps = 1
n_new_tokens = 1
accepted_draft_tokens = 0
draft_logit = self.draft_model(
tokens, read_cache=False, write_cache=True, next_token_only=True
)
# print("Before doing any speculative decoding, draft model cache is: ", self.draft_model.kv_cache[0][0].shape[2])
# print("And prompt length is: ", prompt_len)
while True:
# for each decoding step: generate n tokens from a draft model
draft_tokens = sample(draft_logit).reshape(batch_size, 1)
draft_tokens_left = n - 1
for (
token
) in self.draft_model.generate( # generate automatically updates the cache, it has to
draft_tokens, temp=temp, read_cache=True
):
draft_tokens = mx.concatenate(
[draft_tokens, token.reshape(batch_size, 1)], axis=1
)
draft_tokens_left -= 1
if draft_tokens_left == 0:
break
# have to verify the first draft token using the last verified token
verify_tokens = mx.concatenate([tokens[:, -1:], draft_tokens], axis=1)
# print("Tokens so far: ", self.tokenizer.decode(np.array(tokens[0, prompt_len:]).tolist()))
# print("Predicted draft tokens: [", self.tokenizer.decode(np.array(draft_tokens[0, :]).tolist()), "]")
logits = self.model(verify_tokens, read_cache=True, write_cache=True)
# check the last n + 1 tokens
sampled = sample(logits[:, -(n + 1) :, :])
# print("Sampled tokens: [", self.tokenizer.decode(np.array(sampled[0, :]).tolist()), "]")
# only keep samples that match the draft
num_to_accept = 0
for i in range(n):
if mx.all(sampled[:, i] == draft_tokens[:, i]):
num_to_accept += 1
else:
break
# print("Accepting", num_to_accept)
accepted_draft_tokens += num_to_accept
n_new_tokens += 1 + num_to_accept
new_tokens = sampled[:, : num_to_accept + 1]
tokens = mx.concatenate([tokens, new_tokens], axis=1)
# truncate draft cache to keep only accepted tokens
# what tokens have been INPUT into the draft model? let's say n = 5, start with |p| tokens
# |p| -> t0; |p + t0| -> t1; |p + t0 + t1| -> t2; |p + t0 + t1 + t2| -> t3; |p + t0 + t1 + t2 + t3| -> t4;
# return -> t0 - t4, cache has |p + t0 + t1 + t2 + t3|
# main model accepts whatever is correct, then generates t'
# if 0 accepted: cache should have |p + t'|
# if 1 accepted: |p + t0 + t'|
# if 2 accepted: |p + t0 + t1 + t'|
# ...
# if 5 accepted: |p + t0 + t1 + t2 + t3 + t4 + t'|
# we're always going to have to show the draft model the 1 token where it went off
# the rails and we rejected it and took the real model, cause that won't be in its cache
# print("After speculative decoding, before truncation, draft cache has: ", self.draft_model.kv_cache[0][0].shape[2])
if num_to_accept < n:
self.draft_model.truncate_kv_cache(n - 1 - num_to_accept)
# print("Truncated draft cache by", n - 1 - num_to_accept, "now it has", self.draft_model.kv_cache[0][0].shape[2])
elif num_to_accept == n:
# forward the model on the last draft token to catch it up
# maybe this is bad?
self.draft_model(
draft_tokens[:, -1:],
read_cache=True,
write_cache=True,
next_token_only=True,
)
# now how to truncate the full model's cache?
# it has |p + t0 + t1 + t2 + t3 + t4|
# if 0 accepted: truncate back to p
# if 1 accepted: truncate to p + t0
self.model.truncate_kv_cache(
n - num_to_accept
) # how many to truncate? i think 1 off from draft model? idk
# NOTE: main model doesn't know that it predicted t' (the 1 non-draft token)
# i think this is ok because it's the last accepted token and will be passed back in at verification time
# NOTE: model is now (or could be!) 1 token ahead of draft model cause if it accepts the full
# draft it's now predicted 1 token past the draft token's last token. must account for this.
# print("Length of accepted tokens: ", tokens.shape[1])
# print("Length of draft model cache: ", self.draft_model.kv_cache[0][0].shape[2])
# print("Length of main model cache: ", self.model.kv_cache[0][0].shape[2])
decoding_steps += 1
if n_new_tokens >= num_tokens or mx.any(new_tokens == 2):
break
# get the next draft token based on t', preparing to do it all again!
# print("Getting the token that comes after: ", self.tokenizer.decode(np.array(tokens[0, -1:]).tolist()))
draft_logit = self.draft_model(
tokens[:, -1:], read_cache=True, write_cache=True, next_token_only=True
)
mx.eval(tokens)
end = time.time()
seq = np.array(tokens[0, :]).tolist()
s = self.tokenizer.decode(seq[prompt_len:])
# print(f"=== COMPLETION {0 + 1} ===")
# print(s)
print(
"=== GENERATED",
n_new_tokens,
"TOKENS IN",
round(end - start, 2),
"SECONDS ===",
)
print("=== ACCEPTED", accepted_draft_tokens, "DRAFT TOKENS ===")
print("=== DECODING STEPS", decoding_steps, "===")
return s

View File

@@ -0,0 +1,21 @@
import time
from decoder import SpeculativeDecoder
# This will use the chat template from the primary model
engine = SpeculativeDecoder(
# model="TinyLlama/TinyLlama-1.1B-intermediate-step-1195k-token-2.5T",
model="meta-llama/Llama-2-7b-hf",
draft_model="TinyLlama/TinyLlama-1.1B-intermediate-step-1195k-token-2.5T",
)
messages = {"role": "user", "content": "Finish the monologue: To be, or not to be..."}
# Do 1 regular generation to get warmed up (the first one is slow)
engine.generate(messages, max_tokens=1)
engine.generate(messages, max_tokens=1, draft=True)
# Time regular generation
engine.generate(messages, max_tokens=125)
# Time speculative decoding
engine.speculative_decode(messages, max_tokens=125, n_draft=10)

View File

@@ -7,6 +7,14 @@ import mlx.nn as nn
from typing import Optional, Tuple
def create_additive_causal_mask(N: int, offset: int = 0, dtype: mx.Dtype = mx.float32):
rinds = mx.arange(offset + N)
linds = mx.arange(offset, offset + N) if offset else rinds
mask = linds[:, None] < rinds[None]
mask = mask.astype(dtype) * -1e9
return mask
class RMSNorm(nn.Module):
def __init__(self, dims: int, eps: float = 1e-5):
super().__init__()
@@ -29,7 +37,6 @@ class Attention(nn.Module):
self.n_heads: int = config.num_attention_heads
self.n_kv_heads: int = config.num_key_value_heads
self.repeats = self.n_heads // self.n_kv_heads
# print("heads", self.n_heads, "kv heads", self.n_kv_heads, "repeats", self.repeats)
self.head_dim = config.hidden_size // self.n_heads
self.scale = self.head_dim**-0.5
@@ -63,12 +70,8 @@ class Attention(nn.Module):
def repeat(a):
a = mx.concatenate([mx.expand_dims(a, 2)] * self.repeats, axis=2)
kv_size = a.shape[-1]
# can't use the L from x here, this is like cross-attention during decoding
return a.reshape([B, self.n_heads, -1, kv_size])
# cache should be with unrepeated kv, otherwise GQA is pointless lol
# keys, values = map(repeat, (keys, values))
if cache is not None:
key_cache, value_cache = cache
queries = self.rope(queries, offset=key_cache.shape[2])
@@ -79,17 +82,9 @@ class Attention(nn.Module):
queries = self.rope(queries)
keys = self.rope(keys)
# print("queries shape", queries.shape, "keys shape", keys.shape, "values shape", values.shape)
scores = (queries * self.scale) @ repeat(keys).transpose(0, 1, 3, 2)
if mask is not None:
# print("we need to add mask of shape", mask.shape, "to scores of shape", scores.shape)
if cache is None:
scores += mask
else:
# we're doing "cross-attn"; add mask to the "end" of the attn matrix along the K dimension
a, b = mx.split(scores, indices_or_sections=[-mask.shape[-1]], axis=-1)
scores = mx.concatenate([a, b + mask], axis=-1)
scores += mask
scores = mx.softmax(scores.astype(mx.float32), axis=-1).astype(scores.dtype)
output = (scores @ repeat(values)).transpose(0, 2, 1, 3).reshape(B, L, -1)
return self.o_proj(output), (keys, values)
@@ -148,75 +143,55 @@ class Llama(nn.Module):
]
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
self.kv_cache = []
self.reset_cache()
def truncate_kv_cache(self, num_to_truncate):
def truncate_cache(self, num_to_truncate):
cache_length = self.kv_cache[0][0].shape[2]
num_to_truncate = min(num_to_truncate, cache_length)
if num_to_truncate == 0:
return False
else:
if num_to_truncate < cache_length:
self.kv_cache = tree_map(
lambda x: x[:, :, :-num_to_truncate, :], self.kv_cache
)
return True
else:
self.reset_cache()
def reset_cache(self):
self.kv_cache = [None] * len(self.layers)
def __call__(
self,
x: mx.array,
read_cache: bool = False,
write_cache: bool = False,
next_token_only: bool = False,
):
mask = nn.MultiHeadAttention.create_additive_causal_mask(x.shape[1])
mask = mask.astype(self.embed_tokens.weight.dtype)
if read_cache and len(self.kv_cache) != len(self.layers):
raise RuntimeError(
f"Length of cache ({len(self.kv_cache)}) must match number of layers ({len(self.layers)})"
)
if self.kv_cache[0]:
offset = self.kv_cache[0][0].shape[-2]
else:
offset = 0
if x.shape[1] > 1:
mask = create_additive_causal_mask(x.shape[1], offset)
mask = mask.astype(self.embed_tokens.weight.dtype)
else:
mask = None
x = self.embed_tokens(x)
for idx, layer in enumerate(self.layers):
x, c = layer(x, mask, cache=self.kv_cache[idx] if read_cache else None)
if write_cache:
if len(self.kv_cache) == 0:
self.kv_cache = [None] * len(self.layers)
self.kv_cache[idx] = c
x = self.norm(x)
x, self.kv_cache[idx] = layer(x, mask, cache=self.kv_cache[idx])
if next_token_only:
x = x[:, -1]
x = self.norm(x)
return self.lm_head(x)
@classmethod
def from_hugging_face(cls, model_path: str):
config = LlamaConfig.from_pretrained(model_path)
torch_weights = AutoModelForCausalLM.from_pretrained(model_path).state_dict()
mx_weights = {
k.replace("model.", ""): mx.array(v.numpy())
weights = {
k.replace("model.", ""): mx.array(v.numpy(), mx.float16)
for k, v in torch_weights.items()
}
for k in mx_weights.keys():
mx_weights[k] = mx_weights[k].astype(mx.float16)
mlx_model = cls(config)
mlx_model.update(tree_unflatten(list(mx_weights.items())))
return mlx_model
def generate(self, x: mx.array, temp=0.0, read_cache: bool = False):
# Make an additive causal mask. We will need that to process the prompt.
def sample(logits):
if temp == 0:
return mx.argmax(logits, axis=-1)
else:
return mx.random.categorical(logits * (1 / temp))
mask = nn.MultiHeadAttention.create_additive_causal_mask(x.shape[1])
mask = mask.astype(self.embed_tokens.weight.dtype)
logit = self(x, read_cache=read_cache, write_cache=True, next_token_only=True)
tok = sample(logit)
yield tok
while True:
x = tok.reshape(-1, 1)
logit = self(x, read_cache=True, write_cache=True, next_token_only=True)
tok = sample(logit)
yield tok
model = cls(config)
model.update(tree_unflatten(list(weights.items())))
mx.eval(model.parameters())
return model

View File

@@ -1,25 +0,0 @@
import time
from engine import LlamaEngine
# This will use the chat template from the primary model
engine = LlamaEngine(
model="meta-llama/Llama-2-7b-hf",
draft_model="TinyLlama/TinyLlama-1.1B-intermediate-step-1195k-token-2.5T"
)
messages = [
{"role": "user", "content": "Finish the monologue: To be, or not to be..."}
]
# Do 1 regular generation to get warmed up (the first one is slow)
engine.generate(messages, num_tokens=1, temp=0.1)
# Time regular generation
start = time.time()
engine.generate(messages, num_tokens=125, temp=0.1)
print(f"Regular generation took {time.time() - start} seconds")
# Time speculative decoding
start = time.time()
engine.speculative_decode(messages, num_tokens=125, temp=0.1, n=5)
print(f"Speculative decoding took {time.time() - start} seconds")