Partially stream de-tokenization (#609)

* partially stream de-tokenization

* don't break full response
This commit is contained in:
Awni Hannun 2024-03-23 15:32:33 -07:00 committed by GitHub
parent 494cdf8e96
commit 5a52899405
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -224,6 +224,7 @@ def generate(
tic = time.perf_counter()
tokens = []
token_strings = []
skip = 0
REPLACEMENT_CHAR = "\ufffd"
@ -250,15 +251,20 @@ def generate(
if formatter:
formatter(s[skip:], prob.item())
skip = len(s)
elif REPLACEMENT_CHAR not in s:
elif s[-1] != REPLACEMENT_CHAR:
print(s[skip:], end="", flush=True)
skip = len(s)
# Reset token cache at line break
if s[-1] == "\n":
tokens = []
token_strings.append(s)
skip = 0
token_count = len(tokens)
token_string = tokenizer.decode(tokens).replace(REPLACEMENT_CHAR, "")
token_count = n + 1
token_strings.append(tokenizer.decode(tokens).replace(REPLACEMENT_CHAR, ""))
if verbose:
print(token_string[skip:], flush=True)
print(token_strings[-1][skip:], flush=True)
gen_time = time.perf_counter() - tic
print("=" * 10)
if token_count == 0:
@ -269,7 +275,7 @@ def generate(
print(f"Prompt: {prompt_tps:.3f} tokens-per-sec")
print(f"Generation: {gen_tps:.3f} tokens-per-sec")
return token_string
return "".join(token_strings)
def load_model(model_path: Path, lazy: bool = False) -> nn.Module: