fix wrong generation in train

This commit is contained in:
Goekdeniz-Guelmez
2025-02-22 17:21:08 +01:00
parent d9c4c6e60c
commit 9705ed908e

View File

@@ -132,21 +132,15 @@ def r1_count_xml(prompts: list, completions: list, answer: list, **kwargs) -> li
def generate_grpo(model: nn.Module, prompts, max_tokens, tokenizer, group_size, is_training=False, end_token: str = "</answer>", temperature: float = 0.8): def generate_grpo(model: nn.Module, prompts, max_tokens, tokenizer, group_size, is_training=False, end_token: str = "</answer>", temperature: float = 0.8):
if model.training == False:
print("Model is in training mode", model.training, "Manually setting to eval mode")
model.train()
if len(prompts.shape) == 1: if len(prompts.shape) == 1:
prompts = prompts[None, :] prompts = prompts[None, :]
if prompts.shape[1] == 0: if prompts.shape[1] == 0:
return None return None
batch_size = prompts.shape[0] * group_size batch_size = prompts.shape[0] * group_size
expanded_prompts = mx.repeat(prompts, group_size, axis=0) expanded_prompts = mx.repeat(prompts, group_size, axis=0)
end_sequence = mx.array(tokenizer.encode(end_token)) end_sequence = mx.array(tokenizer.encode(end_token))
results = [] results = []
mx.eval(expanded_prompts)
try: try:
for idx in range(batch_size): for idx in range(batch_size):
current_tokens = [] current_tokens = []
@@ -154,8 +148,8 @@ def generate_grpo(model: nn.Module, prompts, max_tokens, tokenizer, group_size,
if is_training: if is_training:
current_input = expanded_prompts[idx] current_input = expanded_prompts[idx]
prompt_cache = cache.make_prompt_cache(model) prompt_cache = cache.make_prompt_cache(model)
logits = model(current_input[None], cache=prompt_cache)[:, -1] logits = model(current_input[None], cache=prompt_cache)[:, -1]
mx.eval(logits, prompt_cache)
while len(current_tokens) < max_tokens: while len(current_tokens) < max_tokens:
logits_temp = logits / temperature logits_temp = logits / temperature
@@ -169,16 +163,16 @@ def generate_grpo(model: nn.Module, prompts, max_tokens, tokenizer, group_size,
mx.array(test_sequence[-len(end_sequence):]), mx.array(test_sequence[-len(end_sequence):]),
end_sequence end_sequence
)): )):
current_tokens.append(token)
break break
if token == tokenizer.eos_token_id: if token == tokenizer.eos_token_id:
break break
current_tokens.append(token) current_tokens.append(token)
current_input = mx.array([token]) current_input = mx.array([token])
logits = model(current_input[None], cache=prompt_cache)[:, -1] logits = model(current_input[None], cache=prompt_cache)[:, -1]
mx.eval(current_input) mx.eval(current_input, logits, probs, next_token)
mx.metal.clear_cache()
else: else:
generator = generate_step( generator = generate_step(
expanded_prompts[idx], expanded_prompts[idx],