first working prototype, will try training out at home

This commit is contained in:
Goekdeniz-Guelmez 2025-02-03 12:05:29 +01:00
parent 23d75cd7ad
commit 1d9e4802f0
2 changed files with 254 additions and 122 deletions

View File

@ -9,7 +9,7 @@ class GRPODataset:
"""
Dataset wrapper for GRPO training data.
Each example should have a 'prompt' and 'answer' field.
Returns data in (prompt, answer) tuple format required by GRPO trainer.
Returns data in (prompt_tokens, answer_tokens, prompt_str, answer_str) tuple format.
"""
def __init__(
self,
@ -20,15 +20,14 @@ class GRPODataset:
):
self._data = []
for item in data:
# Get prompt and answer text
prompt = str(item[prompt_key])
answer = str(item[answer_key])
# Store as (prompt, answer) tuple
self._data.append((prompt, answer))
prompt_str = str(item[prompt_key])
answer_str = str(item[answer_key])
prompt_tokens = tokenizer.encode(prompt_str)
answer_tokens = tokenizer.encode(answer_str)
self._data.append((prompt_tokens, answer_tokens, prompt_str, answer_str))
def __getitem__(self, idx: int) -> Tuple[str, str]:
"""Returns a (prompt, answer) tuple for the given index."""
def __getitem__(self, idx: int) -> Tuple[List[int], List[int], str, str]:
"""Returns a (prompt_tokens, answer_tokens, prompt_str, answer_str) tuple."""
return self._data[idx]
def __len__(self) -> int:

View File

@ -12,7 +12,7 @@ from mlx.utils import tree_flatten
from .trainer import grad_checkpoint, TrainingArgs, TrainingCallback, average_gradients, iterate_batches
from mlx_lm import generate
from mlx_lm.utils import generate_step
@dataclass
@ -35,6 +35,70 @@ class GRPOTrainingArgs(TrainingArgs):
)
def generate_for_grpo(
model,
prompt,
max_tokens,
tokenizer,
temperature=1.0
):
try:
# Ensure prompt is the right shape
if len(prompt.shape) == 1:
prompt = prompt[None, :]
# Initialize generation
generated = []
current_prompt = prompt[0]
for step in range(max_tokens):
try:
# Get model output with explicit shape checking
current_batch = current_prompt[None, :]
logits = model(current_batch)
# Ensure we have the last token logits
token_logits = logits[0, -1]
# Apply temperature and get probabilities
if temperature > 0:
token_logits = token_logits / temperature
probs = mx.softmax(token_logits)
# Sample the next token
next_token = mx.random.categorical(probs[None, :])
next_token = next_token[0]
# Force evaluation to catch any issues
mx.eval(next_token)
token_value = next_token.item()
# Add to generated sequence
generated.append(next_token)
current_prompt = mx.concatenate([current_prompt, next_token[None]])
if token_value == tokenizer.eos_token_id:
break
except Exception as e:
raise
if not generated:
return prompt[0]
try:
result = mx.concatenate([prompt[0], mx.stack(generated)])
mx.eval(result)
return result
except Exception as e:
raise
except Exception as e:
raise
def r1_extract_xml_answer(text: str) -> str:
"""Extracts the answer from an XML formatted text string."""
try:
@ -45,42 +109,45 @@ def r1_extract_xml_answer(text: str) -> str:
print("[extract_xml_answer] Failed to extract answer from: ", text)
return ""
def r1_accuracy_reward_func(prompts, completions, answer, **kwargs) -> list[float]:
def r1_accuracy_reward_func(prompts: list, completions: list, answer: list, **kwargs) -> list[float]:
"""Calculates reward based on accuracy of extracted answers.
Args:
prompts: List of input prompts
completions: List of completion strings
answer: Expected answer or list of answers
**kwargs: Additional arguments
Returns:
list[float]: Reward values for each completion
"""
extracted_responses = [r1_extract_xml_answer(r) for r in completions]
q = prompts[0] if isinstance(prompts[0], str) else prompts[0][-1]['content']
return [2.0 if r == a else 0.0 for r, a in zip(extracted_responses, answer)]
def r1_int_reward_func(completions, **kwargs) -> list[float]:
def r1_int_reward_func(prompts: list, completions: list, answer: list, **kwargs) -> list[float]:
"""Rewards numerical responses.
Args:
prompts: List of input prompts
completions: List of completion strings
answer: Expected answer or list of answers
**kwargs: Additional arguments
Returns:
list[float]: Reward values for each completion
"""
extracted_responses = [r1_extract_xml_answer(r) for r in completions]
return [0.5 if r.isdigit() else 0.0 for r in extracted_responses]
def r1_strict_format_reward_func(completions, **kwargs) -> list[float]:
def r1_soft_format_reward_func(prompts: list, completions: list, answer: list, **kwargs) -> list[float]:
"""Rewards completions with flexible XML format."""
pattern = r"<think>.*?</think>\s*<answer>.*?</answer>"
matches = [re.match(pattern, r) for r in completions]
return [0.5 if match else 0.0 for match in matches]
def r1_strict_format_reward_func(prompts: list, completions: list, answer: list, **kwargs) -> list[float]:
"""Rewards completions with strict XML format.
Args:
prompts: List of input prompts
completions: List of completion strings
answer: Expected answer or list of answers
**kwargs: Additional arguments
Returns:
list[float]: Reward values for each completion
"""
@ -88,98 +155,128 @@ def r1_strict_format_reward_func(completions, **kwargs) -> list[float]:
matches = [re.match(pattern, r) for r in completions]
return [0.5 if match else 0.0 for match in matches]
def r1_soft_format_reward_func(completions, **kwargs) -> list[float]:
"""Rewards completions with flexible XML format.
Args:
completions: List of completion strings
**kwargs: Additional arguments
Returns:
list[float]: Reward values for each completion
"""
pattern = r"<think>.*?</think>\s*<answer>.*?</answer>"
matches = [re.match(pattern, r) for r in completions]
return [0.5 if match else 0.0 for match in matches]
def r1_count_xml(text: str) -> float:
def r1_count_xml(prompts: list, completions: list, answer: list, **kwargs) -> list[float]:
"""Calculates score based on XML formatting.
Args:
text: Input text string
prompts: List of input prompts (unused)
completions: List of completion strings to evaluate
answer: Expected answer or list of answers (unused)
**kwargs: Additional arguments
Returns:
float: Score based on XML tag presence and formatting
list[float]: List of scores based on XML tag presence and formatting
"""
count = 0.0
if text.count("<think>\n") == 1:
count += 0.125
if text.count("\n</think>\n") == 1:
count += 0.125
if text.count("\n<answer>\n") == 1:
count += 0.125
scores = []
for text in completions:
count = 0.0
if text.count("<think>\n") == 1:
count += 0.125
if text.count("\n</think>\n") == 1:
count += 0.125
if text.count("\n<answer>\n") == 1:
count += 0.125
count -= len(text.split("\n</answer>\n")[-1])*0.001
if text.count("\n</answer>") == 1:
count += 0.125
if text.count("\n</answer>") == 1:
count += 0.125
count -= (len(text.split("\n</answer>")[-1]) - 1)*0.001
return count
scores.append(count)
return scores
def grpo_loss(
model,
tokenizer,
prompts,
reward_funcs=None,
beta=0.1,
group_size=4,
epsilon=1e-4,
ref_model=None
):
"""
Calculates the GRPO loss with support for multiple reward functions.
model,
tokenizer,
batch,
reward_funcs=None,
beta=0.1,
group_size=4,
epsilon=1e-4,
ref_model=None,
max_tokens=128,
temperature=1.0
):
"""Modified GRPO loss function with better error handling"""
prompt_tokens, answer_tokens, prompt_text, answer_text = batch
batch_size = len(prompt_tokens)
Args:
model: The model to optimize
tokenizer: Tokenizer for processing inputs
prompts: List of input prompts
reward_funcs: List of reward functions to use
beta: KL penalty coefficient
group_size: Number of completions per prompt
epsilon: Small constant for numerical stability
ref_model: Optional reference model for KL divergence
Returns:
tuple: (loss, total_sequence_length, metrics_dict)
"""
batch_size = len(prompts)
# Generate multiple completions for each prompt
# Generate completions for each prompt
all_completions = []
all_completion_texts = []
for prompt in prompts:
for prompt in prompt_tokens:
prompt_tensor = mx.array(prompt)
prompt_completions = []
prompt_completion_texts = []
# Generate group_size completions for each prompt
for _ in range(group_size):
completion = generate(model, tokenizer, prompt)
prompt_completions.append(completion)
try:
completion_ids = generate_for_grpo(
model,
prompt_tensor,
max_tokens,
tokenizer=tokenizer,
temperature=temperature
)
# Verify completion_ids is not None
if completion_ids is None:
print("Warning: generate_for_grpo returned None")
break
completion_text = tokenizer.decode(completion_ids.tolist())
prompt_completions.append(completion_ids)
prompt_completion_texts.append(completion_text)
except Exception as e:
print(f"Error in completion generation: {str(e)}")
# Fallback to using original prompt
prompt_completions.append(prompt_tensor)
prompt_completion_texts.append(tokenizer.decode(prompt_tensor.tolist()))
all_completions.extend(prompt_completions)
all_completion_texts.extend(prompt_completion_texts)
# Verify we have the expected number of completions
assert len(all_completions) == batch_size * group_size
assert len(all_completion_texts) == batch_size * group_size
# Tokenize all prompts + completions
tokenized_inputs = tokenizer(
[p + c for p, c in zip(prompts * group_size, all_completions)],
return_tensors="np",
padding=True
)
# Expand answer_text and prompt_text to match completion groups
expanded_answers = []
expanded_prompts = []
for i in range(batch_size):
expanded_answers.extend([answer_text[i]] * group_size)
expanded_prompts.extend([prompt_text[i]] * group_size)
# Verify we have the expected number of completions
assert len(all_completions) == batch_size * group_size
assert len(all_completion_texts) == batch_size * group_size
inputs = mx.array(tokenized_inputs["input_ids"])
attention_mask = mx.array(tokenized_inputs["attention_mask"])
max_length = max(ids.shape[0] for ids in all_completions)
padded_completions = []
attention_masks = []
# Get lengths for proper masking
for completion_ids in all_completions:
padding_length = max_length - completion_ids.shape[0]
if padding_length > 0:
padding = mx.zeros((padding_length,), dtype=completion_ids.dtype)
padded_ids = mx.concatenate([completion_ids, padding])
mask = mx.concatenate([mx.ones_like(completion_ids), mx.zeros_like(padding)])
else:
padded_ids = completion_ids
mask = mx.ones_like(completion_ids)
padded_completions.append(padded_ids)
attention_masks.append(mask)
inputs = mx.stack(padded_completions)
attention_mask = mx.stack(attention_masks)
lengths = attention_mask.sum(axis=1)
# Get logits from current model
logits = model(inputs).astype(mx.float32)
# Calculate log probabilities
log_probs = mx.log_softmax(logits[:, :-1, :], axis=-1)
log_probs = nn.log_softmax(logits[:, :-1, :], axis=-1)
# Prepare targets
targets = inputs[:, 1:]
@ -197,7 +294,7 @@ def grpo_loss(
else:
ref_logits = model(inputs).astype(mx.float32)
ref_log_probs = mx.log_softmax(ref_logits[:, :-1, :], axis=-1)
ref_log_probs = nn.log_softmax(ref_logits[:, :-1, :], axis=-1)
ref_token_log_probs = mx.take_along_axis(
ref_log_probs,
targets.reshape(*targets.shape, 1),
@ -210,7 +307,11 @@ def grpo_loss(
# Calculate combined rewards from all reward functions
rewards = mx.zeros((len(all_completions),))
for reward_func in reward_funcs:
func_rewards = mx.array(reward_func(all_completions))
func_rewards = mx.array(reward_func(
prompts=prompt_text,
completions=all_completion_texts,
answer=answer_text
))
rewards += func_rewards
# Normalize rewards if using multiple reward functions
@ -245,8 +346,12 @@ def grpo_loss(
# Collect metrics for each reward function separately
reward_metrics = {}
for i, reward_func in enumerate(reward_funcs):
func_rewards = mx.array(reward_func(all_completions))
func_grouped_rewards = func_rewards.reshape(batch_size, group_size)
func_rewards = mx.array(reward_func(
prompts=prompt_text,
completions=all_completion_texts,
answer=answer_text
))
# func_grouped_rewards = func_rewards.reshape(batch_size, group_size)
reward_metrics[f'reward_func_{i}_mean'] = mx.mean(func_rewards)
reward_metrics[f'reward_func_{i}_std'] = mx.std(func_rewards)
@ -264,26 +369,30 @@ def grpo_loss(
def iterate_grpo_batches(dataset, tokenizer, batch_size, max_seq_length, train=False):
"""
Creates batches from prompt-answer pairs for GRPO training.
Creates batches from dataset entries for GRPO training.
Args:
dataset: List of (prompt, answer) pairs
dataset: List of (prompt_tokens, answer_tokens, prompt_str, answer_str) tuples
tokenizer: Tokenizer for processing inputs
batch_size: Size of each batch
max_seq_length: Maximum sequence length
train: Whether this is for training
Yields:
List of prompts for the current batch
Tuple containing:
- prompts_tokens: List of token sequences for current batch
- answers_tokens: List of token sequences
- prompts_text: List of prompt strings
- answers_text: List of answer strings
"""
# Verify dataset is not empty and has correct format
if not dataset or not isinstance(dataset[0], (tuple, list)) or len(dataset[0]) != 2:
raise ValueError("Dataset must be a list of (prompt, answer) pairs")
# Sort by combined length of prompt + answer
# Verify dataset format
if not dataset or not isinstance(dataset[0], tuple) or len(dataset[0]) != 4:
raise ValueError("Dataset must be list of (prompt_tokens, answer_tokens, prompt_str, answer_str) tuples")
# Sort by combined length of prompt + answer tokens
idx = sorted(range(len(dataset)),
key=lambda i: len(dataset[i][0]) + len(dataset[i][1]))
if len(dataset) < batch_size:
raise ValueError(
f"Dataset must have at least batch_size={batch_size} "
@ -306,22 +415,22 @@ def iterate_grpo_batches(dataset, tokenizer, batch_size, max_seq_length, train=F
indices = np.random.permutation(len(batch_idx)) if train else range(len(batch_idx))
for i in indices:
# Get current batch of prompt-answer pairs
# Get current batch
current_batch = [dataset[j] for j in batch_idx[i]]
# Extract prompts and answers
prompts = [pair[0] for pair in current_batch]
answers = [pair[1] for pair in current_batch]
if any(len(p) > max_seq_length for p in prompts):
# Extract all components
prompts_tokens = [item[0] for item in current_batch]
answers_tokens = [item[1] for item in current_batch]
prompts_text = [item[2] for item in current_batch]
answers_text = [item[3] for item in current_batch]
if any(len(p) > max_seq_length for p in prompts_tokens):
print(
f"[WARNING] Some prompts are longer than {max_seq_length} tokens. "
"Long prompts will be truncated."
)
# For GRPO, we only need to yield the prompts
# The answers will be used by the reward functions
yield prompts
yield prompts_tokens, answers_tokens, prompts_text, answers_text
if not train:
break
@ -342,11 +451,19 @@ def evaluate_grpo(
loss: callable = grpo_loss,
iterate_batches: callable = iterate_grpo_batches
):
"""
Evaluate model using GRPO loss.
Returns:
tuple: (average loss, number of tokens, average metrics)
"""
all_losses = 0
ntokens = 0
all_metrics = None # Initialize metrics dictionary
# Create iterator for batches
index_iterator = iter(range(num_batches)) if num_batches != -1 else iter(int, 1)
# Iterate through batches
for _, batch in zip(
index_iterator,
iterate_batches(
@ -356,35 +473,41 @@ def evaluate_grpo(
max_seq_length=max_seq_length,
),
):
prompts = batch
# Calculate loss for current batch
losses, toks, metrics = loss(
model=model,
tokenizer=tokenizer,
prompts=prompts,
batch=batch,
reward_funcs=reward_funcs,
beta=beta,
group_size=group_size,
epsilon=epsilon,
ref_model=ref_model
)
# Accumulate losses and tokens
all_losses += losses * toks
ntokens += toks
# Accumulate metrics
if all_metrics is None:
all_metrics = {k: v * toks for k, v in metrics.items()}
else:
for k, v in metrics.items():
all_metrics[k] += v * toks
# Evaluate accumulated values
mx.eval(all_losses, ntokens)
# Aggregate across distributed workers
all_losses = mx.distributed.all_sum(all_losses, stream=mx.cpu)
ntokens = mx.distributed.all_sum(ntokens, stream=mx.cpu)
all_metrics = {k: mx.distributed.all_sum(v) for k, v in all_metrics.items()}
# Calculate averages
avg_metrics = {k: (v / ntokens).item() for k, v in all_metrics.items()}
avg_loss = (all_losses / ntokens).item()
return avg_loss, ntokens, avg_metrics
@ -420,8 +543,18 @@ def train_grpo(
state = [model.state, optimizer.state]
def step(batch):
# Forward and backward pass
(loss, toks, metrics), grad = loss_value_and_grad(model, *batch)
(loss, toks, metrics), grad = loss_value_and_grad(
model,
tokenizer=tokenizer,
batch=batch,
reward_funcs=reward_funcs,
beta=args.beta,
group_size=args.group_size,
epsilon=args.epsilon,
ref_model=ref_model
)
# All reduce the gradients if running in distributed mode
grad = average_gradients(grad)
@ -430,7 +563,7 @@ def train_grpo(
optimizer.update(model, grad)
return loss, toks, metrics
loss_value_and_grad = nn.value_and_grad(model, loss)
losses = 0