clean up and readding temperature argument

This commit is contained in:
Goekdeniz-Guelmez 2025-02-22 02:34:56 +01:00
parent d653371e3d
commit d9c4c6e60c

View File

@ -131,7 +131,7 @@ def r1_count_xml(prompts: list, completions: list, answer: list, **kwargs) -> li
return scores return scores
def generate_grpo(model: nn.Module, prompts, max_tokens, tokenizer, group_size, is_training=False, end_token: str = "</answer>"): def generate_grpo(model: nn.Module, prompts, max_tokens, tokenizer, group_size, is_training=False, end_token: str = "</answer>", temperature: float = 0.8):
if model.training == False: if model.training == False:
print("Model is in training mode", model.training, "Manually setting to eval mode") print("Model is in training mode", model.training, "Manually setting to eval mode")
model.train() model.train()
@ -146,8 +146,6 @@ def generate_grpo(model: nn.Module, prompts, max_tokens, tokenizer, group_size,
end_sequence = mx.array(tokenizer.encode(end_token)) end_sequence = mx.array(tokenizer.encode(end_token))
results = [] results = []
tokens_generated = 0
start_time = time.perf_counter()
try: try:
for idx in range(batch_size): for idx in range(batch_size):
@ -160,20 +158,23 @@ def generate_grpo(model: nn.Module, prompts, max_tokens, tokenizer, group_size,
logits = model(current_input[None], cache=prompt_cache)[:, -1] logits = model(current_input[None], cache=prompt_cache)[:, -1]
while len(current_tokens) < max_tokens: while len(current_tokens) < max_tokens:
probs = nn.softmax(logits, axis=-1) logits_temp = logits / temperature
probs = nn.softmax(logits_temp, axis=-1)
next_token = mx.argmax(probs, axis=-1) next_token = mx.argmax(probs, axis=-1)
token = next_token.item() token = next_token.item()
if token == tokenizer.eos_token_id:
break test_sequence = current_tokens + [token]
if (len(current_tokens) >= len(end_sequence) and if (len(test_sequence) >= len(end_sequence) and
mx.array_equal( mx.array_equal(
mx.array(current_tokens[-len(end_sequence):]), mx.array(test_sequence[-len(end_sequence):]),
end_sequence end_sequence
)): )):
break break
if token == tokenizer.eos_token_id:
break
current_tokens.append(token) current_tokens.append(token)
tokens_generated += 1
current_input = mx.array([token]) current_input = mx.array([token])
logits = model(current_input[None], cache=prompt_cache)[:, -1] logits = model(current_input[None], cache=prompt_cache)[:, -1]
mx.eval(current_input) mx.eval(current_input)
@ -189,15 +190,12 @@ def generate_grpo(model: nn.Module, prompts, max_tokens, tokenizer, group_size,
if token == tokenizer.eos_token_id: if token == tokenizer.eos_token_id:
break break
current_tokens.append(token) current_tokens.append(token)
tokens_generated += 1
if current_tokens: if current_tokens:
results.append(mx.array(current_tokens)) results.append(mx.array(current_tokens))
mx.metal.clear_cache() mx.metal.clear_cache()
mx.eval(results) mx.eval(results)
generation_time = time.perf_counter() - start_time
print(f"Generated {tokens_generated} tokens in {generation_time:.2f}s ({tokens_generated/generation_time:.2f} tokens/s)")
return results return results
except Exception as e: except Exception as e:
@ -266,7 +264,8 @@ def grpo_loss(
max_tokens, max_tokens,
tokenizer, tokenizer,
group_size, group_size,
is_training=True is_training=True,
temperature=temperature
) )
if completions is not None: if completions is not None:
for completion_ids in completions: for completion_ids in completions: