mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-09-01 04:14:38 +08:00
spec decode gives same results as regular decode
This commit is contained in:
@@ -1,16 +1,32 @@
|
||||
## Speculative Decoding with MLX
|
||||
# Speculative Decoding
|
||||
|
||||
This example implements [speculative decoding](https://arxiv.org/abs/2211.17192), which allows you to use a smaller draft model to predict several tokens, and then a larger verification model to check them all in parallel. The results are output that is identical to what the larger model would produce, but with far fewer forward passes (as long as the reference model is good enough at guessing).
|
||||
This example implements [speculative decoding] for text generation.[^1].
|
||||
Speculative decoding uses a smaller draft model to propose several tokens, and
|
||||
then a larger model which decides which tokens to accept. The generated text is
|
||||
identical to what the larger model would produce on its own, but with far fewer
|
||||
forward passes of the large model since it can evaluate the draft tokens in
|
||||
parallel.
|
||||
|
||||
### Setup
|
||||
|
||||
First, install the requirements:
|
||||
|
||||
Install the requirements and then you can try it out:
|
||||
```
|
||||
cd speculative_decoding
|
||||
pip install -r requirements.txt
|
||||
python test.py
|
||||
```
|
||||
|
||||
In order for that to happen, it's generally good if the models are trained on similar data, with a similar chat template, etc. For example, you could use Meta's 7B Llama as a draft model for the 13B Llama. In my tests, I've mostly used TinyLlama as a draft model for Llama-7B. The chat versions of TinyLlama and Llama-7B-Chat are trained with different templates, but it works OK. Alternatively, you can use base models, and a prompt to make the model act like a chat model (e.g. [URIAL](https://arxiv.org/abs/2312.01552)).
|
||||
### Run
|
||||
|
||||
I believe the implementation is *correct* (it produces near-identical output with regular generation vs. speculative decoding, and when speculative decoding is enabled, the draft model does correctly predict many tokens). However, it assumes a batch size of 1 at the moment (I'm not actually sure how to handle batching where some drafts might have more correct tokens than others). Also I feel like it could be faster!
|
||||
You can run with the default arguments:
|
||||
|
||||
Before merging this in, I would appreciate some help understanding how to make this faster and optimizing the performance so it's actually useful!
|
||||
```
|
||||
python main.py
|
||||
```
|
||||
|
||||
Speculative decoding works well when most of the tokens from the draft model
|
||||
are accepted by the larger model. That's more likely to happen if the models
|
||||
are trained on similar data. The default setting in this example uses TinyLlama
|
||||
as a draft morel for Llama 7B.
|
||||
|
||||
[^1] See the paper [Fast Inference from Transformers via Speculative Decoding]((https://arxiv.org/abs/2211.17192)
|
||||
|
160
speculative_decoding/decoder.py
Normal file
160
speculative_decoding/decoder.py
Normal file
@@ -0,0 +1,160 @@
|
||||
import transformers
|
||||
from dataclasses import dataclass, field
|
||||
from model import Llama
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
import time
|
||||
import numpy as np
|
||||
from typing import List, Optional
|
||||
|
||||
from prompts import create_urial_prompt
|
||||
|
||||
|
||||
class Tokenizer:
|
||||
def __init__(self, model_name: str):
|
||||
self._tokenizer = transformers.AutoTokenizer.from_pretrained(model_name)
|
||||
|
||||
@property
|
||||
def eos_id(self) -> int:
|
||||
return self._tokenizer.eos_token_id
|
||||
|
||||
def encode(self, s: str) -> mx.array:
|
||||
return mx.array(
|
||||
self._tokenizer(s, return_tensors="np", return_attention_mask=False,)[
|
||||
"input_ids"
|
||||
].squeeze(0)
|
||||
)
|
||||
|
||||
def decode(self, t: List[int], with_sep: bool = True) -> str:
|
||||
tokens = self._tokenizer.convert_ids_to_tokens(t)
|
||||
return "".join(t.replace("▁", " " if with_sep else "") for t in tokens)
|
||||
|
||||
|
||||
class SpeculativeDecoder:
|
||||
def __init__(self, model: str, draft_model: str = None):
|
||||
self.tokenizer = Tokenizer(model)
|
||||
self.model = Llama.from_hugging_face(model)
|
||||
if draft_model is not None:
|
||||
self.draft_model = Llama.from_hugging_face(draft_model)
|
||||
|
||||
def tokenize(self, prompt):
|
||||
# if self.tokenizer.chat_template is not None:
|
||||
# tokenized = self.tokenizer.apply_chat_template(
|
||||
# prompt, tokenize=True, add_generation_prompt=True
|
||||
# )
|
||||
# else:
|
||||
# use urial zero-shot template
|
||||
tokenized = self.tokenizer.encode(create_urial_prompt(prompt["content"]))
|
||||
return tokenized
|
||||
|
||||
def _generate(
|
||||
self,
|
||||
x: mx.array,
|
||||
temp: float = 0.0,
|
||||
draft: bool = False,
|
||||
):
|
||||
model = self.draft_model if draft else self.model
|
||||
|
||||
def sample(logits):
|
||||
if temp == 0:
|
||||
return mx.argmax(logits, axis=-1)
|
||||
else:
|
||||
return mx.random.categorical(logits * (1 / temp))
|
||||
|
||||
while True:
|
||||
logit = model(x[None, :], next_token_only=True)
|
||||
x = sample(logit)
|
||||
yield x
|
||||
|
||||
def generate(
|
||||
self,
|
||||
prompt,
|
||||
max_tokens: int = 100,
|
||||
temp: float = 0.0,
|
||||
draft: bool = False,
|
||||
):
|
||||
x = self.tokenize(prompt)
|
||||
start = time.time()
|
||||
for token, n in zip(self._generate(x, temp, draft=draft), range(max_tokens)):
|
||||
token = token.item()
|
||||
if token == self.tokenizer.eos_id:
|
||||
break
|
||||
print(self.tokenizer.decode(token, with_sep=n > 0), end="", flush=True)
|
||||
run_time = time.time() - start
|
||||
print()
|
||||
print(f"=== GENERATED {n + 1} TOKENS in {run_time} SECONDS ===")
|
||||
if draft:
|
||||
self.draft_model.reset_cache()
|
||||
else:
|
||||
self.model.reset_cache()
|
||||
|
||||
def speculative_decode(
|
||||
self, prompt, max_tokens: int = 100, temp: float = 0.0, n_draft: int = 5
|
||||
):
|
||||
def sample(logits):
|
||||
if temp == 0:
|
||||
return mx.argmax(logits, axis=-1)
|
||||
else:
|
||||
return mx.random.categorical(logits * (1 / temp))
|
||||
|
||||
tokens = mx.array(self.tokenize(prompt), mx.uint32)
|
||||
start = time.time()
|
||||
|
||||
decoding_steps = 0
|
||||
ntoks = 0
|
||||
accepted_draft_tokens = 0
|
||||
|
||||
while True:
|
||||
# For each decoding step: generate n tokens from a draft model
|
||||
draft_tokens = []
|
||||
for _, t in zip(
|
||||
range(ntoks, min(ntoks + n_draft, max_tokens)),
|
||||
self._generate(tokens, temp=temp, draft=True),
|
||||
):
|
||||
draft_tokens.append(t)
|
||||
if t.item() == self.tokenizer.eos_id:
|
||||
break
|
||||
|
||||
# Verify the draft tokens with the last verified token
|
||||
draft_tokens = mx.concatenate(draft_tokens)
|
||||
verify_tokens = mx.concatenate([tokens, draft_tokens])
|
||||
logits = self.model(verify_tokens[None, :-1])
|
||||
sampled = sample(logits[:, -draft_tokens.size :]).squeeze(0)
|
||||
|
||||
# Only keep samples that match the draft:
|
||||
equal_toks = sampled == draft_tokens
|
||||
num_to_accept = (equal_toks.tolist() + [False]).index(False)
|
||||
new_tokens = sampled[: max(1, num_to_accept)]
|
||||
|
||||
accepted_draft_tokens += num_to_accept
|
||||
|
||||
# Rewind the cache for unaccepted tokens:
|
||||
if (n := draft_tokens.size) > num_to_accept:
|
||||
self.draft_model.truncate_cache(n - new_tokens.size)
|
||||
self.model.truncate_cache(n - new_tokens.size)
|
||||
|
||||
decoding_steps += 1
|
||||
|
||||
# Check stop decodig criteria:
|
||||
for t in new_tokens.tolist():
|
||||
if t == self.tokenizer.eos_id:
|
||||
break
|
||||
print(self.tokenizer.decode(t, with_sep=ntoks > 0), end="", flush=True)
|
||||
ntoks += new_tokens.size
|
||||
if ntoks >= max_tokens or new_tokens[-1] == self.tokenizer.eos_id:
|
||||
break
|
||||
tokens = new_tokens[-1:]
|
||||
|
||||
end = time.time()
|
||||
self.model.reset_cache()
|
||||
self.draft_model.reset_cache()
|
||||
print()
|
||||
print(
|
||||
"=== GENERATED",
|
||||
ntoks,
|
||||
"TOKENS IN",
|
||||
round(end - start, 2),
|
||||
"SECONDS ===",
|
||||
)
|
||||
print("=== ACCEPTED", accepted_draft_tokens, "DRAFT TOKENS ===")
|
||||
print("=== DECODING STEPS", decoding_steps, "===")
|
@@ -1,208 +0,0 @@
|
||||
import transformers
|
||||
from dataclasses import dataclass, field
|
||||
from model import Llama
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
import time
|
||||
import numpy as np
|
||||
from typing import Optional
|
||||
|
||||
from prompts import create_urial_prompt
|
||||
|
||||
|
||||
def clone(x: mx.array):
|
||||
return mx.array(np.array(x))
|
||||
|
||||
|
||||
@dataclass
|
||||
class LlamaEngine:
|
||||
model: str # path to HuggingFace repo
|
||||
draft_model: Optional[str] = None # path to draft model
|
||||
tokenizer: transformers.AutoTokenizer = field(init=False)
|
||||
|
||||
def __post_init__(self):
|
||||
self.tokenizer = transformers.AutoTokenizer.from_pretrained(self.model)
|
||||
self.model = Llama.from_hugging_face(self.model)
|
||||
if self.draft_model is not None:
|
||||
self.draft_model = Llama.from_hugging_face(self.draft_model)
|
||||
|
||||
def tokenize(self, messages):
|
||||
if self.tokenizer.chat_template is not None:
|
||||
tokenized = self.tokenizer.apply_chat_template(
|
||||
messages, tokenize=True, add_generation_prompt=True
|
||||
)
|
||||
else:
|
||||
# use urial zero-shot template
|
||||
tokenized = self.tokenizer.encode(
|
||||
create_urial_prompt(messages[0]["content"])
|
||||
)
|
||||
|
||||
return tokenized
|
||||
|
||||
def generate(
|
||||
self,
|
||||
messages,
|
||||
num_tokens: int = 100,
|
||||
temp: float = 0.8,
|
||||
draft_model: bool = False, # if true gen with draft model
|
||||
):
|
||||
tokenized = self.tokenize(messages)
|
||||
x = mx.array([tokenized])
|
||||
tokens = []
|
||||
start = time.time()
|
||||
for token in self.model.generate(x, temp):
|
||||
if token.item() == 2:
|
||||
break
|
||||
tokens.append(token)
|
||||
if len(tokens) >= num_tokens:
|
||||
break
|
||||
run_time = time.time() - start
|
||||
tokens = [t.item() for t in tokens]
|
||||
s = self.tokenizer.decode(tokens)
|
||||
# print("=== COMPLETION ===")
|
||||
# print(s)
|
||||
print(f"=== GENERATED {len(tokens)} TOKENS in {run_time} SECONDS ===")
|
||||
return s
|
||||
|
||||
# generate only supports batch size 1, so should this
|
||||
def speculative_decode(
|
||||
self, messages, num_tokens: int = 100, temp: float = 0.8, n: int = 5
|
||||
):
|
||||
def sample(logits):
|
||||
if temp == 0:
|
||||
return mx.argmax(logits, axis=-1)
|
||||
else:
|
||||
return mx.random.categorical(logits * (1 / temp))
|
||||
|
||||
batch_size = 1
|
||||
tokenized = self.tokenize(messages)
|
||||
tokens = mx.array([tokenized])
|
||||
prompt_len = tokens.shape[1]
|
||||
start = time.time()
|
||||
|
||||
# prefill the main model
|
||||
# sample first token & write draft from there (avoids rewinding main model which i was doing before)
|
||||
logit = self.model(
|
||||
tokens, read_cache=False, write_cache=True, next_token_only=True
|
||||
)
|
||||
first_token = sample(logit).reshape(batch_size, 1)
|
||||
tokens = mx.concatenate([tokens, first_token], axis=1)
|
||||
|
||||
decoding_steps = 1
|
||||
n_new_tokens = 1
|
||||
accepted_draft_tokens = 0
|
||||
draft_logit = self.draft_model(
|
||||
tokens, read_cache=False, write_cache=True, next_token_only=True
|
||||
)
|
||||
# print("Before doing any speculative decoding, draft model cache is: ", self.draft_model.kv_cache[0][0].shape[2])
|
||||
# print("And prompt length is: ", prompt_len)
|
||||
while True:
|
||||
# for each decoding step: generate n tokens from a draft model
|
||||
draft_tokens = sample(draft_logit).reshape(batch_size, 1)
|
||||
draft_tokens_left = n - 1
|
||||
for (
|
||||
token
|
||||
) in self.draft_model.generate( # generate automatically updates the cache, it has to
|
||||
draft_tokens, temp=temp, read_cache=True
|
||||
):
|
||||
draft_tokens = mx.concatenate(
|
||||
[draft_tokens, token.reshape(batch_size, 1)], axis=1
|
||||
)
|
||||
draft_tokens_left -= 1
|
||||
if draft_tokens_left == 0:
|
||||
break
|
||||
|
||||
# have to verify the first draft token using the last verified token
|
||||
verify_tokens = mx.concatenate([tokens[:, -1:], draft_tokens], axis=1)
|
||||
# print("Tokens so far: ", self.tokenizer.decode(np.array(tokens[0, prompt_len:]).tolist()))
|
||||
# print("Predicted draft tokens: [", self.tokenizer.decode(np.array(draft_tokens[0, :]).tolist()), "]")
|
||||
logits = self.model(verify_tokens, read_cache=True, write_cache=True)
|
||||
# check the last n + 1 tokens
|
||||
sampled = sample(logits[:, -(n + 1) :, :])
|
||||
# print("Sampled tokens: [", self.tokenizer.decode(np.array(sampled[0, :]).tolist()), "]")
|
||||
|
||||
# only keep samples that match the draft
|
||||
num_to_accept = 0
|
||||
for i in range(n):
|
||||
if mx.all(sampled[:, i] == draft_tokens[:, i]):
|
||||
num_to_accept += 1
|
||||
else:
|
||||
break
|
||||
# print("Accepting", num_to_accept)
|
||||
accepted_draft_tokens += num_to_accept
|
||||
n_new_tokens += 1 + num_to_accept
|
||||
|
||||
new_tokens = sampled[:, : num_to_accept + 1]
|
||||
tokens = mx.concatenate([tokens, new_tokens], axis=1)
|
||||
|
||||
# truncate draft cache to keep only accepted tokens
|
||||
# what tokens have been INPUT into the draft model? let's say n = 5, start with |p| tokens
|
||||
# |p| -> t0; |p + t0| -> t1; |p + t0 + t1| -> t2; |p + t0 + t1 + t2| -> t3; |p + t0 + t1 + t2 + t3| -> t4;
|
||||
# return -> t0 - t4, cache has |p + t0 + t1 + t2 + t3|
|
||||
# main model accepts whatever is correct, then generates t'
|
||||
# if 0 accepted: cache should have |p + t'|
|
||||
# if 1 accepted: |p + t0 + t'|
|
||||
# if 2 accepted: |p + t0 + t1 + t'|
|
||||
# ...
|
||||
# if 5 accepted: |p + t0 + t1 + t2 + t3 + t4 + t'|
|
||||
# we're always going to have to show the draft model the 1 token where it went off
|
||||
# the rails and we rejected it and took the real model, cause that won't be in its cache
|
||||
# print("After speculative decoding, before truncation, draft cache has: ", self.draft_model.kv_cache[0][0].shape[2])
|
||||
if num_to_accept < n:
|
||||
self.draft_model.truncate_kv_cache(n - 1 - num_to_accept)
|
||||
# print("Truncated draft cache by", n - 1 - num_to_accept, "now it has", self.draft_model.kv_cache[0][0].shape[2])
|
||||
elif num_to_accept == n:
|
||||
# forward the model on the last draft token to catch it up
|
||||
# maybe this is bad?
|
||||
self.draft_model(
|
||||
draft_tokens[:, -1:],
|
||||
read_cache=True,
|
||||
write_cache=True,
|
||||
next_token_only=True,
|
||||
)
|
||||
|
||||
# now how to truncate the full model's cache?
|
||||
# it has |p + t0 + t1 + t2 + t3 + t4|
|
||||
# if 0 accepted: truncate back to p
|
||||
# if 1 accepted: truncate to p + t0
|
||||
self.model.truncate_kv_cache(
|
||||
n - num_to_accept
|
||||
) # how many to truncate? i think 1 off from draft model? idk
|
||||
# NOTE: main model doesn't know that it predicted t' (the 1 non-draft token)
|
||||
# i think this is ok because it's the last accepted token and will be passed back in at verification time
|
||||
|
||||
# NOTE: model is now (or could be!) 1 token ahead of draft model cause if it accepts the full
|
||||
# draft it's now predicted 1 token past the draft token's last token. must account for this.
|
||||
|
||||
# print("Length of accepted tokens: ", tokens.shape[1])
|
||||
# print("Length of draft model cache: ", self.draft_model.kv_cache[0][0].shape[2])
|
||||
# print("Length of main model cache: ", self.model.kv_cache[0][0].shape[2])
|
||||
decoding_steps += 1
|
||||
|
||||
if n_new_tokens >= num_tokens or mx.any(new_tokens == 2):
|
||||
break
|
||||
|
||||
# get the next draft token based on t', preparing to do it all again!
|
||||
# print("Getting the token that comes after: ", self.tokenizer.decode(np.array(tokens[0, -1:]).tolist()))
|
||||
draft_logit = self.draft_model(
|
||||
tokens[:, -1:], read_cache=True, write_cache=True, next_token_only=True
|
||||
)
|
||||
|
||||
mx.eval(tokens)
|
||||
end = time.time()
|
||||
|
||||
seq = np.array(tokens[0, :]).tolist()
|
||||
s = self.tokenizer.decode(seq[prompt_len:])
|
||||
# print(f"=== COMPLETION {0 + 1} ===")
|
||||
# print(s)
|
||||
|
||||
print(
|
||||
"=== GENERATED",
|
||||
n_new_tokens,
|
||||
"TOKENS IN",
|
||||
round(end - start, 2),
|
||||
"SECONDS ===",
|
||||
)
|
||||
print("=== ACCEPTED", accepted_draft_tokens, "DRAFT TOKENS ===")
|
||||
print("=== DECODING STEPS", decoding_steps, "===")
|
||||
return s
|
21
speculative_decoding/main.py
Normal file
21
speculative_decoding/main.py
Normal file
@@ -0,0 +1,21 @@
|
||||
import time
|
||||
from decoder import SpeculativeDecoder
|
||||
|
||||
# This will use the chat template from the primary model
|
||||
engine = SpeculativeDecoder(
|
||||
# model="TinyLlama/TinyLlama-1.1B-intermediate-step-1195k-token-2.5T",
|
||||
model="meta-llama/Llama-2-7b-hf",
|
||||
draft_model="TinyLlama/TinyLlama-1.1B-intermediate-step-1195k-token-2.5T",
|
||||
)
|
||||
|
||||
messages = {"role": "user", "content": "Finish the monologue: To be, or not to be..."}
|
||||
|
||||
# Do 1 regular generation to get warmed up (the first one is slow)
|
||||
engine.generate(messages, max_tokens=1)
|
||||
engine.generate(messages, max_tokens=1, draft=True)
|
||||
|
||||
# Time regular generation
|
||||
engine.generate(messages, max_tokens=125)
|
||||
|
||||
# Time speculative decoding
|
||||
engine.speculative_decode(messages, max_tokens=125, n_draft=10)
|
@@ -7,6 +7,14 @@ import mlx.nn as nn
|
||||
from typing import Optional, Tuple
|
||||
|
||||
|
||||
def create_additive_causal_mask(N: int, offset: int = 0, dtype: mx.Dtype = mx.float32):
|
||||
rinds = mx.arange(offset + N)
|
||||
linds = mx.arange(offset, offset + N) if offset else rinds
|
||||
mask = linds[:, None] < rinds[None]
|
||||
mask = mask.astype(dtype) * -1e9
|
||||
return mask
|
||||
|
||||
|
||||
class RMSNorm(nn.Module):
|
||||
def __init__(self, dims: int, eps: float = 1e-5):
|
||||
super().__init__()
|
||||
@@ -29,7 +37,6 @@ class Attention(nn.Module):
|
||||
self.n_heads: int = config.num_attention_heads
|
||||
self.n_kv_heads: int = config.num_key_value_heads
|
||||
self.repeats = self.n_heads // self.n_kv_heads
|
||||
# print("heads", self.n_heads, "kv heads", self.n_kv_heads, "repeats", self.repeats)
|
||||
self.head_dim = config.hidden_size // self.n_heads
|
||||
self.scale = self.head_dim**-0.5
|
||||
|
||||
@@ -63,12 +70,8 @@ class Attention(nn.Module):
|
||||
def repeat(a):
|
||||
a = mx.concatenate([mx.expand_dims(a, 2)] * self.repeats, axis=2)
|
||||
kv_size = a.shape[-1]
|
||||
# can't use the L from x here, this is like cross-attention during decoding
|
||||
return a.reshape([B, self.n_heads, -1, kv_size])
|
||||
|
||||
# cache should be with unrepeated kv, otherwise GQA is pointless lol
|
||||
# keys, values = map(repeat, (keys, values))
|
||||
|
||||
if cache is not None:
|
||||
key_cache, value_cache = cache
|
||||
queries = self.rope(queries, offset=key_cache.shape[2])
|
||||
@@ -79,17 +82,9 @@ class Attention(nn.Module):
|
||||
queries = self.rope(queries)
|
||||
keys = self.rope(keys)
|
||||
|
||||
# print("queries shape", queries.shape, "keys shape", keys.shape, "values shape", values.shape)
|
||||
|
||||
scores = (queries * self.scale) @ repeat(keys).transpose(0, 1, 3, 2)
|
||||
if mask is not None:
|
||||
# print("we need to add mask of shape", mask.shape, "to scores of shape", scores.shape)
|
||||
if cache is None:
|
||||
scores += mask
|
||||
else:
|
||||
# we're doing "cross-attn"; add mask to the "end" of the attn matrix along the K dimension
|
||||
a, b = mx.split(scores, indices_or_sections=[-mask.shape[-1]], axis=-1)
|
||||
scores = mx.concatenate([a, b + mask], axis=-1)
|
||||
scores += mask
|
||||
scores = mx.softmax(scores.astype(mx.float32), axis=-1).astype(scores.dtype)
|
||||
output = (scores @ repeat(values)).transpose(0, 2, 1, 3).reshape(B, L, -1)
|
||||
return self.o_proj(output), (keys, values)
|
||||
@@ -148,75 +143,55 @@ class Llama(nn.Module):
|
||||
]
|
||||
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
||||
self.kv_cache = []
|
||||
self.reset_cache()
|
||||
|
||||
def truncate_kv_cache(self, num_to_truncate):
|
||||
def truncate_cache(self, num_to_truncate):
|
||||
cache_length = self.kv_cache[0][0].shape[2]
|
||||
num_to_truncate = min(num_to_truncate, cache_length)
|
||||
if num_to_truncate == 0:
|
||||
return False
|
||||
else:
|
||||
if num_to_truncate < cache_length:
|
||||
self.kv_cache = tree_map(
|
||||
lambda x: x[:, :, :-num_to_truncate, :], self.kv_cache
|
||||
)
|
||||
return True
|
||||
else:
|
||||
self.reset_cache()
|
||||
|
||||
def reset_cache(self):
|
||||
self.kv_cache = [None] * len(self.layers)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
x: mx.array,
|
||||
read_cache: bool = False,
|
||||
write_cache: bool = False,
|
||||
next_token_only: bool = False,
|
||||
):
|
||||
mask = nn.MultiHeadAttention.create_additive_causal_mask(x.shape[1])
|
||||
mask = mask.astype(self.embed_tokens.weight.dtype)
|
||||
if read_cache and len(self.kv_cache) != len(self.layers):
|
||||
raise RuntimeError(
|
||||
f"Length of cache ({len(self.kv_cache)}) must match number of layers ({len(self.layers)})"
|
||||
)
|
||||
if self.kv_cache[0]:
|
||||
offset = self.kv_cache[0][0].shape[-2]
|
||||
else:
|
||||
offset = 0
|
||||
|
||||
if x.shape[1] > 1:
|
||||
mask = create_additive_causal_mask(x.shape[1], offset)
|
||||
mask = mask.astype(self.embed_tokens.weight.dtype)
|
||||
else:
|
||||
mask = None
|
||||
|
||||
x = self.embed_tokens(x)
|
||||
for idx, layer in enumerate(self.layers):
|
||||
x, c = layer(x, mask, cache=self.kv_cache[idx] if read_cache else None)
|
||||
if write_cache:
|
||||
if len(self.kv_cache) == 0:
|
||||
self.kv_cache = [None] * len(self.layers)
|
||||
self.kv_cache[idx] = c
|
||||
x = self.norm(x)
|
||||
x, self.kv_cache[idx] = layer(x, mask, cache=self.kv_cache[idx])
|
||||
|
||||
if next_token_only:
|
||||
x = x[:, -1]
|
||||
|
||||
x = self.norm(x)
|
||||
return self.lm_head(x)
|
||||
|
||||
@classmethod
|
||||
def from_hugging_face(cls, model_path: str):
|
||||
config = LlamaConfig.from_pretrained(model_path)
|
||||
torch_weights = AutoModelForCausalLM.from_pretrained(model_path).state_dict()
|
||||
mx_weights = {
|
||||
k.replace("model.", ""): mx.array(v.numpy())
|
||||
weights = {
|
||||
k.replace("model.", ""): mx.array(v.numpy(), mx.float16)
|
||||
for k, v in torch_weights.items()
|
||||
}
|
||||
for k in mx_weights.keys():
|
||||
mx_weights[k] = mx_weights[k].astype(mx.float16)
|
||||
mlx_model = cls(config)
|
||||
mlx_model.update(tree_unflatten(list(mx_weights.items())))
|
||||
|
||||
return mlx_model
|
||||
|
||||
def generate(self, x: mx.array, temp=0.0, read_cache: bool = False):
|
||||
# Make an additive causal mask. We will need that to process the prompt.
|
||||
def sample(logits):
|
||||
if temp == 0:
|
||||
return mx.argmax(logits, axis=-1)
|
||||
else:
|
||||
return mx.random.categorical(logits * (1 / temp))
|
||||
|
||||
mask = nn.MultiHeadAttention.create_additive_causal_mask(x.shape[1])
|
||||
mask = mask.astype(self.embed_tokens.weight.dtype)
|
||||
|
||||
logit = self(x, read_cache=read_cache, write_cache=True, next_token_only=True)
|
||||
tok = sample(logit)
|
||||
yield tok
|
||||
while True:
|
||||
x = tok.reshape(-1, 1)
|
||||
logit = self(x, read_cache=True, write_cache=True, next_token_only=True)
|
||||
tok = sample(logit)
|
||||
yield tok
|
||||
model = cls(config)
|
||||
model.update(tree_unflatten(list(weights.items())))
|
||||
mx.eval(model.parameters())
|
||||
return model
|
||||
|
@@ -1,25 +0,0 @@
|
||||
import time
|
||||
from engine import LlamaEngine
|
||||
|
||||
# This will use the chat template from the primary model
|
||||
engine = LlamaEngine(
|
||||
model="meta-llama/Llama-2-7b-hf",
|
||||
draft_model="TinyLlama/TinyLlama-1.1B-intermediate-step-1195k-token-2.5T"
|
||||
)
|
||||
|
||||
messages = [
|
||||
{"role": "user", "content": "Finish the monologue: To be, or not to be..."}
|
||||
]
|
||||
|
||||
# Do 1 regular generation to get warmed up (the first one is slow)
|
||||
engine.generate(messages, num_tokens=1, temp=0.1)
|
||||
|
||||
# Time regular generation
|
||||
start = time.time()
|
||||
engine.generate(messages, num_tokens=125, temp=0.1)
|
||||
print(f"Regular generation took {time.time() - start} seconds")
|
||||
|
||||
# Time speculative decoding
|
||||
start = time.time()
|
||||
engine.speculative_decode(messages, num_tokens=125, temp=0.1, n=5)
|
||||
print(f"Speculative decoding took {time.time() - start} seconds")
|
Reference in New Issue
Block a user