mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-24 01:17:28 +08:00

* Made mypy compatible * reformatted * Added more fixes * Added fixes to speculative-decoding * Fixes * fix circle * revert some stuff --------- Co-authored-by: Awni Hannun <awni@apple.com>
193 lines
6.0 KiB
Python
193 lines
6.0 KiB
Python
from typing import List
|
|
|
|
import mlx.core as mx
|
|
import transformers
|
|
from model import Model
|
|
|
|
|
|
class Tokenizer:
|
|
def __init__(self, model_name: str):
|
|
self._tokenizer = transformers.AutoTokenizer.from_pretrained(
|
|
model_name,
|
|
legacy=False,
|
|
model_max_length=512,
|
|
)
|
|
self._decoder_start_id = 0
|
|
|
|
@property
|
|
def eos_id(self) -> int:
|
|
return self._tokenizer.eos_token_id
|
|
|
|
@property
|
|
def decoder_start_id(self) -> int:
|
|
return self._decoder_start_id
|
|
|
|
def encode(self, s: str) -> mx.array:
|
|
return mx.array(
|
|
self._tokenizer(
|
|
s,
|
|
return_tensors="np",
|
|
return_attention_mask=False,
|
|
)[
|
|
"input_ids"
|
|
].squeeze(0)
|
|
)
|
|
|
|
def decode(self, t: List[int]) -> str:
|
|
return self._tokenizer.decode(t)
|
|
|
|
|
|
class SpeculativeDecoder:
|
|
def __init__(
|
|
self,
|
|
model: Model,
|
|
draft_model: Model,
|
|
tokenizer: str,
|
|
num_draft: int = 5,
|
|
delta: float = 0.0,
|
|
):
|
|
self.tokenizer = Tokenizer(tokenizer)
|
|
self.model = model
|
|
self.draft_model = draft_model
|
|
self.num_draft = num_draft
|
|
self.delta = delta
|
|
|
|
def _generate(
|
|
self,
|
|
x: mx.array,
|
|
memory: mx.array,
|
|
draft: bool = False,
|
|
):
|
|
model = self.draft_model if draft else self.model
|
|
while True:
|
|
logits = model.decode(x[None], memory)[0, -1]
|
|
x = mx.argmax(logits, keepdims=True)
|
|
lognorm = mx.logsumexp(logits.astype(mx.float32))
|
|
logprob = logits[x] - lognorm
|
|
yield x, logprob
|
|
|
|
def generate(
|
|
self,
|
|
prompt,
|
|
max_tokens: int = 100,
|
|
):
|
|
memory = self.model.encode(self.tokenizer.encode(prompt)[None])
|
|
x = mx.array([self.tokenizer.decoder_start_id])
|
|
skip = 0
|
|
outputs = []
|
|
for (token, _), n in zip(self._generate(x, memory), range(max_tokens)):
|
|
if token == self.tokenizer.eos_id:
|
|
break
|
|
outputs.append(token.item())
|
|
if (n + 1) % 10 == 0:
|
|
str_output = self.tokenizer.decode(outputs)
|
|
print(str_output[skip:], end="", flush=True)
|
|
skip = len(str_output)
|
|
|
|
print(self.tokenizer.decode(outputs)[skip:], end="", flush=True)
|
|
print()
|
|
self.model.reset_cache()
|
|
|
|
def _get_num_accept(self, draft_tokens, draft_probs, model_logits):
|
|
# accept_toks = mx.argmax(model_logits, axis=-1) == draft_tokens
|
|
model_probs = mx.take_along_axis(
|
|
model_logits,
|
|
draft_tokens[:, None],
|
|
axis=-1,
|
|
).squeeze(-1)
|
|
model_probs -= mx.logsumexp(model_logits.astype(mx.float32), axis=-1)
|
|
unis = mx.random.uniform(shape=(draft_tokens.size,))
|
|
log_unis = mx.log(mx.maximum(unis - self.delta, 0.0))
|
|
accept_toks = log_unis <= ((model_probs - draft_probs))
|
|
num_to_accept = (accept_toks.tolist() + [False]).index(False)
|
|
return num_to_accept
|
|
|
|
def speculative_decode(
|
|
self,
|
|
prompt,
|
|
max_tokens: int = 100,
|
|
):
|
|
def sample(logits):
|
|
return mx.argmax(logits, axis=-1)
|
|
|
|
prompt = mx.array(self.tokenizer.encode(prompt), mx.uint32)[None]
|
|
memory = self.model.encode(prompt)
|
|
draft_memory = self.draft_model.encode(prompt)
|
|
|
|
tokens = mx.array([self.tokenizer.decoder_start_id])
|
|
|
|
n_steps = 0
|
|
ntoks = 0
|
|
n_accepted = 0
|
|
n_draft = 0
|
|
|
|
outputs = []
|
|
skip = 0
|
|
draft_inputs = tokens
|
|
inputs = tokens
|
|
while True:
|
|
# For each decoding step: generate n tokens from a draft model
|
|
draft_tokens = []
|
|
draft_probs = []
|
|
for _, (t, p) in zip(
|
|
range(ntoks, min(ntoks + self.num_draft, max_tokens)),
|
|
self._generate(draft_inputs, draft_memory, draft=True),
|
|
):
|
|
draft_tokens.append(t)
|
|
draft_probs.append(p)
|
|
if t.item() == self.tokenizer.eos_id:
|
|
break
|
|
|
|
# Verify the draft tokens with the last verified token:
|
|
draft_tokens = mx.concatenate(draft_tokens)
|
|
draft_probs = mx.concatenate(draft_probs)
|
|
verify_tokens = mx.concatenate([inputs, draft_tokens])
|
|
logits = self.model.decode(
|
|
verify_tokens[None, :],
|
|
memory,
|
|
).squeeze(0)
|
|
|
|
# Only keep samples that match the draft:
|
|
num_to_accept = self._get_num_accept(
|
|
draft_tokens,
|
|
draft_probs,
|
|
logits[:-1],
|
|
)
|
|
new_tokens = draft_tokens[:num_to_accept]
|
|
# Get the next token from the main model as well
|
|
new_tokens = mx.concatenate(
|
|
[new_tokens, mx.argmax(logits[num_to_accept], keepdims=True)]
|
|
)
|
|
|
|
n_accepted += num_to_accept
|
|
n_draft += len(draft_tokens)
|
|
|
|
# Rewind the cache for unaccepted tokens:
|
|
if (n := len(draft_tokens)) > num_to_accept:
|
|
self.draft_model.truncate_cache(n - len(new_tokens))
|
|
self.model.truncate_cache(n - len(new_tokens) + 1)
|
|
|
|
n_steps += 1
|
|
|
|
for t in new_tokens.tolist():
|
|
if t == self.tokenizer.eos_id or ntoks >= max_tokens:
|
|
break
|
|
outputs.append(t)
|
|
ntoks += 1
|
|
|
|
str_output = self.tokenizer.decode(outputs)
|
|
print(str_output[skip:], end="", flush=True)
|
|
skip = len(str_output)
|
|
|
|
if ntoks >= max_tokens or new_tokens[-1] == self.tokenizer.eos_id:
|
|
break
|
|
draft_inputs = new_tokens[max(len(new_tokens) - 2, 0) :]
|
|
inputs = draft_inputs[-1:]
|
|
|
|
print(self.tokenizer.decode(outputs)[skip:], end="", flush=True)
|
|
print()
|
|
|
|
self.model.reset_cache()
|
|
self.draft_model.reset_cache()
|
|
return {"n_accepted": n_accepted, "n_draft": n_draft, "n_steps": n_steps}
|