fix mlx_lm generator for chinese (#321)

* fix generator for chinese

* add REPLACEMENT_CHAR

---------

Co-authored-by: cg <cg@qq.com>
This commit is contained in:
someone 2024-01-16 23:13:33 +08:00 committed by GitHub
parent b0870ed679
commit 2287294723
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -122,6 +122,8 @@ def generate(
tokens = [] tokens = []
skip = 0 skip = 0
REPLACEMENT_CHAR = '\ufffd'
for token, _ in zip(generate_step(prompt, model, temp), range(max_tokens)): for token, _ in zip(generate_step(prompt, model, temp), range(max_tokens)):
if token == tokenizer.eos_token_id: if token == tokenizer.eos_token_id:
break break
@ -130,10 +132,11 @@ def generate(
if verbose: if verbose:
s = tokenizer.decode(tokens) s = tokenizer.decode(tokens)
if REPLACEMENT_CHAR not in s:
print(s[skip:], end="", flush=True) print(s[skip:], end="", flush=True)
skip = len(s) skip = len(s)
tokens = tokenizer.decode(tokens) tokens = tokenizer.decode(tokens).replace(REPLACEMENT_CHAR, '')
if verbose: if verbose:
print(tokens[skip:], flush=True) print(tokens[skip:], flush=True)
return tokens return tokens