mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-26 18:51:18 +08:00
updates
This commit is contained in:
parent
0a19522ec4
commit
bcfa55d882
@ -115,28 +115,24 @@ def generate_grpo(model, prompt, max_tokens, tokenizer, temperature):
|
|||||||
if prompt.shape[1] == 0:
|
if prompt.shape[1] == 0:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
# Get "</answer>" token ids
|
|
||||||
end_sequence = tokenizer.encode("</answer>")
|
end_sequence = tokenizer.encode("</answer>")
|
||||||
end_sequence_length = len(end_sequence)
|
end_sequence_length = len(end_sequence)
|
||||||
|
|
||||||
# Use int32 for token ids
|
|
||||||
output = mx.zeros((prompt.shape[1] + max_tokens,), dtype=mx.int32)
|
output = mx.zeros((prompt.shape[1] + max_tokens,), dtype=mx.int32)
|
||||||
output[:prompt.shape[1]] = prompt[0]
|
output[:prompt.shape[1]] = prompt[0]
|
||||||
current_length = prompt.shape[1]
|
current_length = prompt.shape[1]
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
def sample(logits):
|
||||||
|
if temperature > 0:
|
||||||
|
logits /= temperature
|
||||||
|
logprobs = logits - mx.logsumexp(logits, keepdims=True)
|
||||||
|
return mx.random.categorical(logprobs[None, :]).astype(mx.int32)[0]
|
||||||
|
|
||||||
for _ in range(max_tokens):
|
for _ in range(max_tokens):
|
||||||
current_input = output[:current_length][None, :]
|
current_input = output[:current_length][None, :]
|
||||||
logits = model(current_input)
|
logits = model(current_input)
|
||||||
token_logits = logits[0, -1]
|
token_logits = logits[0, -1]
|
||||||
|
next_token = sample(token_logits)
|
||||||
if temperature > 0:
|
|
||||||
token_logits /= temperature
|
|
||||||
|
|
||||||
probs = mx.softmax(token_logits)
|
|
||||||
next_token = mx.random.categorical(probs[None, :]).astype(mx.int32)
|
|
||||||
next_token = next_token[0]
|
|
||||||
|
|
||||||
token_value = next_token.item()
|
token_value = next_token.item()
|
||||||
output[current_length] = token_value
|
output[current_length] = token_value
|
||||||
current_length += 1
|
current_length += 1
|
||||||
@ -146,16 +142,19 @@ def generate_grpo(model, prompt, max_tokens, tokenizer, temperature):
|
|||||||
|
|
||||||
if current_length >= end_sequence_length:
|
if current_length >= end_sequence_length:
|
||||||
last_tokens = output[current_length - end_sequence_length:current_length].tolist()
|
last_tokens = output[current_length - end_sequence_length:current_length].tolist()
|
||||||
|
# print(f"Last tokens: {last_tokens}")
|
||||||
|
# print(f"Decoded text: {tokenizer.decode(last_tokens)}")
|
||||||
|
# print(f"Target sequence: {end_sequence}")
|
||||||
if last_tokens == end_sequence:
|
if last_tokens == end_sequence:
|
||||||
break
|
break
|
||||||
|
|
||||||
if current_length > prompt.shape[1]:
|
if current_length > prompt.shape[1]:
|
||||||
result = output[:current_length]
|
return output[:current_length]
|
||||||
return result
|
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"Generation error: {str(e)}")
|
print(f"Generation error: {str(e)}")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user