This commit is contained in:
Goekdeniz-Guelmez 2025-02-22 01:05:58 +01:00
parent c51b0a2715
commit 79de353530

View File

@ -155,23 +155,20 @@ def generate_grpo(model: nn.Module, prompts, max_tokens, tokenizer, group_size,
current_input = expanded_prompts[idx] current_input = expanded_prompts[idx]
while len(current_tokens) < max_tokens: while len(current_tokens) < max_tokens:
logits = model(current_input[None])[:, -1] logits = model(current_input[None])[:, -1]
next_token = mx.argmax(logits, axis=-1) probs = nn.softmax(logits, axis=-1)
next_token = mx.argmax(probs, axis=-1)
token = next_token.item() token = next_token.item()
current_tokens.append(token) current_tokens.append(token)
tokens_generated += 1 tokens_generated += 1
if token == tokenizer.eos_token_id: if token == tokenizer.eos_token_id:
break break
if (len(current_tokens) >= len(end_sequence) and if (len(current_tokens) >= len(end_sequence) and
mx.array_equal( mx.array_equal(
mx.array(current_tokens[-len(end_sequence):]), mx.array(current_tokens[-len(end_sequence):]),
end_sequence end_sequence
)): )):
break break
current_input = mx.concatenate([current_input, mx.array([token])]) current_input = mx.concatenate([current_input, mx.array([token])])
if len(current_tokens) % 32 == 0: if len(current_tokens) % 32 == 0:
mx.eval(current_input) mx.eval(current_input)
mx.metal.clear_cache() mx.metal.clear_cache()
@ -182,7 +179,6 @@ def generate_grpo(model: nn.Module, prompts, max_tokens, tokenizer, group_size,
max_tokens=max_tokens, max_tokens=max_tokens,
sampler=lambda x: mx.argmax(x, axis=-1) sampler=lambda x: mx.argmax(x, axis=-1)
) )
for token, _ in generator: for token, _ in generator:
current_tokens.append(token) current_tokens.append(token)
tokens_generated += 1 tokens_generated += 1
@ -276,6 +272,8 @@ def grpo_loss(
print(f"Generation error: {e}") print(f"Generation error: {e}")
continue continue
mx.metal.clear_cache()
expanded_answers = [] expanded_answers = []
expanded_prompts = [] expanded_prompts = []
for i in range(batch_size): for i in range(batch_size):