speculative decoding

This commit is contained in:
Benjamin Anderson 2023-12-20 00:25:49 -06:00 committed by Awni Hannun
parent 50fceb1a28
commit a436d198ec
6 changed files with 457 additions and 0 deletions

View File

@ -0,0 +1,16 @@
## Speculative Decoding with MLX
This example implements [speculative decoding](https://arxiv.org/abs/2211.17192), which allows you to use a smaller draft model to predict several tokens, and then a larger verification model to check them all in parallel. The results are output that is identical to what the larger model would produce, but with far fewer forward passes (as long as the reference model is good enough at guessing).
Install the requirements and then you can try it out:
```
cd speculative_decoding
pip install -r requirements.txt
python test.py
```
In order for that to happen, it's generally good if the models are trained on similar data, with a similar chat template, etc. For example, you could use Meta's 7B Llama as a draft model for the 13B Llama. In my tests, I've mostly used TinyLlama as a draft model for Llama-7B. The chat versions of TinyLlama and Llama-7B-Chat are trained with different templates, but it works OK. Alternatively, you can use base models, and a prompt to make the model act like a chat model (e.g. [URIAL](https://arxiv.org/abs/2312.01552)).
I believe the implementation is *correct* (it produces near-identical output with regular generation vs. speculative decoding, and when speculative decoding is enabled, the draft model does correctly predict many tokens). However, it assumes a batch size of 1 at the moment (I'm not actually sure how to handle batching where some drafts might have more correct tokens than others). Also I feel like it could be faster!
Before merging this in, I would appreciate some help understanding how to make this faster and optimizing the performance so it's actually useful!

View File

@ -0,0 +1,190 @@
import transformers
from dataclasses import dataclass, field
from model import Llama
import mlx.core as mx
import mlx.nn as nn
import time
import numpy as np
from typing import Optional
from prompts import create_urial_prompt
def clone(x: mx.array):
return mx.array(np.array(x))
@dataclass
class LlamaEngine:
model: str # path to HuggingFace repo
draft_model: Optional[str] = None # path to draft model
tokenizer: transformers.AutoTokenizer = field(init=False)
def __post_init__(self):
self.tokenizer = transformers.AutoTokenizer.from_pretrained(self.model)
self.model = Llama.from_hugging_face(self.model)
if self.draft_model is not None:
self.draft_model = Llama.from_hugging_face(self.draft_model)
def tokenize(self, messages):
if self.tokenizer.chat_template is not None:
tokenized = self.tokenizer.apply_chat_template(
messages,
tokenize=True,
add_generation_prompt=True
)
else:
# use urial zero-shot template
tokenized = self.tokenizer.encode(create_urial_prompt(messages[0]["content"]))
return tokenized
def generate(
self,
messages,
num_tokens: int = 100,
temp: float = 0.8,
draft_model: bool = False # if true gen with draft model
):
tokenized = self.tokenize(messages)
x = mx.array([tokenized])
tokens = []
start = time.time()
for token in self.model.generate(x, temp):
if token.item() == 2:
break
tokens.append(token)
if len(tokens) >= num_tokens:
break
mx.eval(tokens)
run_time = time.time() - start
tokens = [t.item() for t in tokens]
s = self.tokenizer.decode(tokens)
# print("=== COMPLETION ===")
# print(s)
print(f"=== GENERATED {len(tokens)} TOKENS in {run_time} SECONDS ===")
return s
# generate only supports batch size 1, so should this
def speculative_decode(
self,
messages,
num_tokens: int = 100,
temp: float = 0.8,
n: int = 5
):
batch_size = 1
tokenized = self.tokenize(messages)
tokens = mx.array([tokenized])
prompt_len = tokens.shape[1]
start = time.time()
# prefill the main model
# sample first token & write draft from there (avoids rewinding main model which i was doing before)
logit = self.model(tokens, read_cache=False, write_cache=True, next_token_only=True)
first_token = mx.random.categorical(logit * (1 / temp)).reshape(batch_size, 1)
tokens = mx.concatenate([tokens, first_token], axis=1)
decoding_steps = 1
n_new_tokens = 1
accepted_draft_tokens = 0
draft_logit = self.draft_model(tokens, read_cache=False, write_cache=True, next_token_only=True)
# print("Before doing any speculative decoding, draft model cache is: ", self.draft_model.kv_cache[0][0].shape[2])
# print("And prompt length is: ", prompt_len)
while True:
# for each decoding step: generate n tokens from a draft model
draft_tokens = mx.random.categorical(draft_logit * (1 / temp)).reshape(batch_size, 1)
draft_tokens_left = n - 1
for token in self.draft_model.generate( # generate automatically updates the cache, it has to
draft_tokens,
temp=temp,
read_cache=True
):
draft_tokens = mx.concatenate([draft_tokens, token.reshape(batch_size, 1)], axis=1)
draft_tokens_left -= 1
if draft_tokens_left == 0:
break
# have to verify the first draft token using the last verified token
verify_tokens = mx.concatenate([tokens[:, -1:], draft_tokens], axis=1)
# print("Tokens so far: ", self.tokenizer.decode(np.array(tokens[0, prompt_len:]).tolist()))
# print("Predicted draft tokens: [", self.tokenizer.decode(np.array(draft_tokens[0, :]).tolist()), "]")
logits = self.model(verify_tokens, read_cache=True, write_cache=True)
# check the last n + 1 tokens
logits = logits[:, -(n + 1):, :]
sampled = mx.random.categorical(logits * (1 / temp), axis=-1)
# print("Sampled tokens: [", self.tokenizer.decode(np.array(sampled[0, :]).tolist()), "]")
# only keep samples that match the draft
num_to_accept = 0
for i in range(n):
if mx.all(sampled[:, i] == draft_tokens[:, i]):
num_to_accept += 1
else:
break
# print("Accepting", num_to_accept)
accepted_draft_tokens += num_to_accept
n_new_tokens += (1 + num_to_accept)
new_tokens = sampled[:, :num_to_accept + 1]
tokens = mx.concatenate([tokens, new_tokens], axis=1)
# truncate draft cache to keep only accepted tokens
# what tokens have been INPUT into the draft model? let's say n = 5, start with |p| tokens
# |p| -> t0; |p + t0| -> t1; |p + t0 + t1| -> t2; |p + t0 + t1 + t2| -> t3; |p + t0 + t1 + t2 + t3| -> t4;
# return -> t0 - t4, cache has |p + t0 + t1 + t2 + t3|
# main model accepts whatever is correct, then generates t'
# if 0 accepted: cache should have |p + t'|
# if 1 accepted: |p + t0 + t'|
# if 2 accepted: |p + t0 + t1 + t'|
# ...
# if 5 accepted: |p + t0 + t1 + t2 + t3 + t4 + t'|
# we're always going to have to show the draft model the 1 token where it went off
# the rails and we rejected it and took the real model, cause that won't be in its cache
# print("After speculative decoding, before truncation, draft cache has: ", self.draft_model.kv_cache[0][0].shape[2])
if num_to_accept < n:
self.draft_model.truncate_kv_cache(n - 1 - num_to_accept)
# print("Truncated draft cache by", n - 1 - num_to_accept, "now it has", self.draft_model.kv_cache[0][0].shape[2])
elif num_to_accept == n:
# forward the model on the last draft token to catch it up
# maybe this is bad?
self.draft_model(draft_tokens[:, -1:], read_cache=True, write_cache=True, next_token_only=True)
# now how to truncate the full model's cache?
# it has |p + t0 + t1 + t2 + t3 + t4|
# if 0 accepted: truncate back to p
# if 1 accepted: truncate to p + t0
self.model.truncate_kv_cache(n - num_to_accept) # how many to truncate? i think 1 off from draft model? idk
# NOTE: main model doesn't know that it predicted t' (the 1 non-draft token)
# i think this is ok because it's the last accepted token and will be passed back in at verification time
# NOTE: model is now (or could be!) 1 token ahead of draft model cause if it accepts the full
# draft it's now predicted 1 token past the draft token's last token. must account for this.
# print("Length of accepted tokens: ", tokens.shape[1])
# print("Length of draft model cache: ", self.draft_model.kv_cache[0][0].shape[2])
# print("Length of main model cache: ", self.model.kv_cache[0][0].shape[2])
decoding_steps += 1
if n_new_tokens >= num_tokens or mx.any(new_tokens == 2):
break
# get the next draft token based on t', preparing to do it all again!
# print("Getting the token that comes after: ", self.tokenizer.decode(np.array(tokens[0, -1:]).tolist()))
draft_logit = self.draft_model(tokens[:, -1:], read_cache=True, write_cache=True, next_token_only=True)
mx.eval(tokens)
end = time.time()
seq = np.array(tokens[0, :]).tolist()
s = self.tokenizer.decode(seq[prompt_len:])
# print(f"=== COMPLETION {0 + 1} ===")
# print(s)
print("=== GENERATED", n_new_tokens, "TOKENS IN", round(end - start, 2), "SECONDS ===")
print("=== ACCEPTED", accepted_draft_tokens, "DRAFT TOKENS ===")
print("=== DECODING STEPS", decoding_steps, "===")
return s

View File

@ -0,0 +1,195 @@
from transformers import LlamaConfig, AutoModelForCausalLM
import mlx.core as mx
import mlx.nn as nn
from mlx.utils import tree_unflatten, tree_map
import mlx.core as mx
import mlx.nn as nn
from typing import Optional, Tuple
class RMSNorm(nn.Module):
def __init__(self, dims: int, eps: float = 1e-5):
super().__init__()
self.weight = mx.ones((dims,))
self.eps = eps
def _norm(self, x):
return x * mx.rsqrt(x.square().mean(-1, keepdims=True) + self.eps)
def __call__(self, x):
output = self._norm(x.astype(mx.float32)).astype(x.dtype)
return self.weight * output
class Attention(nn.Module):
def __init__(self, config: LlamaConfig):
super().__init__()
self.config = config
self.n_heads: int = config.num_attention_heads
self.n_kv_heads: int = config.num_key_value_heads
self.repeats = self.n_heads // self.n_kv_heads
# print("heads", self.n_heads, "kv heads", self.n_kv_heads, "repeats", self.repeats)
self.head_dim = config.hidden_size // self.n_heads
self.scale = self.head_dim ** -0.5
self.q_proj = nn.Linear(config.hidden_size, config.hidden_size, bias=False)
self.k_proj = nn.Linear(config.hidden_size, config.hidden_size // self.repeats, bias=False)
self.v_proj = nn.Linear(config.hidden_size, config.hidden_size // self.repeats, bias=False)
self.o_proj = nn.Linear(config.hidden_size, config.hidden_size, bias=False)
self.rope = nn.RoPE(self.head_dim, traditional=False)
def __call__(
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[Tuple[mx.array, mx.array]] = None,
) -> mx.array:
B, L, D = x.shape
queries, keys, values = self.q_proj(x), self.k_proj(x), self.v_proj(x)
# Prepare the queries, keys and values for the attention computation
queries = queries.reshape(B, L, self.n_heads, -1).transpose(0, 2, 1, 3)
keys = keys.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
values = values.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3) # B, n_kv_heads, L, head_dim
def repeat(a):
a = mx.concatenate([mx.expand_dims(a, 2)] * self.repeats, axis=2)
kv_size = a.shape[-1]
# can't use the L from x here, this is like cross-attention during decoding
return a.reshape([B, self.n_heads, -1, kv_size])
# cache should be with unrepeated kv, otherwise GQA is pointless lol
# keys, values = map(repeat, (keys, values))
if cache is not None:
key_cache, value_cache = cache
queries = self.rope(queries, offset=key_cache.shape[2])
keys = self.rope(keys, offset=key_cache.shape[2])
keys = mx.concatenate([key_cache, keys], axis=2)
values = mx.concatenate([value_cache, values], axis=2)
else:
queries = self.rope(queries)
keys = self.rope(keys)
# print("queries shape", queries.shape, "keys shape", keys.shape, "values shape", values.shape)
scores = (queries * self.scale) @ repeat(keys).transpose(0, 1, 3, 2)
if mask is not None:
# print("we need to add mask of shape", mask.shape, "to scores of shape", scores.shape)
if cache is None:
scores += mask
else:
# we're doing "cross-attn"; add mask to the "end" of the attn matrix along the K dimension
a, b = mx.split(scores, indices_or_sections=[-mask.shape[-1]], axis=-1)
scores = mx.concatenate([a, b + mask], axis=-1)
scores = mx.softmax(scores.astype(mx.float32), axis=-1).astype(scores.dtype)
output = (scores @ repeat(values)).transpose(0, 2, 1, 3).reshape(B, L, -1)
return self.o_proj(output), (keys, values)
class FeedForward(nn.Module):
def __init__(self, config: LlamaConfig):
super().__init__()
self.gate_proj = nn.Linear(config.hidden_size, config.intermediate_size, bias=False)
self.down_proj = nn.Linear(config.intermediate_size, config.hidden_size, bias=False)
self.up_proj = nn.Linear(config.hidden_size, config.intermediate_size, bias=False)
def __call__(self, x) -> mx.array:
return self.down_proj(nn.silu(self.gate_proj(x)) * self.up_proj(x))
class TransformerBlock(nn.Module):
def __init__(self, config: LlamaConfig):
super().__init__()
self.n_heads = config.num_attention_heads
self.dim = config.hidden_size
self.self_attn = Attention(config=config)
self.mlp = FeedForward(config=config)
self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.post_attention_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
def __call__(
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[Tuple[mx.array, mx.array]] = None,
) -> mx.array:
r, cache = self.self_attn(self.input_layernorm(x), mask, cache)
h = x + r
r = self.mlp(self.post_attention_layernorm(h))
out = h + r
return out, cache
class Llama(nn.Module):
def __init__(self, config: LlamaConfig):
super().__init__()
self.config = config
self.vocab_size = config.vocab_size
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size)
self.layers = [TransformerBlock(config=config) for _ in range(config.num_hidden_layers)]
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
self.kv_cache = []
def truncate_kv_cache(self, num_to_truncate):
cache_length = self.kv_cache[0][0].shape[2]
num_to_truncate = min(num_to_truncate, cache_length)
if num_to_truncate == 0:
return False
else:
self.kv_cache = tree_map(lambda x: x[:, :, :-num_to_truncate, :], self.kv_cache)
return True
def __call__(
self,
x: mx.array,
read_cache: bool = False,
write_cache: bool = False,
next_token_only: bool = False
):
mask = nn.MultiHeadAttention.create_additive_causal_mask(x.shape[1])
mask = mask.astype(self.embed_tokens.weight.dtype)
if read_cache and len(self.kv_cache) != len(self.layers):
raise RuntimeError(f"Length of cache ({len(self.kv_cache)}) must match number of layers ({len(self.layers)})")
x = self.embed_tokens(x)
for idx, layer in enumerate(self.layers):
x, c = layer(x, mask, cache=self.kv_cache[idx] if read_cache else None)
if write_cache:
if len(self.kv_cache) == 0:
self.kv_cache = [None] * len(self.layers)
self.kv_cache[idx] = c
x = self.norm(x)
if next_token_only:
x = x[:, -1]
return self.lm_head(x)
@classmethod
def from_hugging_face(cls, model_path: str):
config = LlamaConfig.from_pretrained(model_path)
torch_weights = AutoModelForCausalLM.from_pretrained(model_path).state_dict()
mx_weights = {k.replace("model.", ""):mx.array(v.numpy()) for k, v in torch_weights.items()}
for k in mx_weights.keys():
mx_weights[k] = mx_weights[k].astype(mx.float16)
mlx_model = cls(config)
mlx_model.update(tree_unflatten(list(mx_weights.items())))
return mlx_model
def generate(
self,
x: mx.array,
temp=1.0,
read_cache: bool = False
):
# Make an additive causal mask. We will need that to process the prompt.
mask = nn.MultiHeadAttention.create_additive_causal_mask(x.shape[1])
mask = mask.astype(self.embed_tokens.weight.dtype)
logit = self(x, read_cache=read_cache, write_cache=True, next_token_only=True)
tok = mx.random.categorical(logit * (1 / temp))
yield tok
while True:
x = tok.reshape(-1, 1)
logit = self(x, read_cache=True, write_cache=True, next_token_only=True)
tok = mx.random.categorical(logit * (1 / temp))
yield tok

View File

@ -0,0 +1,28 @@
URIAL_1_SHOT = """# Instruction
Below is a list of conversations between a human and an AI assistant (you).
Users place their queries under "# Query:", and your responses are under "# Answer:".
You are a helpful, respectful, and honest assistant.
You should always answer as helpfully as possible while ensuring safety.
Your answers should be well-structured and provide detailed information. They should also have an engaging tone.
Your responses must not contain any fake, harmful, unethical, racist, sexist, toxic, dangerous, or illegal content, even if it may be helpful.
Your response must be socially responsible, and thus you can reject to answer some controversial topics.
# Query:
```Can you tell me some common types of renewable energy sources?```
# Answer:
```
Sure, here are some of the most common types of renewable energy sources:
1. Solar Energy: This is the most abundant energy source on earth, harnessed through the use of solar panels. These panels convert sunlight into electricity without any moving parts, noise, pollution, or damage to the environment.
2. Wind Energy: Wind turbines convert the kinetic energy in the wind into mechanical power. This mechanical power can be used for specific tasks (such as pumping water) or converted into electricity to power homes, businesses, and schools.
3. Hydropower: Generated by using electricity generators to capture the energy from falling or fast-moving water. This renewable source can come from various ways, including a large dam on a river, tidal or wave energy from the ocean, or using small scale turbines in streams.
4. Geothermal Energy: This type of energy is generated from the heat deep within the Earth. This heat can be used directly for heating buildings or to generate electricity. It is continuously produced inside the Earth and is nearly as reliable as the tides.
5. Biomass Energy: Biomass is organic material that comes from plants and animals, and it contains stored energy from the sun. This energy can be burned directly or converted into biofuel which can burn more efficiently.
Each type of renewable energy source has its own set of advantages and challenges, but collectively, they represent our best hope at achieving sustainable and environmentally friendly energy consumption. Please let me know if you have any other questions!
```"""
def create_urial_prompt(message: str):
return URIAL_1_SHOT + "\n\n# Query:\n```" + message + "```\n\n# Answer:\n```\n"

View File

@ -0,0 +1,3 @@
mlx>=0.0.5
transformers
numpy

View File

@ -0,0 +1,25 @@
import time
from engine import LlamaEngine
# This will use the chat template from the primary model
engine = LlamaEngine(
model="meta-llama/Llama-2-7b-hf",
draft_model="TinyLlama/TinyLlama-1.1B-intermediate-step-1195k-token-2.5T"
)
messages = [
{"role": "user", "content": "Finish the monologue: To be, or not to be..."}
]
# Do 1 regular generation to get warmed up (the first one is slow)
engine.generate(messages, num_tokens=1, temp=0.1)
# Time regular generation
start = time.time()
engine.generate(messages, num_tokens=125, temp=0.1)
print(f"Regular generation took {time.time() - start} seconds")
# Time speculative decoding
start = time.time()
engine.speculative_decode(messages, num_tokens=125, temp=0.1, n=5)
print(f"Speculative decoding took {time.time() - start} seconds")