mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-07-21 02:51:13 +08:00
nits
This commit is contained in:
parent
c51b0a2715
commit
79de353530
@ -155,23 +155,20 @@ def generate_grpo(model: nn.Module, prompts, max_tokens, tokenizer, group_size,
|
|||||||
current_input = expanded_prompts[idx]
|
current_input = expanded_prompts[idx]
|
||||||
while len(current_tokens) < max_tokens:
|
while len(current_tokens) < max_tokens:
|
||||||
logits = model(current_input[None])[:, -1]
|
logits = model(current_input[None])[:, -1]
|
||||||
next_token = mx.argmax(logits, axis=-1)
|
probs = nn.softmax(logits, axis=-1)
|
||||||
|
next_token = mx.argmax(probs, axis=-1)
|
||||||
token = next_token.item()
|
token = next_token.item()
|
||||||
current_tokens.append(token)
|
current_tokens.append(token)
|
||||||
tokens_generated += 1
|
tokens_generated += 1
|
||||||
|
|
||||||
if token == tokenizer.eos_token_id:
|
if token == tokenizer.eos_token_id:
|
||||||
break
|
break
|
||||||
|
|
||||||
if (len(current_tokens) >= len(end_sequence) and
|
if (len(current_tokens) >= len(end_sequence) and
|
||||||
mx.array_equal(
|
mx.array_equal(
|
||||||
mx.array(current_tokens[-len(end_sequence):]),
|
mx.array(current_tokens[-len(end_sequence):]),
|
||||||
end_sequence
|
end_sequence
|
||||||
)):
|
)):
|
||||||
break
|
break
|
||||||
|
|
||||||
current_input = mx.concatenate([current_input, mx.array([token])])
|
current_input = mx.concatenate([current_input, mx.array([token])])
|
||||||
|
|
||||||
if len(current_tokens) % 32 == 0:
|
if len(current_tokens) % 32 == 0:
|
||||||
mx.eval(current_input)
|
mx.eval(current_input)
|
||||||
mx.metal.clear_cache()
|
mx.metal.clear_cache()
|
||||||
@ -182,7 +179,6 @@ def generate_grpo(model: nn.Module, prompts, max_tokens, tokenizer, group_size,
|
|||||||
max_tokens=max_tokens,
|
max_tokens=max_tokens,
|
||||||
sampler=lambda x: mx.argmax(x, axis=-1)
|
sampler=lambda x: mx.argmax(x, axis=-1)
|
||||||
)
|
)
|
||||||
|
|
||||||
for token, _ in generator:
|
for token, _ in generator:
|
||||||
current_tokens.append(token)
|
current_tokens.append(token)
|
||||||
tokens_generated += 1
|
tokens_generated += 1
|
||||||
@ -276,6 +272,8 @@ def grpo_loss(
|
|||||||
print(f"Generation error: {e}")
|
print(f"Generation error: {e}")
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
mx.metal.clear_cache()
|
||||||
|
|
||||||
expanded_answers = []
|
expanded_answers = []
|
||||||
expanded_prompts = []
|
expanded_prompts = []
|
||||||
for i in range(batch_size):
|
for i in range(batch_size):
|
||||||
|
Loading…
Reference in New Issue
Block a user