mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-07-19 09:31:13 +08:00
nits
This commit is contained in:
parent
c51b0a2715
commit
79de353530
@ -155,23 +155,20 @@ def generate_grpo(model: nn.Module, prompts, max_tokens, tokenizer, group_size,
|
||||
current_input = expanded_prompts[idx]
|
||||
while len(current_tokens) < max_tokens:
|
||||
logits = model(current_input[None])[:, -1]
|
||||
next_token = mx.argmax(logits, axis=-1)
|
||||
probs = nn.softmax(logits, axis=-1)
|
||||
next_token = mx.argmax(probs, axis=-1)
|
||||
token = next_token.item()
|
||||
current_tokens.append(token)
|
||||
tokens_generated += 1
|
||||
|
||||
if token == tokenizer.eos_token_id:
|
||||
break
|
||||
|
||||
if (len(current_tokens) >= len(end_sequence) and
|
||||
mx.array_equal(
|
||||
mx.array(current_tokens[-len(end_sequence):]),
|
||||
end_sequence
|
||||
)):
|
||||
break
|
||||
|
||||
current_input = mx.concatenate([current_input, mx.array([token])])
|
||||
|
||||
if len(current_tokens) % 32 == 0:
|
||||
mx.eval(current_input)
|
||||
mx.metal.clear_cache()
|
||||
@ -182,7 +179,6 @@ def generate_grpo(model: nn.Module, prompts, max_tokens, tokenizer, group_size,
|
||||
max_tokens=max_tokens,
|
||||
sampler=lambda x: mx.argmax(x, axis=-1)
|
||||
)
|
||||
|
||||
for token, _ in generator:
|
||||
current_tokens.append(token)
|
||||
tokens_generated += 1
|
||||
@ -276,6 +272,8 @@ def grpo_loss(
|
||||
print(f"Generation error: {e}")
|
||||
continue
|
||||
|
||||
mx.metal.clear_cache()
|
||||
|
||||
expanded_answers = []
|
||||
expanded_prompts = []
|
||||
for i in range(batch_size):
|
||||
|
Loading…
Reference in New Issue
Block a user