This commit is contained in:
Leon Ericsson
2023-12-29 13:19:28 +01:00
parent 453ca97528
commit 3f750759d3

View File

@@ -93,19 +93,28 @@ class PromptLookupDecoder:
while True:
# For each decoding step: generate n_draft tokens by searching the prompt
def find_draft(input_ids):
input_length = input_ids.size
ngram = input_ids[-ngram_max:]
for ngram_size in range(ngram_max, ngram_min, -1):
ngram = input_ids[-ngram_size:]
largest_match = 0
candidate = mx.array([], dtype=mx.uint32)
for i in range(input_ids.size - (ngram_max * 2)):
matches = input_ids[i : i + ngram_max] == ngram
for i in range(input_length - (ngram_size * 2)):
if mx.all(input_ids[i : i + ngram_size] == ngram):
start_idx = i + ngram_size
end_idx = start_idx + n_draft
if start_idx < input_length - ngram_size:
return input_ids[start_idx:end_idx]
# reverse through the matches array
match_length = 0
for j in range(matches.size - 1, -1, -1):
if matches[j]:
match_length += 1
else:
break
return mx.array([], dtype=mx.uint32)
if match_length > ngram_min and match_length > largest_match:
largest_match = match_length
start_idx = i + ngram_max
end_idx = start_idx + n_draft
candidate = input_ids[start_idx:end_idx]
return candidate
draft_tokens = find_draft(tokens)
n_drafted += draft_tokens.size