move code && small opt to find_draft

This commit is contained in:
Leon Ericsson 2023-12-29 12:18:09 +01:00
parent c73bec5598
commit 453ca97528
6 changed files with 3 additions and 15 deletions

View File

@ -53,18 +53,6 @@ class PromptLookupDecoder:
print()
print(f"=== GENERATED {n + 1} TOKENS IN {run_time} SECONDS ===")
"""
Considerations:
- If a match is found but we can't draft n_draft tokens, do we draft as
many as we can or check for a match with a smaller ngram size?
- How do we choose if there are multiple matches?
This implementation:
- Ignores a match if we can't draft n_draft tokens. This avoids the risk
of only drafting a few tokens.
- We exit upon the first match. This avoids the need to rank matches.
"""
def prompt_lookup(
self,
prompt: str,
@ -104,13 +92,13 @@ class PromptLookupDecoder:
while True:
# For each decoding step: generate n_draft tokens by searching the prompt
def generate_draft(input_ids):
def find_draft(input_ids):
input_length = input_ids.size
for ngram_size in range(ngram_max, ngram_min, -1):
ngram = input_ids[-ngram_size:]
for i in range(input_length - ngram_size):
for i in range(input_length - (ngram_size * 2)):
if mx.all(input_ids[i : i + ngram_size] == ngram):
start_idx = i + ngram_size
end_idx = start_idx + n_draft
@ -119,7 +107,7 @@ class PromptLookupDecoder:
return mx.array([], dtype=mx.uint32)
draft_tokens = generate_draft(tokens)
draft_tokens = find_draft(tokens)
n_drafted += draft_tokens.size
# Verify draft tokens with the last verified token