mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-08-31 11:54:37 +08:00
speculative decoding
This commit is contained in:
parent
50fceb1a28
commit
a436d198ec
16
speculative_decoding/README.md
Normal file
16
speculative_decoding/README.md
Normal file
@ -0,0 +1,16 @@
|
||||
## Speculative Decoding with MLX
|
||||
|
||||
This example implements [speculative decoding](https://arxiv.org/abs/2211.17192), which allows you to use a smaller draft model to predict several tokens, and then a larger verification model to check them all in parallel. The results are output that is identical to what the larger model would produce, but with far fewer forward passes (as long as the reference model is good enough at guessing).
|
||||
|
||||
Install the requirements and then you can try it out:
|
||||
```
|
||||
cd speculative_decoding
|
||||
pip install -r requirements.txt
|
||||
python test.py
|
||||
```
|
||||
|
||||
In order for that to happen, it's generally good if the models are trained on similar data, with a similar chat template, etc. For example, you could use Meta's 7B Llama as a draft model for the 13B Llama. In my tests, I've mostly used TinyLlama as a draft model for Llama-7B. The chat versions of TinyLlama and Llama-7B-Chat are trained with different templates, but it works OK. Alternatively, you can use base models, and a prompt to make the model act like a chat model (e.g. [URIAL](https://arxiv.org/abs/2312.01552)).
|
||||
|
||||
I believe the implementation is *correct* (it produces near-identical output with regular generation vs. speculative decoding, and when speculative decoding is enabled, the draft model does correctly predict many tokens). However, it assumes a batch size of 1 at the moment (I'm not actually sure how to handle batching where some drafts might have more correct tokens than others). Also I feel like it could be faster!
|
||||
|
||||
Before merging this in, I would appreciate some help understanding how to make this faster and optimizing the performance so it's actually useful!
|
190
speculative_decoding/engine.py
Normal file
190
speculative_decoding/engine.py
Normal file
@ -0,0 +1,190 @@
|
||||
import transformers
|
||||
from dataclasses import dataclass, field
|
||||
from model import Llama
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
import time
|
||||
import numpy as np
|
||||
from typing import Optional
|
||||
|
||||
from prompts import create_urial_prompt
|
||||
|
||||
def clone(x: mx.array):
|
||||
return mx.array(np.array(x))
|
||||
|
||||
@dataclass
|
||||
class LlamaEngine:
|
||||
model: str # path to HuggingFace repo
|
||||
draft_model: Optional[str] = None # path to draft model
|
||||
tokenizer: transformers.AutoTokenizer = field(init=False)
|
||||
|
||||
def __post_init__(self):
|
||||
self.tokenizer = transformers.AutoTokenizer.from_pretrained(self.model)
|
||||
self.model = Llama.from_hugging_face(self.model)
|
||||
if self.draft_model is not None:
|
||||
self.draft_model = Llama.from_hugging_face(self.draft_model)
|
||||
|
||||
def tokenize(self, messages):
|
||||
if self.tokenizer.chat_template is not None:
|
||||
tokenized = self.tokenizer.apply_chat_template(
|
||||
messages,
|
||||
tokenize=True,
|
||||
add_generation_prompt=True
|
||||
)
|
||||
else:
|
||||
# use urial zero-shot template
|
||||
tokenized = self.tokenizer.encode(create_urial_prompt(messages[0]["content"]))
|
||||
|
||||
return tokenized
|
||||
|
||||
def generate(
|
||||
self,
|
||||
messages,
|
||||
num_tokens: int = 100,
|
||||
temp: float = 0.8,
|
||||
draft_model: bool = False # if true gen with draft model
|
||||
):
|
||||
tokenized = self.tokenize(messages)
|
||||
x = mx.array([tokenized])
|
||||
tokens = []
|
||||
start = time.time()
|
||||
for token in self.model.generate(x, temp):
|
||||
if token.item() == 2:
|
||||
break
|
||||
tokens.append(token)
|
||||
if len(tokens) >= num_tokens:
|
||||
break
|
||||
mx.eval(tokens)
|
||||
run_time = time.time() - start
|
||||
tokens = [t.item() for t in tokens]
|
||||
s = self.tokenizer.decode(tokens)
|
||||
# print("=== COMPLETION ===")
|
||||
# print(s)
|
||||
print(f"=== GENERATED {len(tokens)} TOKENS in {run_time} SECONDS ===")
|
||||
return s
|
||||
|
||||
# generate only supports batch size 1, so should this
|
||||
def speculative_decode(
|
||||
self,
|
||||
messages,
|
||||
num_tokens: int = 100,
|
||||
temp: float = 0.8,
|
||||
n: int = 5
|
||||
):
|
||||
batch_size = 1
|
||||
tokenized = self.tokenize(messages)
|
||||
tokens = mx.array([tokenized])
|
||||
prompt_len = tokens.shape[1]
|
||||
start = time.time()
|
||||
|
||||
# prefill the main model
|
||||
# sample first token & write draft from there (avoids rewinding main model which i was doing before)
|
||||
logit = self.model(tokens, read_cache=False, write_cache=True, next_token_only=True)
|
||||
first_token = mx.random.categorical(logit * (1 / temp)).reshape(batch_size, 1)
|
||||
tokens = mx.concatenate([tokens, first_token], axis=1)
|
||||
|
||||
decoding_steps = 1
|
||||
n_new_tokens = 1
|
||||
accepted_draft_tokens = 0
|
||||
draft_logit = self.draft_model(tokens, read_cache=False, write_cache=True, next_token_only=True)
|
||||
# print("Before doing any speculative decoding, draft model cache is: ", self.draft_model.kv_cache[0][0].shape[2])
|
||||
# print("And prompt length is: ", prompt_len)
|
||||
while True:
|
||||
# for each decoding step: generate n tokens from a draft model
|
||||
draft_tokens = mx.random.categorical(draft_logit * (1 / temp)).reshape(batch_size, 1)
|
||||
draft_tokens_left = n - 1
|
||||
for token in self.draft_model.generate( # generate automatically updates the cache, it has to
|
||||
draft_tokens,
|
||||
temp=temp,
|
||||
read_cache=True
|
||||
):
|
||||
draft_tokens = mx.concatenate([draft_tokens, token.reshape(batch_size, 1)], axis=1)
|
||||
draft_tokens_left -= 1
|
||||
if draft_tokens_left == 0:
|
||||
break
|
||||
|
||||
# have to verify the first draft token using the last verified token
|
||||
verify_tokens = mx.concatenate([tokens[:, -1:], draft_tokens], axis=1)
|
||||
# print("Tokens so far: ", self.tokenizer.decode(np.array(tokens[0, prompt_len:]).tolist()))
|
||||
# print("Predicted draft tokens: [", self.tokenizer.decode(np.array(draft_tokens[0, :]).tolist()), "]")
|
||||
logits = self.model(verify_tokens, read_cache=True, write_cache=True)
|
||||
# check the last n + 1 tokens
|
||||
logits = logits[:, -(n + 1):, :]
|
||||
sampled = mx.random.categorical(logits * (1 / temp), axis=-1)
|
||||
# print("Sampled tokens: [", self.tokenizer.decode(np.array(sampled[0, :]).tolist()), "]")
|
||||
|
||||
# only keep samples that match the draft
|
||||
num_to_accept = 0
|
||||
for i in range(n):
|
||||
if mx.all(sampled[:, i] == draft_tokens[:, i]):
|
||||
num_to_accept += 1
|
||||
else:
|
||||
break
|
||||
# print("Accepting", num_to_accept)
|
||||
accepted_draft_tokens += num_to_accept
|
||||
n_new_tokens += (1 + num_to_accept)
|
||||
|
||||
new_tokens = sampled[:, :num_to_accept + 1]
|
||||
tokens = mx.concatenate([tokens, new_tokens], axis=1)
|
||||
|
||||
# truncate draft cache to keep only accepted tokens
|
||||
# what tokens have been INPUT into the draft model? let's say n = 5, start with |p| tokens
|
||||
# |p| -> t0; |p + t0| -> t1; |p + t0 + t1| -> t2; |p + t0 + t1 + t2| -> t3; |p + t0 + t1 + t2 + t3| -> t4;
|
||||
# return -> t0 - t4, cache has |p + t0 + t1 + t2 + t3|
|
||||
# main model accepts whatever is correct, then generates t'
|
||||
# if 0 accepted: cache should have |p + t'|
|
||||
# if 1 accepted: |p + t0 + t'|
|
||||
# if 2 accepted: |p + t0 + t1 + t'|
|
||||
# ...
|
||||
# if 5 accepted: |p + t0 + t1 + t2 + t3 + t4 + t'|
|
||||
# we're always going to have to show the draft model the 1 token where it went off
|
||||
# the rails and we rejected it and took the real model, cause that won't be in its cache
|
||||
# print("After speculative decoding, before truncation, draft cache has: ", self.draft_model.kv_cache[0][0].shape[2])
|
||||
if num_to_accept < n:
|
||||
self.draft_model.truncate_kv_cache(n - 1 - num_to_accept)
|
||||
# print("Truncated draft cache by", n - 1 - num_to_accept, "now it has", self.draft_model.kv_cache[0][0].shape[2])
|
||||
elif num_to_accept == n:
|
||||
# forward the model on the last draft token to catch it up
|
||||
# maybe this is bad?
|
||||
self.draft_model(draft_tokens[:, -1:], read_cache=True, write_cache=True, next_token_only=True)
|
||||
|
||||
# now how to truncate the full model's cache?
|
||||
# it has |p + t0 + t1 + t2 + t3 + t4|
|
||||
# if 0 accepted: truncate back to p
|
||||
# if 1 accepted: truncate to p + t0
|
||||
self.model.truncate_kv_cache(n - num_to_accept) # how many to truncate? i think 1 off from draft model? idk
|
||||
# NOTE: main model doesn't know that it predicted t' (the 1 non-draft token)
|
||||
# i think this is ok because it's the last accepted token and will be passed back in at verification time
|
||||
|
||||
# NOTE: model is now (or could be!) 1 token ahead of draft model cause if it accepts the full
|
||||
# draft it's now predicted 1 token past the draft token's last token. must account for this.
|
||||
|
||||
|
||||
# print("Length of accepted tokens: ", tokens.shape[1])
|
||||
# print("Length of draft model cache: ", self.draft_model.kv_cache[0][0].shape[2])
|
||||
# print("Length of main model cache: ", self.model.kv_cache[0][0].shape[2])
|
||||
decoding_steps += 1
|
||||
|
||||
if n_new_tokens >= num_tokens or mx.any(new_tokens == 2):
|
||||
break
|
||||
|
||||
# get the next draft token based on t', preparing to do it all again!
|
||||
# print("Getting the token that comes after: ", self.tokenizer.decode(np.array(tokens[0, -1:]).tolist()))
|
||||
draft_logit = self.draft_model(tokens[:, -1:], read_cache=True, write_cache=True, next_token_only=True)
|
||||
|
||||
mx.eval(tokens)
|
||||
end = time.time()
|
||||
|
||||
seq = np.array(tokens[0, :]).tolist()
|
||||
s = self.tokenizer.decode(seq[prompt_len:])
|
||||
# print(f"=== COMPLETION {0 + 1} ===")
|
||||
# print(s)
|
||||
|
||||
print("=== GENERATED", n_new_tokens, "TOKENS IN", round(end - start, 2), "SECONDS ===")
|
||||
print("=== ACCEPTED", accepted_draft_tokens, "DRAFT TOKENS ===")
|
||||
print("=== DECODING STEPS", decoding_steps, "===")
|
||||
return s
|
||||
|
||||
|
||||
|
||||
|
195
speculative_decoding/model.py
Normal file
195
speculative_decoding/model.py
Normal file
@ -0,0 +1,195 @@
|
||||
from transformers import LlamaConfig, AutoModelForCausalLM
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
from mlx.utils import tree_unflatten, tree_map
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
from typing import Optional, Tuple
|
||||
|
||||
class RMSNorm(nn.Module):
|
||||
def __init__(self, dims: int, eps: float = 1e-5):
|
||||
super().__init__()
|
||||
self.weight = mx.ones((dims,))
|
||||
self.eps = eps
|
||||
|
||||
def _norm(self, x):
|
||||
return x * mx.rsqrt(x.square().mean(-1, keepdims=True) + self.eps)
|
||||
|
||||
def __call__(self, x):
|
||||
output = self._norm(x.astype(mx.float32)).astype(x.dtype)
|
||||
return self.weight * output
|
||||
|
||||
|
||||
class Attention(nn.Module):
|
||||
def __init__(self, config: LlamaConfig):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
|
||||
self.n_heads: int = config.num_attention_heads
|
||||
self.n_kv_heads: int = config.num_key_value_heads
|
||||
self.repeats = self.n_heads // self.n_kv_heads
|
||||
# print("heads", self.n_heads, "kv heads", self.n_kv_heads, "repeats", self.repeats)
|
||||
self.head_dim = config.hidden_size // self.n_heads
|
||||
self.scale = self.head_dim ** -0.5
|
||||
|
||||
self.q_proj = nn.Linear(config.hidden_size, config.hidden_size, bias=False)
|
||||
self.k_proj = nn.Linear(config.hidden_size, config.hidden_size // self.repeats, bias=False)
|
||||
self.v_proj = nn.Linear(config.hidden_size, config.hidden_size // self.repeats, bias=False)
|
||||
self.o_proj = nn.Linear(config.hidden_size, config.hidden_size, bias=False)
|
||||
self.rope = nn.RoPE(self.head_dim, traditional=False)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
x: mx.array,
|
||||
mask: Optional[mx.array] = None,
|
||||
cache: Optional[Tuple[mx.array, mx.array]] = None,
|
||||
) -> mx.array:
|
||||
B, L, D = x.shape
|
||||
|
||||
queries, keys, values = self.q_proj(x), self.k_proj(x), self.v_proj(x)
|
||||
|
||||
# Prepare the queries, keys and values for the attention computation
|
||||
queries = queries.reshape(B, L, self.n_heads, -1).transpose(0, 2, 1, 3)
|
||||
keys = keys.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
|
||||
values = values.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3) # B, n_kv_heads, L, head_dim
|
||||
|
||||
def repeat(a):
|
||||
a = mx.concatenate([mx.expand_dims(a, 2)] * self.repeats, axis=2)
|
||||
kv_size = a.shape[-1]
|
||||
# can't use the L from x here, this is like cross-attention during decoding
|
||||
return a.reshape([B, self.n_heads, -1, kv_size])
|
||||
|
||||
# cache should be with unrepeated kv, otherwise GQA is pointless lol
|
||||
# keys, values = map(repeat, (keys, values))
|
||||
|
||||
if cache is not None:
|
||||
key_cache, value_cache = cache
|
||||
queries = self.rope(queries, offset=key_cache.shape[2])
|
||||
keys = self.rope(keys, offset=key_cache.shape[2])
|
||||
keys = mx.concatenate([key_cache, keys], axis=2)
|
||||
values = mx.concatenate([value_cache, values], axis=2)
|
||||
else:
|
||||
queries = self.rope(queries)
|
||||
keys = self.rope(keys)
|
||||
|
||||
# print("queries shape", queries.shape, "keys shape", keys.shape, "values shape", values.shape)
|
||||
|
||||
scores = (queries * self.scale) @ repeat(keys).transpose(0, 1, 3, 2)
|
||||
if mask is not None:
|
||||
# print("we need to add mask of shape", mask.shape, "to scores of shape", scores.shape)
|
||||
if cache is None:
|
||||
scores += mask
|
||||
else:
|
||||
# we're doing "cross-attn"; add mask to the "end" of the attn matrix along the K dimension
|
||||
a, b = mx.split(scores, indices_or_sections=[-mask.shape[-1]], axis=-1)
|
||||
scores = mx.concatenate([a, b + mask], axis=-1)
|
||||
scores = mx.softmax(scores.astype(mx.float32), axis=-1).astype(scores.dtype)
|
||||
output = (scores @ repeat(values)).transpose(0, 2, 1, 3).reshape(B, L, -1)
|
||||
return self.o_proj(output), (keys, values)
|
||||
|
||||
class FeedForward(nn.Module):
|
||||
def __init__(self, config: LlamaConfig):
|
||||
super().__init__()
|
||||
self.gate_proj = nn.Linear(config.hidden_size, config.intermediate_size, bias=False)
|
||||
self.down_proj = nn.Linear(config.intermediate_size, config.hidden_size, bias=False)
|
||||
self.up_proj = nn.Linear(config.hidden_size, config.intermediate_size, bias=False)
|
||||
|
||||
def __call__(self, x) -> mx.array:
|
||||
return self.down_proj(nn.silu(self.gate_proj(x)) * self.up_proj(x))
|
||||
|
||||
class TransformerBlock(nn.Module):
|
||||
def __init__(self, config: LlamaConfig):
|
||||
super().__init__()
|
||||
self.n_heads = config.num_attention_heads
|
||||
self.dim = config.hidden_size
|
||||
self.self_attn = Attention(config=config)
|
||||
self.mlp = FeedForward(config=config)
|
||||
self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
self.post_attention_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
x: mx.array,
|
||||
mask: Optional[mx.array] = None,
|
||||
cache: Optional[Tuple[mx.array, mx.array]] = None,
|
||||
) -> mx.array:
|
||||
r, cache = self.self_attn(self.input_layernorm(x), mask, cache)
|
||||
h = x + r
|
||||
r = self.mlp(self.post_attention_layernorm(h))
|
||||
out = h + r
|
||||
return out, cache
|
||||
|
||||
|
||||
class Llama(nn.Module):
|
||||
def __init__(self, config: LlamaConfig):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.vocab_size = config.vocab_size
|
||||
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size)
|
||||
self.layers = [TransformerBlock(config=config) for _ in range(config.num_hidden_layers)]
|
||||
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
||||
self.kv_cache = []
|
||||
|
||||
def truncate_kv_cache(self, num_to_truncate):
|
||||
cache_length = self.kv_cache[0][0].shape[2]
|
||||
num_to_truncate = min(num_to_truncate, cache_length)
|
||||
if num_to_truncate == 0:
|
||||
return False
|
||||
else:
|
||||
self.kv_cache = tree_map(lambda x: x[:, :, :-num_to_truncate, :], self.kv_cache)
|
||||
return True
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
x: mx.array,
|
||||
read_cache: bool = False,
|
||||
write_cache: bool = False,
|
||||
next_token_only: bool = False
|
||||
):
|
||||
mask = nn.MultiHeadAttention.create_additive_causal_mask(x.shape[1])
|
||||
mask = mask.astype(self.embed_tokens.weight.dtype)
|
||||
if read_cache and len(self.kv_cache) != len(self.layers):
|
||||
raise RuntimeError(f"Length of cache ({len(self.kv_cache)}) must match number of layers ({len(self.layers)})")
|
||||
x = self.embed_tokens(x)
|
||||
for idx, layer in enumerate(self.layers):
|
||||
x, c = layer(x, mask, cache=self.kv_cache[idx] if read_cache else None)
|
||||
if write_cache:
|
||||
if len(self.kv_cache) == 0:
|
||||
self.kv_cache = [None] * len(self.layers)
|
||||
self.kv_cache[idx] = c
|
||||
x = self.norm(x)
|
||||
if next_token_only:
|
||||
x = x[:, -1]
|
||||
return self.lm_head(x)
|
||||
|
||||
@classmethod
|
||||
def from_hugging_face(cls, model_path: str):
|
||||
config = LlamaConfig.from_pretrained(model_path)
|
||||
torch_weights = AutoModelForCausalLM.from_pretrained(model_path).state_dict()
|
||||
mx_weights = {k.replace("model.", ""):mx.array(v.numpy()) for k, v in torch_weights.items()}
|
||||
for k in mx_weights.keys():
|
||||
mx_weights[k] = mx_weights[k].astype(mx.float16)
|
||||
mlx_model = cls(config)
|
||||
mlx_model.update(tree_unflatten(list(mx_weights.items())))
|
||||
|
||||
return mlx_model
|
||||
|
||||
def generate(
|
||||
self,
|
||||
x: mx.array,
|
||||
temp=1.0,
|
||||
read_cache: bool = False
|
||||
):
|
||||
# Make an additive causal mask. We will need that to process the prompt.
|
||||
mask = nn.MultiHeadAttention.create_additive_causal_mask(x.shape[1])
|
||||
mask = mask.astype(self.embed_tokens.weight.dtype)
|
||||
|
||||
logit = self(x, read_cache=read_cache, write_cache=True, next_token_only=True)
|
||||
tok = mx.random.categorical(logit * (1 / temp))
|
||||
yield tok
|
||||
while True:
|
||||
x = tok.reshape(-1, 1)
|
||||
logit = self(x, read_cache=True, write_cache=True, next_token_only=True)
|
||||
tok = mx.random.categorical(logit * (1 / temp))
|
||||
yield tok
|
28
speculative_decoding/prompts.py
Normal file
28
speculative_decoding/prompts.py
Normal file
@ -0,0 +1,28 @@
|
||||
URIAL_1_SHOT = """# Instruction
|
||||
|
||||
Below is a list of conversations between a human and an AI assistant (you).
|
||||
Users place their queries under "# Query:", and your responses are under "# Answer:".
|
||||
You are a helpful, respectful, and honest assistant.
|
||||
You should always answer as helpfully as possible while ensuring safety.
|
||||
Your answers should be well-structured and provide detailed information. They should also have an engaging tone.
|
||||
Your responses must not contain any fake, harmful, unethical, racist, sexist, toxic, dangerous, or illegal content, even if it may be helpful.
|
||||
Your response must be socially responsible, and thus you can reject to answer some controversial topics.
|
||||
|
||||
# Query:
|
||||
```Can you tell me some common types of renewable energy sources?```
|
||||
|
||||
# Answer:
|
||||
```
|
||||
Sure, here are some of the most common types of renewable energy sources:
|
||||
|
||||
1. Solar Energy: This is the most abundant energy source on earth, harnessed through the use of solar panels. These panels convert sunlight into electricity without any moving parts, noise, pollution, or damage to the environment.
|
||||
2. Wind Energy: Wind turbines convert the kinetic energy in the wind into mechanical power. This mechanical power can be used for specific tasks (such as pumping water) or converted into electricity to power homes, businesses, and schools.
|
||||
3. Hydropower: Generated by using electricity generators to capture the energy from falling or fast-moving water. This renewable source can come from various ways, including a large dam on a river, tidal or wave energy from the ocean, or using small scale turbines in streams.
|
||||
4. Geothermal Energy: This type of energy is generated from the heat deep within the Earth. This heat can be used directly for heating buildings or to generate electricity. It is continuously produced inside the Earth and is nearly as reliable as the tides.
|
||||
5. Biomass Energy: Biomass is organic material that comes from plants and animals, and it contains stored energy from the sun. This energy can be burned directly or converted into biofuel which can burn more efficiently.
|
||||
|
||||
Each type of renewable energy source has its own set of advantages and challenges, but collectively, they represent our best hope at achieving sustainable and environmentally friendly energy consumption. Please let me know if you have any other questions!
|
||||
```"""
|
||||
|
||||
def create_urial_prompt(message: str):
|
||||
return URIAL_1_SHOT + "\n\n# Query:\n```" + message + "```\n\n# Answer:\n```\n"
|
3
speculative_decoding/requirements.txt
Normal file
3
speculative_decoding/requirements.txt
Normal file
@ -0,0 +1,3 @@
|
||||
mlx>=0.0.5
|
||||
transformers
|
||||
numpy
|
25
speculative_decoding/test.py
Normal file
25
speculative_decoding/test.py
Normal file
@ -0,0 +1,25 @@
|
||||
import time
|
||||
from engine import LlamaEngine
|
||||
|
||||
# This will use the chat template from the primary model
|
||||
engine = LlamaEngine(
|
||||
model="meta-llama/Llama-2-7b-hf",
|
||||
draft_model="TinyLlama/TinyLlama-1.1B-intermediate-step-1195k-token-2.5T"
|
||||
)
|
||||
|
||||
messages = [
|
||||
{"role": "user", "content": "Finish the monologue: To be, or not to be..."}
|
||||
]
|
||||
|
||||
# Do 1 regular generation to get warmed up (the first one is slow)
|
||||
engine.generate(messages, num_tokens=1, temp=0.1)
|
||||
|
||||
# Time regular generation
|
||||
start = time.time()
|
||||
engine.generate(messages, num_tokens=125, temp=0.1)
|
||||
print(f"Regular generation took {time.time() - start} seconds")
|
||||
|
||||
# Time speculative decoding
|
||||
start = time.time()
|
||||
engine.speculative_decode(messages, num_tokens=125, temp=0.1, n=5)
|
||||
print(f"Speculative decoding took {time.time() - start} seconds")
|
Loading…
Reference in New Issue
Block a user