This commit is contained in:
Goekdeniz-Guelmez 2025-02-22 00:21:47 +01:00
parent 710bc1490e
commit c51b0a2715

View File

@ -76,10 +76,30 @@ def r1_accuracy_reward_func(prompts: list, completions: list, answer: list, **kw
def r1_soft_format_reward_func(prompts: list, completions: list, answer: list, **kwargs) -> list[float]: def r1_soft_format_reward_func(prompts: list, completions: list, answer: list, **kwargs) -> list[float]:
if not completions: if not completions:
return [0.0] * len(prompts) return [0.0] * len(prompts)
has_think = r"<think>.*</think>"
has_answer = r"<answer>.*</answer>" scores = []
matches = [(bool(re.search(has_think, r)) and bool(re.search(has_answer, r))) if r else False for r in completions] for completion in completions:
return [0.5 if match else 0.0 for match in matches] if not completion:
scores.append(0.0)
continue
reason_start = completion.find("<think>")
reason_end = completion.find("</think>")
answer_start = completion.find("<answer>")
answer_end = completion.find("</answer>")
if (reason_start != -1 and reason_end != -1 and
answer_start != -1 and answer_end != -1 and
reason_start < reason_end < answer_start < answer_end):
reason_content = completion[reason_start+13:reason_end].strip()
answer_content = completion[answer_start+8:answer_end].strip()
if reason_content and answer_content:
scores.append(0.5)
continue
scores.append(0.0)
return scores
def r1_strict_format_reward_func(prompts: list, completions: list, answer: list, **kwargs) -> list[float]: def r1_strict_format_reward_func(prompts: list, completions: list, answer: list, **kwargs) -> list[float]:
if not completions: if not completions:
@ -110,7 +130,7 @@ def r1_count_xml(prompts: list, completions: list, answer: list, **kwargs) -> li
count -= len(end_text) * 0.001 if len(end_text) > 0 else 0 count -= len(end_text) * 0.001 if len(end_text) > 0 else 0
scores.append(max(0.0, count)) scores.append(max(0.0, count))
return scores return scores
def generate_grpo(model: nn.Module, prompts, max_tokens, tokenizer, group_size, is_training=False): def generate_grpo(model: nn.Module, prompts, max_tokens, tokenizer, group_size, is_training=False):
if len(prompts.shape) == 1: if len(prompts.shape) == 1:
@ -118,53 +138,49 @@ def generate_grpo(model: nn.Module, prompts, max_tokens, tokenizer, group_size,
if prompts.shape[1] == 0: if prompts.shape[1] == 0:
return None return None
model.eval()
batch_size = prompts.shape[0] * group_size batch_size = prompts.shape[0] * group_size
expanded_prompts = mx.repeat(prompts, group_size, axis=0) expanded_prompts = mx.repeat(prompts, group_size, axis=0)
mx.eval(expanded_prompts) end_sequence = mx.array(tokenizer.encode("</answer>"))
results = [] results = []
tokens_generated = 0 tokens_generated = 0
start_time = time.perf_counter() start_time = time.perf_counter()
for idx in range(batch_size): try:
current_prompt = expanded_prompts[idx:idx+1] for idx in range(batch_size):
mx.eval(current_prompt) current_tokens = []
current_tokens = []
try:
if is_training: if is_training:
# Initialize with prompt current_input = expanded_prompts[idx]
current_input = current_prompt[0]
mx.eval(current_input)
while len(current_tokens) < max_tokens: while len(current_tokens) < max_tokens:
# Generate one token at a time logits = model(current_input[None])[:, -1]
logits = model(current_input[None]) next_token = mx.argmax(logits, axis=-1)
next_token = mx.random.categorical(logits[:, -1, :])
token = next_token.item() token = next_token.item()
current_tokens.append(token) current_tokens.append(token)
tokens_generated += 1 tokens_generated += 1
# Clear intermediate results
mx.eval(next_token)
del logits
if token == tokenizer.eos_token_id: if token == tokenizer.eos_token_id:
break break
# Update input for next iteration if (len(current_tokens) >= len(end_sequence) and
current_input = mx.array([token]) mx.array_equal(
mx.eval(current_input) mx.array(current_tokens[-len(end_sequence):]),
end_sequence
)):
break
# Clear cache periodically current_input = mx.concatenate([current_input, mx.array([token])])
if len(current_tokens) % 8 == 0:
if len(current_tokens) % 32 == 0:
mx.eval(current_input)
mx.metal.clear_cache() mx.metal.clear_cache()
else: else:
generator = generate_step( generator = generate_step(
current_prompt[0], expanded_prompts[idx],
model, model,
max_tokens=max_tokens, max_tokens=max_tokens,
sampler=lambda x: mx.random.categorical(x) sampler=lambda x: mx.argmax(x, axis=-1)
) )
for token, _ in generator: for token, _ in generator:
@ -174,28 +190,18 @@ def generate_grpo(model: nn.Module, prompts, max_tokens, tokenizer, group_size,
break break
if current_tokens: if current_tokens:
token_array = mx.array(current_tokens) results.append(mx.array(current_tokens))
mx.eval(token_array) mx.metal.clear_cache()
results.append(token_array)
del token_array
except Exception as e:
print(f"Generation failed for sequence {idx}: {e}")
continue
mx.metal.clear_cache() mx.eval(results)
generation_time = time.perf_counter() - start_time
print(f"Generated {tokens_generated} tokens in {generation_time:.2f}s ({tokens_generated/generation_time:.2f} tokens/s)")
return results
if not results: except Exception as e:
print("No successful generations") print(f"Generation error: {str(e)}")
return None return None
mx.eval(results)
generation_time = time.perf_counter() - start_time
print(f"Generated {tokens_generated} tokens in {generation_time:.2f}s ({tokens_generated/generation_time:.2f} tokens/s)")
return results
def get_per_token_logps(model: nn.Module, inputs, lengths): def get_per_token_logps(model: nn.Module, inputs, lengths):
logits = model(inputs).astype(mx.float16) logits = model(inputs).astype(mx.float16)
@ -243,15 +249,23 @@ def grpo_loss(
prompt_tensor = mx.stack([mx.array(p) for p in batch_prompts]) prompt_tensor = mx.stack([mx.array(p) for p in batch_prompts])
try: try:
completions = generate_grpo( if is_validation:
model, completions = generate_grpo(
prompt_tensor, model,
max_tokens, prompt_tensor,
tokenizer, max_tokens,
group_size, tokenizer,
True group_size
) )
else:
completions = generate_grpo(
model,
prompt_tensor,
max_tokens,
tokenizer,
group_size,
is_training=True
)
if completions is not None: if completions is not None:
for completion_ids in completions: for completion_ids in completions:
completion_text = tokenizer.decode(completion_ids.tolist()) completion_text = tokenizer.decode(completion_ids.tolist())
@ -261,8 +275,6 @@ def grpo_loss(
except Exception as e: except Exception as e:
print(f"Generation error: {e}") print(f"Generation error: {e}")
continue continue
mx.metal.clear_cache()
expanded_answers = [] expanded_answers = []
expanded_prompts = [] expanded_prompts = []