mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-09-01 04:14:38 +08:00
format
This commit is contained in:
@@ -93,19 +93,28 @@ class PromptLookupDecoder:
|
||||
while True:
|
||||
# For each decoding step: generate n_draft tokens by searching the prompt
|
||||
def find_draft(input_ids):
|
||||
input_length = input_ids.size
|
||||
ngram = input_ids[-ngram_max:]
|
||||
|
||||
for ngram_size in range(ngram_max, ngram_min, -1):
|
||||
ngram = input_ids[-ngram_size:]
|
||||
largest_match = 0
|
||||
candidate = mx.array([], dtype=mx.uint32)
|
||||
for i in range(input_ids.size - (ngram_max * 2)):
|
||||
matches = input_ids[i : i + ngram_max] == ngram
|
||||
|
||||
for i in range(input_length - (ngram_size * 2)):
|
||||
if mx.all(input_ids[i : i + ngram_size] == ngram):
|
||||
start_idx = i + ngram_size
|
||||
end_idx = start_idx + n_draft
|
||||
if start_idx < input_length - ngram_size:
|
||||
return input_ids[start_idx:end_idx]
|
||||
# reverse through the matches array
|
||||
match_length = 0
|
||||
for j in range(matches.size - 1, -1, -1):
|
||||
if matches[j]:
|
||||
match_length += 1
|
||||
else:
|
||||
break
|
||||
|
||||
return mx.array([], dtype=mx.uint32)
|
||||
if match_length > ngram_min and match_length > largest_match:
|
||||
largest_match = match_length
|
||||
start_idx = i + ngram_max
|
||||
end_idx = start_idx + n_draft
|
||||
candidate = input_ids[start_idx:end_idx]
|
||||
|
||||
return candidate
|
||||
|
||||
draft_tokens = find_draft(tokens)
|
||||
n_drafted += draft_tokens.size
|
||||
|
Reference in New Issue
Block a user