mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-08-31 11:54:37 +08:00
move code && small opt to find_draft
This commit is contained in:
parent
c73bec5598
commit
453ca97528
@ -53,18 +53,6 @@ class PromptLookupDecoder:
|
||||
print()
|
||||
print(f"=== GENERATED {n + 1} TOKENS IN {run_time} SECONDS ===")
|
||||
|
||||
"""
|
||||
Considerations:
|
||||
- If a match is found but we can't draft n_draft tokens, do we draft as
|
||||
many as we can or check for a match with a smaller ngram size?
|
||||
- How do we choose if there are multiple matches?
|
||||
|
||||
This implementation:
|
||||
- Ignores a match if we can't draft n_draft tokens. This avoids the risk
|
||||
of only drafting a few tokens.
|
||||
- We exit upon the first match. This avoids the need to rank matches.
|
||||
"""
|
||||
|
||||
def prompt_lookup(
|
||||
self,
|
||||
prompt: str,
|
||||
@ -104,13 +92,13 @@ class PromptLookupDecoder:
|
||||
|
||||
while True:
|
||||
# For each decoding step: generate n_draft tokens by searching the prompt
|
||||
def generate_draft(input_ids):
|
||||
def find_draft(input_ids):
|
||||
input_length = input_ids.size
|
||||
|
||||
for ngram_size in range(ngram_max, ngram_min, -1):
|
||||
ngram = input_ids[-ngram_size:]
|
||||
|
||||
for i in range(input_length - ngram_size):
|
||||
for i in range(input_length - (ngram_size * 2)):
|
||||
if mx.all(input_ids[i : i + ngram_size] == ngram):
|
||||
start_idx = i + ngram_size
|
||||
end_idx = start_idx + n_draft
|
||||
@ -119,7 +107,7 @@ class PromptLookupDecoder:
|
||||
|
||||
return mx.array([], dtype=mx.uint32)
|
||||
|
||||
draft_tokens = generate_draft(tokens)
|
||||
draft_tokens = find_draft(tokens)
|
||||
n_drafted += draft_tokens.size
|
||||
|
||||
# Verify draft tokens with the last verified token
|
Loading…
Reference in New Issue
Block a user