clean up and readding temperature argument

This commit is contained in:
Goekdeniz-Guelmez 2025-02-22 02:34:56 +01:00
parent d653371e3d
commit d9c4c6e60c

View File

@ -131,7 +131,7 @@ def r1_count_xml(prompts: list, completions: list, answer: list, **kwargs) -> li
return scores
def generate_grpo(model: nn.Module, prompts, max_tokens, tokenizer, group_size, is_training=False, end_token: str = "</answer>"):
def generate_grpo(model: nn.Module, prompts, max_tokens, tokenizer, group_size, is_training=False, end_token: str = "</answer>", temperature: float = 0.8):
if model.training == False:
print("Model is in training mode", model.training, "Manually setting to eval mode")
model.train()
@ -146,8 +146,6 @@ def generate_grpo(model: nn.Module, prompts, max_tokens, tokenizer, group_size,
end_sequence = mx.array(tokenizer.encode(end_token))
results = []
tokens_generated = 0
start_time = time.perf_counter()
try:
for idx in range(batch_size):
@ -160,20 +158,23 @@ def generate_grpo(model: nn.Module, prompts, max_tokens, tokenizer, group_size,
logits = model(current_input[None], cache=prompt_cache)[:, -1]
while len(current_tokens) < max_tokens:
probs = nn.softmax(logits, axis=-1)
logits_temp = logits / temperature
probs = nn.softmax(logits_temp, axis=-1)
next_token = mx.argmax(probs, axis=-1)
token = next_token.item()
if token == tokenizer.eos_token_id:
break
if (len(current_tokens) >= len(end_sequence) and
test_sequence = current_tokens + [token]
if (len(test_sequence) >= len(end_sequence) and
mx.array_equal(
mx.array(current_tokens[-len(end_sequence):]),
mx.array(test_sequence[-len(end_sequence):]),
end_sequence
)):
break
if token == tokenizer.eos_token_id:
break
current_tokens.append(token)
tokens_generated += 1
current_input = mx.array([token])
logits = model(current_input[None], cache=prompt_cache)[:, -1]
mx.eval(current_input)
@ -189,15 +190,12 @@ def generate_grpo(model: nn.Module, prompts, max_tokens, tokenizer, group_size,
if token == tokenizer.eos_token_id:
break
current_tokens.append(token)
tokens_generated += 1
if current_tokens:
results.append(mx.array(current_tokens))
mx.metal.clear_cache()
mx.eval(results)
generation_time = time.perf_counter() - start_time
print(f"Generated {tokens_generated} tokens in {generation_time:.2f}s ({tokens_generated/generation_time:.2f} tokens/s)")
return results
except Exception as e:
@ -266,7 +264,8 @@ def grpo_loss(
max_tokens,
tokenizer,
group_size,
is_training=True
is_training=True,
temperature=temperature
)
if completions is not None:
for completion_ids in completions: