mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-24 17:31:18 +08:00
Partially stream de-tokenization (#609)
* partially stream de-tokenization * don't break full response
This commit is contained in:
parent
494cdf8e96
commit
5a52899405
@ -224,6 +224,7 @@ def generate(
|
|||||||
|
|
||||||
tic = time.perf_counter()
|
tic = time.perf_counter()
|
||||||
tokens = []
|
tokens = []
|
||||||
|
token_strings = []
|
||||||
skip = 0
|
skip = 0
|
||||||
REPLACEMENT_CHAR = "\ufffd"
|
REPLACEMENT_CHAR = "\ufffd"
|
||||||
|
|
||||||
@ -250,15 +251,20 @@ def generate(
|
|||||||
if formatter:
|
if formatter:
|
||||||
formatter(s[skip:], prob.item())
|
formatter(s[skip:], prob.item())
|
||||||
skip = len(s)
|
skip = len(s)
|
||||||
elif REPLACEMENT_CHAR not in s:
|
elif s[-1] != REPLACEMENT_CHAR:
|
||||||
print(s[skip:], end="", flush=True)
|
print(s[skip:], end="", flush=True)
|
||||||
skip = len(s)
|
skip = len(s)
|
||||||
|
# Reset token cache at line break
|
||||||
|
if s[-1] == "\n":
|
||||||
|
tokens = []
|
||||||
|
token_strings.append(s)
|
||||||
|
skip = 0
|
||||||
|
|
||||||
token_count = len(tokens)
|
token_count = n + 1
|
||||||
token_string = tokenizer.decode(tokens).replace(REPLACEMENT_CHAR, "")
|
token_strings.append(tokenizer.decode(tokens).replace(REPLACEMENT_CHAR, ""))
|
||||||
|
|
||||||
if verbose:
|
if verbose:
|
||||||
print(token_string[skip:], flush=True)
|
print(token_strings[-1][skip:], flush=True)
|
||||||
gen_time = time.perf_counter() - tic
|
gen_time = time.perf_counter() - tic
|
||||||
print("=" * 10)
|
print("=" * 10)
|
||||||
if token_count == 0:
|
if token_count == 0:
|
||||||
@ -269,7 +275,7 @@ def generate(
|
|||||||
print(f"Prompt: {prompt_tps:.3f} tokens-per-sec")
|
print(f"Prompt: {prompt_tps:.3f} tokens-per-sec")
|
||||||
print(f"Generation: {gen_tps:.3f} tokens-per-sec")
|
print(f"Generation: {gen_tps:.3f} tokens-per-sec")
|
||||||
|
|
||||||
return token_string
|
return "".join(token_strings)
|
||||||
|
|
||||||
|
|
||||||
def load_model(model_path: Path, lazy: bool = False) -> nn.Module:
|
def load_model(model_path: Path, lazy: bool = False) -> nn.Module:
|
||||||
|
Loading…
Reference in New Issue
Block a user