adding custom system message integration in dataset, more opimizations (generates now faster, while same RAM usage), fix for the identical generatrions, seperated the reward functions into a seperate file.

This commit is contained in:
Goekdeniz-Guelmez 2025-02-24 20:49:11 +01:00
parent bd5f081ca5
commit e4eac9c97b
3 changed files with 122 additions and 109 deletions

View File

@ -19,6 +19,7 @@ class GRPODataset:
tokenizer: PreTrainedTokenizer,
prompt_key: str = "prompt",
answer_key: str = "answer",
system_key: str = "system",
use_chat_template: bool = False,
use_prompt: bool = False
):
@ -27,9 +28,11 @@ class GRPODataset:
prompt_str = str(item[prompt_key])
answer_str = str(item[answer_key])
if use_chat_template:
default_system_str = "A conversation between User and Assistant. The user asks a question, and the Assistant solves it. The assistant first thinks about the reasoning process in the mind and then provides the user with the answer. The reasoning process and answer are enclosed within <think> </think> and <answer> </answer> tags, respectively, i.e., <think> reasoning process here </think><answer> answer here </answer>."
system_str = item.get(system_key, default_system_str)
prompt_tokens = tokenizer.apply_chat_template(
[
{'role': 'system', 'content': """A conversation between User and Assistant. The user asks a question, and the Assistant solves it. The assistant first thinks about the reasoning process in the mind and then provides the user with the answer. The reasoning process and answer are enclosed within <think> </think> and <answer> </answer> tags, respectively, i.e., <think> reasoning process here </think><answer> answer here </answer>."""},
{'role': 'system', 'content': system_str},
{'role': 'user', 'content': prompt_str}
],
add_generation_prompt=True

View File

@ -0,0 +1,82 @@
from typing import List, Optional, Callable
import re
RewardFunctions = Callable[[List[str], List[str], List[str]], List[float]]
def r1_extract_xml_answer(text: str) -> str:
try:
answer = text.split("<answer>")[-1]
answer = answer.split("</answer>")[0]
return answer.strip()
except:
print("r1_extract_xml_answer returned empty string")
return ""
def r1_int_reward_func(prompts: list, completions: list, answer: list, **kwargs) -> list[float]:
if not completions:
return [0.0] * len(prompts)
extracted_responses = [r1_extract_xml_answer(r) for r in completions]
return [0.5 if r and r.isdigit() else 0.0 for r in extracted_responses]
def r1_accuracy_reward_func(prompts: list, completions: list, answer: list, **kwargs) -> list[float]:
if not completions or not answer:
return [0.0] * len(prompts)
extracted_responses = [r1_extract_xml_answer(r) for r in completions]
return [2.0 if r and a and r == a else 0.0 for r, a in zip(extracted_responses, answer)]
def r1_soft_format_reward_func(prompts: list, completions: list, answer: list, **kwargs) -> list[float]:
if not completions:
return [0.0] * len(prompts)
scores = []
for completion in completions:
if not completion:
scores.append(0.0)
continue
reason_start = completion.find("<think>")
reason_end = completion.find("</think>")
answer_start = completion.find("<answer>")
answer_end = completion.find("</answer>")
if (reason_start != -1 and reason_end != -1 and
answer_start != -1 and answer_end != -1 and
reason_start < reason_end < answer_start < answer_end):
reason_content = completion[reason_start+13:reason_end].strip()
answer_content = completion[answer_start+8:answer_end].strip()
if reason_content and answer_content:
scores.append(0.5)
continue
scores.append(0.0)
return scores
def r1_strict_format_reward_func(prompts: list, completions: list, answer: list, **kwargs) -> list[float]:
if not completions:
return [0.0] * len(prompts)
pattern = r"<think>\n.*?\n</think>\n<answer>*?</answer>"
matches = [bool(re.search(pattern, r)) if r else False for r in completions]
return [0.5 if match else 0.0 for match in matches]
def r1_count_xml(prompts: list, completions: list, answer: list, **kwargs) -> list[float]:
if not completions:
return [0.0] * len(prompts)
scores = []
for text in completions:
if not text:
scores.append(0.0)
continue
count = 0.0
if text.count("<think>\n") == 1:
count += 0.125
if text.count("</think>") == 1:
count += 0.125
if text.count("<answer>") == 1:
count += 0.125
if text.count("</answer>") == 1:
count += 0.125
end_text = text.split("</answer>")[-1]
count -= len(end_text) * 0.001 if len(end_text) > 0 else 0
scores.append(max(0.0, count))
return scores

View File

@ -1,16 +1,16 @@
# Copyright © 2024 Apple Inc.
from typing import List, Optional, Callable
from typing import List, Optional, Tuple, Generator, Callable, Any
from dataclasses import dataclass, field
from pathlib import Path
import time
import re
from mlx.utils import tree_flatten
import mlx.core as mx
import mlx.nn as nn
import numpy as np
from .grpo_reward_functions import r1_accuracy_reward_func, r1_int_reward_func, r1_strict_format_reward_func, r1_soft_format_reward_func, r1_count_xml, RewardFunctions
from .trainer import grad_checkpoint, TrainingArgs, TrainingCallback, average_gradients
from ..utils import generate_step
from ..models import cache
@ -50,88 +50,16 @@ class GRPOTrainingArgs(TrainingArgs):
)
RewardFunctions = Callable[[List[str], List[str], List[str]], List[float]]
def r1_extract_xml_answer(text: str) -> str:
try:
answer = text.split("<answer>")[-1]
answer = answer.split("</answer>")[0]
return answer.strip()
except:
print("r1_extract_xml_answer returned empty string")
return ""
def r1_int_reward_func(prompts: list, completions: list, answer: list, **kwargs) -> list[float]:
if not completions:
return [0.0] * len(prompts)
extracted_responses = [r1_extract_xml_answer(r) for r in completions]
return [0.5 if r and r.isdigit() else 0.0 for r in extracted_responses]
def r1_accuracy_reward_func(prompts: list, completions: list, answer: list, **kwargs) -> list[float]:
if not completions or not answer:
return [0.0] * len(prompts)
extracted_responses = [r1_extract_xml_answer(r) for r in completions]
return [2.0 if r and a and r == a else 0.0 for r, a in zip(extracted_responses, answer)]
def r1_soft_format_reward_func(prompts: list, completions: list, answer: list, **kwargs) -> list[float]:
if not completions:
return [0.0] * len(prompts)
scores = []
for completion in completions:
if not completion:
scores.append(0.0)
continue
reason_start = completion.find("<think>")
reason_end = completion.find("</think>")
answer_start = completion.find("<answer>")
answer_end = completion.find("</answer>")
if (reason_start != -1 and reason_end != -1 and
answer_start != -1 and answer_end != -1 and
reason_start < reason_end < answer_start < answer_end):
reason_content = completion[reason_start+13:reason_end].strip()
answer_content = completion[answer_start+8:answer_end].strip()
if reason_content and answer_content:
scores.append(0.5)
continue
scores.append(0.0)
return scores
def r1_strict_format_reward_func(prompts: list, completions: list, answer: list, **kwargs) -> list[float]:
if not completions:
return [0.0] * len(prompts)
pattern = r"<think>\n.*?\n</think>\n<answer>*?</answer>"
matches = [bool(re.search(pattern, r)) if r else False for r in completions]
return [0.5 if match else 0.0 for match in matches]
def r1_count_xml(prompts: list, completions: list, answer: list, **kwargs) -> list[float]:
if not completions:
return [0.0] * len(prompts)
scores = []
for text in completions:
if not text:
scores.append(0.0)
continue
count = 0.0
if text.count("<think>\n") == 1:
count += 0.125
if text.count("</think>") == 1:
count += 0.125
if text.count("<answer>") == 1:
count += 0.125
if text.count("</answer>") == 1:
count += 0.125
end_text = text.split("</answer>")[-1]
count -= len(end_text) * 0.001 if len(end_text) > 0 else 0
scores.append(max(0.0, count))
return scores
def generate_grpo(model: nn.Module, prompts, max_tokens, tokenizer, group_size, is_training=False, end_token: str = "</answer>", temperature: float = 0.8):
def generate_grpo(
model: nn.Module,
prompts,
max_tokens,
tokenizer,
group_size,
is_training=False,
end_token: str = "</answer>",
temperature: float = 0.8
):
if len(prompts.shape) == 1:
prompts = prompts[None, :]
if prompts.shape[1] == 0:
@ -144,19 +72,38 @@ def generate_grpo(model: nn.Module, prompts, max_tokens, tokenizer, group_size,
try:
for idx in range(batch_size):
current_tokens = []
if is_training:
current_input = expanded_prompts[idx]
prompt_cache = cache.make_prompt_cache(model)
logits = model(current_input[None], cache=prompt_cache)[:, -1]
mx.eval(logits, prompt_cache)
while len(current_tokens) < max_tokens:
logits_temp = logits / temperature
probs = nn.softmax(logits_temp, axis=-1)
next_token = mx.argmax(probs, axis=-1)
next_token = mx.random.categorical(logits_temp)
token = next_token.item()
test_sequence = current_tokens + [token]
if (len(test_sequence) >= len(end_sequence) and
mx.array_equal(
mx.array(test_sequence[-len(end_sequence):]),
end_sequence
)):
current_tokens.append(token)
break
if token == tokenizer.eos_token_id:
break
current_tokens.append(token)
current_input = mx.array([token])
logits = model(current_input[None], cache=prompt_cache)[:, -1]
mx.eval(current_input, logits, probs, next_token, token)
else:
generator = generate_step(
expanded_prompts[idx],
model,
max_tokens=max_tokens,
sampler=lambda x: mx.random.categorical(x / temperature)
)
for token, _ in generator:
test_sequence = current_tokens + [token]
if (len(test_sequence) >= len(end_sequence) and
mx.array_equal(
@ -168,27 +115,10 @@ def generate_grpo(model: nn.Module, prompts, max_tokens, tokenizer, group_size,
if token == tokenizer.eos_token_id:
break
current_tokens.append(token)
current_input = mx.array([token])
logits = model(current_input[None], cache=prompt_cache)[:, -1]
mx.eval(current_input, logits, probs, next_token)
else:
generator = generate_step(
expanded_prompts[idx],
model,
max_tokens=max_tokens,
sampler=lambda x: mx.argmax(x, axis=-1)
)
for token, _ in generator:
if token == tokenizer.eos_token_id:
break
current_tokens.append(token)
if current_tokens:
results.append(mx.array(current_tokens))
mx.metal.clear_cache()
mx.eval(results)
return results
@ -207,12 +137,10 @@ def get_per_token_logps(model: nn.Module, inputs, lengths):
seq_logits = logits[i, :seq_len]
seq_targets = targets[i, :seq_len]
log_probs = nn.log_softmax(seq_logits, axis=-1)
token_log_probs = mx.take_along_axis(
log_probs,
seq_targets.reshape(seq_len, 1), axis=-1
).squeeze(-1)
per_token_logps.append(token_log_probs)
mx.eval(logits)
return per_token_logps
@ -399,7 +327,7 @@ def grpo_loss(
}
if is_validation:
print(f"\nValidation sample generation:\n{all_completion_texts[-1]}\n")
print(f"\nValidation sample generation:\n{all_completion_texts}\n")
print(f"Validation sample answer:\n{answer_text[-1]}\n")
mx.metal.clear_cache()