mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-24 17:31:18 +08:00
fix mlx_lm generator for chinese (#321)
* fix generator for chinese * add REPLACEMENT_CHAR --------- Co-authored-by: cg <cg@qq.com>
This commit is contained in:
parent
b0870ed679
commit
2287294723
@ -122,6 +122,8 @@ def generate(
|
|||||||
|
|
||||||
tokens = []
|
tokens = []
|
||||||
skip = 0
|
skip = 0
|
||||||
|
REPLACEMENT_CHAR = '\ufffd'
|
||||||
|
|
||||||
for token, _ in zip(generate_step(prompt, model, temp), range(max_tokens)):
|
for token, _ in zip(generate_step(prompt, model, temp), range(max_tokens)):
|
||||||
if token == tokenizer.eos_token_id:
|
if token == tokenizer.eos_token_id:
|
||||||
break
|
break
|
||||||
@ -130,10 +132,11 @@ def generate(
|
|||||||
|
|
||||||
if verbose:
|
if verbose:
|
||||||
s = tokenizer.decode(tokens)
|
s = tokenizer.decode(tokens)
|
||||||
print(s[skip:], end="", flush=True)
|
if REPLACEMENT_CHAR not in s:
|
||||||
skip = len(s)
|
print(s[skip:], end="", flush=True)
|
||||||
|
skip = len(s)
|
||||||
|
|
||||||
tokens = tokenizer.decode(tokens)
|
tokens = tokenizer.decode(tokens).replace(REPLACEMENT_CHAR, '')
|
||||||
if verbose:
|
if verbose:
|
||||||
print(tokens[skip:], flush=True)
|
print(tokens[skip:], flush=True)
|
||||||
return tokens
|
return tokens
|
||||||
|
Loading…
Reference in New Issue
Block a user