mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-29 12:51:12 +08:00
udpate
This commit is contained in:
parent
0a09a93454
commit
2a8e6f6e44
@ -125,6 +125,11 @@ def generate_grpo(model, prompt, max_tokens, tokenizer, temperature=1.0):
|
|||||||
if prompt.shape[1] == 0:
|
if prompt.shape[1] == 0:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
# Get "</answer>" token ids
|
||||||
|
end_sequence = tokenizer.encode("</answer>")
|
||||||
|
end_sequence_length = len(end_sequence)
|
||||||
|
|
||||||
|
# Use int32 for token ids
|
||||||
output = mx.zeros((prompt.shape[1] + max_tokens,), dtype=mx.int32)
|
output = mx.zeros((prompt.shape[1] + max_tokens,), dtype=mx.int32)
|
||||||
output[:prompt.shape[1]] = prompt[0]
|
output[:prompt.shape[1]] = prompt[0]
|
||||||
current_length = prompt.shape[1]
|
current_length = prompt.shape[1]
|
||||||
@ -146,9 +151,16 @@ def generate_grpo(model, prompt, max_tokens, tokenizer, temperature=1.0):
|
|||||||
output[current_length] = token_value
|
output[current_length] = token_value
|
||||||
current_length += 1
|
current_length += 1
|
||||||
|
|
||||||
|
# Check for EOS token
|
||||||
if token_value == tokenizer.eos_token_id:
|
if token_value == tokenizer.eos_token_id:
|
||||||
break
|
break
|
||||||
|
|
||||||
|
# Check for "</answer>" sequence
|
||||||
|
if current_length >= end_sequence_length:
|
||||||
|
last_tokens = output[current_length - end_sequence_length:current_length].tolist()
|
||||||
|
if last_tokens == end_sequence:
|
||||||
|
break
|
||||||
|
|
||||||
if current_length > prompt.shape[1]:
|
if current_length > prompt.shape[1]:
|
||||||
result = output[:current_length]
|
result = output[:current_length]
|
||||||
return result
|
return result
|
||||||
|
Loading…
Reference in New Issue
Block a user