mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-08-09 18:36:38 +08:00
first working prototype, will try training out at home
This commit is contained in:
parent
23d75cd7ad
commit
1d9e4802f0
@ -9,7 +9,7 @@ class GRPODataset:
|
||||
"""
|
||||
Dataset wrapper for GRPO training data.
|
||||
Each example should have a 'prompt' and 'answer' field.
|
||||
Returns data in (prompt, answer) tuple format required by GRPO trainer.
|
||||
Returns data in (prompt_tokens, answer_tokens, prompt_str, answer_str) tuple format.
|
||||
"""
|
||||
def __init__(
|
||||
self,
|
||||
@ -20,15 +20,14 @@ class GRPODataset:
|
||||
):
|
||||
self._data = []
|
||||
for item in data:
|
||||
# Get prompt and answer text
|
||||
prompt = str(item[prompt_key])
|
||||
answer = str(item[answer_key])
|
||||
|
||||
# Store as (prompt, answer) tuple
|
||||
self._data.append((prompt, answer))
|
||||
prompt_str = str(item[prompt_key])
|
||||
answer_str = str(item[answer_key])
|
||||
prompt_tokens = tokenizer.encode(prompt_str)
|
||||
answer_tokens = tokenizer.encode(answer_str)
|
||||
self._data.append((prompt_tokens, answer_tokens, prompt_str, answer_str))
|
||||
|
||||
def __getitem__(self, idx: int) -> Tuple[str, str]:
|
||||
"""Returns a (prompt, answer) tuple for the given index."""
|
||||
def __getitem__(self, idx: int) -> Tuple[List[int], List[int], str, str]:
|
||||
"""Returns a (prompt_tokens, answer_tokens, prompt_str, answer_str) tuple."""
|
||||
return self._data[idx]
|
||||
|
||||
def __len__(self) -> int:
|
||||
|
@ -12,7 +12,7 @@ from mlx.utils import tree_flatten
|
||||
|
||||
from .trainer import grad_checkpoint, TrainingArgs, TrainingCallback, average_gradients, iterate_batches
|
||||
|
||||
from mlx_lm import generate
|
||||
from mlx_lm.utils import generate_step
|
||||
|
||||
|
||||
@dataclass
|
||||
@ -35,6 +35,70 @@ class GRPOTrainingArgs(TrainingArgs):
|
||||
)
|
||||
|
||||
|
||||
def generate_for_grpo(
|
||||
model,
|
||||
prompt,
|
||||
max_tokens,
|
||||
tokenizer,
|
||||
temperature=1.0
|
||||
):
|
||||
try:
|
||||
|
||||
# Ensure prompt is the right shape
|
||||
if len(prompt.shape) == 1:
|
||||
prompt = prompt[None, :]
|
||||
|
||||
# Initialize generation
|
||||
generated = []
|
||||
current_prompt = prompt[0]
|
||||
|
||||
for step in range(max_tokens):
|
||||
try:
|
||||
# Get model output with explicit shape checking
|
||||
current_batch = current_prompt[None, :]
|
||||
|
||||
logits = model(current_batch)
|
||||
|
||||
# Ensure we have the last token logits
|
||||
token_logits = logits[0, -1]
|
||||
|
||||
# Apply temperature and get probabilities
|
||||
if temperature > 0:
|
||||
token_logits = token_logits / temperature
|
||||
probs = mx.softmax(token_logits)
|
||||
|
||||
# Sample the next token
|
||||
next_token = mx.random.categorical(probs[None, :])
|
||||
next_token = next_token[0]
|
||||
|
||||
# Force evaluation to catch any issues
|
||||
mx.eval(next_token)
|
||||
token_value = next_token.item()
|
||||
|
||||
# Add to generated sequence
|
||||
generated.append(next_token)
|
||||
current_prompt = mx.concatenate([current_prompt, next_token[None]])
|
||||
|
||||
if token_value == tokenizer.eos_token_id:
|
||||
break
|
||||
|
||||
except Exception as e:
|
||||
raise
|
||||
|
||||
if not generated:
|
||||
return prompt[0]
|
||||
|
||||
try:
|
||||
result = mx.concatenate([prompt[0], mx.stack(generated)])
|
||||
mx.eval(result)
|
||||
return result
|
||||
except Exception as e:
|
||||
raise
|
||||
|
||||
except Exception as e:
|
||||
raise
|
||||
|
||||
|
||||
def r1_extract_xml_answer(text: str) -> str:
|
||||
"""Extracts the answer from an XML formatted text string."""
|
||||
try:
|
||||
@ -45,42 +109,45 @@ def r1_extract_xml_answer(text: str) -> str:
|
||||
print("[extract_xml_answer] Failed to extract answer from: ", text)
|
||||
return ""
|
||||
|
||||
def r1_accuracy_reward_func(prompts, completions, answer, **kwargs) -> list[float]:
|
||||
def r1_accuracy_reward_func(prompts: list, completions: list, answer: list, **kwargs) -> list[float]:
|
||||
"""Calculates reward based on accuracy of extracted answers.
|
||||
|
||||
Args:
|
||||
prompts: List of input prompts
|
||||
completions: List of completion strings
|
||||
answer: Expected answer or list of answers
|
||||
**kwargs: Additional arguments
|
||||
|
||||
Returns:
|
||||
list[float]: Reward values for each completion
|
||||
"""
|
||||
extracted_responses = [r1_extract_xml_answer(r) for r in completions]
|
||||
q = prompts[0] if isinstance(prompts[0], str) else prompts[0][-1]['content']
|
||||
return [2.0 if r == a else 0.0 for r, a in zip(extracted_responses, answer)]
|
||||
|
||||
def r1_int_reward_func(completions, **kwargs) -> list[float]:
|
||||
def r1_int_reward_func(prompts: list, completions: list, answer: list, **kwargs) -> list[float]:
|
||||
"""Rewards numerical responses.
|
||||
|
||||
Args:
|
||||
prompts: List of input prompts
|
||||
completions: List of completion strings
|
||||
answer: Expected answer or list of answers
|
||||
**kwargs: Additional arguments
|
||||
|
||||
Returns:
|
||||
list[float]: Reward values for each completion
|
||||
"""
|
||||
extracted_responses = [r1_extract_xml_answer(r) for r in completions]
|
||||
return [0.5 if r.isdigit() else 0.0 for r in extracted_responses]
|
||||
|
||||
def r1_strict_format_reward_func(completions, **kwargs) -> list[float]:
|
||||
def r1_soft_format_reward_func(prompts: list, completions: list, answer: list, **kwargs) -> list[float]:
|
||||
"""Rewards completions with flexible XML format."""
|
||||
pattern = r"<think>.*?</think>\s*<answer>.*?</answer>"
|
||||
matches = [re.match(pattern, r) for r in completions]
|
||||
return [0.5 if match else 0.0 for match in matches]
|
||||
|
||||
def r1_strict_format_reward_func(prompts: list, completions: list, answer: list, **kwargs) -> list[float]:
|
||||
"""Rewards completions with strict XML format.
|
||||
|
||||
Args:
|
||||
prompts: List of input prompts
|
||||
completions: List of completion strings
|
||||
answer: Expected answer or list of answers
|
||||
**kwargs: Additional arguments
|
||||
|
||||
Returns:
|
||||
list[float]: Reward values for each completion
|
||||
"""
|
||||
@ -88,98 +155,128 @@ def r1_strict_format_reward_func(completions, **kwargs) -> list[float]:
|
||||
matches = [re.match(pattern, r) for r in completions]
|
||||
return [0.5 if match else 0.0 for match in matches]
|
||||
|
||||
def r1_soft_format_reward_func(completions, **kwargs) -> list[float]:
|
||||
"""Rewards completions with flexible XML format.
|
||||
|
||||
Args:
|
||||
completions: List of completion strings
|
||||
**kwargs: Additional arguments
|
||||
|
||||
Returns:
|
||||
list[float]: Reward values for each completion
|
||||
"""
|
||||
pattern = r"<think>.*?</think>\s*<answer>.*?</answer>"
|
||||
matches = [re.match(pattern, r) for r in completions]
|
||||
return [0.5 if match else 0.0 for match in matches]
|
||||
|
||||
def r1_count_xml(text: str) -> float:
|
||||
def r1_count_xml(prompts: list, completions: list, answer: list, **kwargs) -> list[float]:
|
||||
"""Calculates score based on XML formatting.
|
||||
|
||||
Args:
|
||||
text: Input text string
|
||||
|
||||
prompts: List of input prompts (unused)
|
||||
completions: List of completion strings to evaluate
|
||||
answer: Expected answer or list of answers (unused)
|
||||
**kwargs: Additional arguments
|
||||
Returns:
|
||||
float: Score based on XML tag presence and formatting
|
||||
list[float]: List of scores based on XML tag presence and formatting
|
||||
"""
|
||||
count = 0.0
|
||||
if text.count("<think>\n") == 1:
|
||||
count += 0.125
|
||||
if text.count("\n</think>\n") == 1:
|
||||
count += 0.125
|
||||
if text.count("\n<answer>\n") == 1:
|
||||
count += 0.125
|
||||
scores = []
|
||||
for text in completions:
|
||||
count = 0.0
|
||||
if text.count("<think>\n") == 1:
|
||||
count += 0.125
|
||||
if text.count("\n</think>\n") == 1:
|
||||
count += 0.125
|
||||
if text.count("\n<answer>\n") == 1:
|
||||
count += 0.125
|
||||
count -= len(text.split("\n</answer>\n")[-1])*0.001
|
||||
if text.count("\n</answer>") == 1:
|
||||
count += 0.125
|
||||
if text.count("\n</answer>") == 1:
|
||||
count += 0.125
|
||||
count -= (len(text.split("\n</answer>")[-1]) - 1)*0.001
|
||||
return count
|
||||
scores.append(count)
|
||||
return scores
|
||||
|
||||
|
||||
def grpo_loss(
|
||||
model,
|
||||
tokenizer,
|
||||
prompts,
|
||||
reward_funcs=None,
|
||||
beta=0.1,
|
||||
group_size=4,
|
||||
epsilon=1e-4,
|
||||
ref_model=None
|
||||
):
|
||||
"""
|
||||
Calculates the GRPO loss with support for multiple reward functions.
|
||||
model,
|
||||
tokenizer,
|
||||
batch,
|
||||
reward_funcs=None,
|
||||
beta=0.1,
|
||||
group_size=4,
|
||||
epsilon=1e-4,
|
||||
ref_model=None,
|
||||
max_tokens=128,
|
||||
temperature=1.0
|
||||
):
|
||||
"""Modified GRPO loss function with better error handling"""
|
||||
prompt_tokens, answer_tokens, prompt_text, answer_text = batch
|
||||
batch_size = len(prompt_tokens)
|
||||
|
||||
Args:
|
||||
model: The model to optimize
|
||||
tokenizer: Tokenizer for processing inputs
|
||||
prompts: List of input prompts
|
||||
reward_funcs: List of reward functions to use
|
||||
beta: KL penalty coefficient
|
||||
group_size: Number of completions per prompt
|
||||
epsilon: Small constant for numerical stability
|
||||
ref_model: Optional reference model for KL divergence
|
||||
|
||||
Returns:
|
||||
tuple: (loss, total_sequence_length, metrics_dict)
|
||||
"""
|
||||
batch_size = len(prompts)
|
||||
# Generate multiple completions for each prompt
|
||||
# Generate completions for each prompt
|
||||
all_completions = []
|
||||
all_completion_texts = []
|
||||
|
||||
for prompt in prompts:
|
||||
for prompt in prompt_tokens:
|
||||
prompt_tensor = mx.array(prompt)
|
||||
prompt_completions = []
|
||||
prompt_completion_texts = []
|
||||
|
||||
# Generate group_size completions for each prompt
|
||||
for _ in range(group_size):
|
||||
completion = generate(model, tokenizer, prompt)
|
||||
prompt_completions.append(completion)
|
||||
try:
|
||||
completion_ids = generate_for_grpo(
|
||||
model,
|
||||
prompt_tensor,
|
||||
max_tokens,
|
||||
tokenizer=tokenizer,
|
||||
temperature=temperature
|
||||
)
|
||||
|
||||
# Verify completion_ids is not None
|
||||
if completion_ids is None:
|
||||
print("Warning: generate_for_grpo returned None")
|
||||
break
|
||||
|
||||
completion_text = tokenizer.decode(completion_ids.tolist())
|
||||
|
||||
prompt_completions.append(completion_ids)
|
||||
prompt_completion_texts.append(completion_text)
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error in completion generation: {str(e)}")
|
||||
# Fallback to using original prompt
|
||||
prompt_completions.append(prompt_tensor)
|
||||
prompt_completion_texts.append(tokenizer.decode(prompt_tensor.tolist()))
|
||||
|
||||
all_completions.extend(prompt_completions)
|
||||
all_completion_texts.extend(prompt_completion_texts)
|
||||
|
||||
# Verify we have the expected number of completions
|
||||
assert len(all_completions) == batch_size * group_size
|
||||
assert len(all_completion_texts) == batch_size * group_size
|
||||
|
||||
# Tokenize all prompts + completions
|
||||
tokenized_inputs = tokenizer(
|
||||
[p + c for p, c in zip(prompts * group_size, all_completions)],
|
||||
return_tensors="np",
|
||||
padding=True
|
||||
)
|
||||
# Expand answer_text and prompt_text to match completion groups
|
||||
expanded_answers = []
|
||||
expanded_prompts = []
|
||||
for i in range(batch_size):
|
||||
expanded_answers.extend([answer_text[i]] * group_size)
|
||||
expanded_prompts.extend([prompt_text[i]] * group_size)
|
||||
|
||||
# Verify we have the expected number of completions
|
||||
assert len(all_completions) == batch_size * group_size
|
||||
assert len(all_completion_texts) == batch_size * group_size
|
||||
|
||||
inputs = mx.array(tokenized_inputs["input_ids"])
|
||||
attention_mask = mx.array(tokenized_inputs["attention_mask"])
|
||||
max_length = max(ids.shape[0] for ids in all_completions)
|
||||
padded_completions = []
|
||||
attention_masks = []
|
||||
|
||||
# Get lengths for proper masking
|
||||
for completion_ids in all_completions:
|
||||
padding_length = max_length - completion_ids.shape[0]
|
||||
if padding_length > 0:
|
||||
padding = mx.zeros((padding_length,), dtype=completion_ids.dtype)
|
||||
padded_ids = mx.concatenate([completion_ids, padding])
|
||||
mask = mx.concatenate([mx.ones_like(completion_ids), mx.zeros_like(padding)])
|
||||
else:
|
||||
padded_ids = completion_ids
|
||||
mask = mx.ones_like(completion_ids)
|
||||
padded_completions.append(padded_ids)
|
||||
attention_masks.append(mask)
|
||||
|
||||
inputs = mx.stack(padded_completions)
|
||||
attention_mask = mx.stack(attention_masks)
|
||||
lengths = attention_mask.sum(axis=1)
|
||||
|
||||
# Get logits from current model
|
||||
logits = model(inputs).astype(mx.float32)
|
||||
|
||||
# Calculate log probabilities
|
||||
log_probs = mx.log_softmax(logits[:, :-1, :], axis=-1)
|
||||
log_probs = nn.log_softmax(logits[:, :-1, :], axis=-1)
|
||||
|
||||
# Prepare targets
|
||||
targets = inputs[:, 1:]
|
||||
@ -197,7 +294,7 @@ def grpo_loss(
|
||||
else:
|
||||
ref_logits = model(inputs).astype(mx.float32)
|
||||
|
||||
ref_log_probs = mx.log_softmax(ref_logits[:, :-1, :], axis=-1)
|
||||
ref_log_probs = nn.log_softmax(ref_logits[:, :-1, :], axis=-1)
|
||||
ref_token_log_probs = mx.take_along_axis(
|
||||
ref_log_probs,
|
||||
targets.reshape(*targets.shape, 1),
|
||||
@ -210,7 +307,11 @@ def grpo_loss(
|
||||
# Calculate combined rewards from all reward functions
|
||||
rewards = mx.zeros((len(all_completions),))
|
||||
for reward_func in reward_funcs:
|
||||
func_rewards = mx.array(reward_func(all_completions))
|
||||
func_rewards = mx.array(reward_func(
|
||||
prompts=prompt_text,
|
||||
completions=all_completion_texts,
|
||||
answer=answer_text
|
||||
))
|
||||
rewards += func_rewards
|
||||
|
||||
# Normalize rewards if using multiple reward functions
|
||||
@ -245,8 +346,12 @@ def grpo_loss(
|
||||
# Collect metrics for each reward function separately
|
||||
reward_metrics = {}
|
||||
for i, reward_func in enumerate(reward_funcs):
|
||||
func_rewards = mx.array(reward_func(all_completions))
|
||||
func_grouped_rewards = func_rewards.reshape(batch_size, group_size)
|
||||
func_rewards = mx.array(reward_func(
|
||||
prompts=prompt_text,
|
||||
completions=all_completion_texts,
|
||||
answer=answer_text
|
||||
))
|
||||
# func_grouped_rewards = func_rewards.reshape(batch_size, group_size)
|
||||
reward_metrics[f'reward_func_{i}_mean'] = mx.mean(func_rewards)
|
||||
reward_metrics[f'reward_func_{i}_std'] = mx.std(func_rewards)
|
||||
|
||||
@ -264,26 +369,30 @@ def grpo_loss(
|
||||
|
||||
def iterate_grpo_batches(dataset, tokenizer, batch_size, max_seq_length, train=False):
|
||||
"""
|
||||
Creates batches from prompt-answer pairs for GRPO training.
|
||||
Creates batches from dataset entries for GRPO training.
|
||||
|
||||
Args:
|
||||
dataset: List of (prompt, answer) pairs
|
||||
dataset: List of (prompt_tokens, answer_tokens, prompt_str, answer_str) tuples
|
||||
tokenizer: Tokenizer for processing inputs
|
||||
batch_size: Size of each batch
|
||||
max_seq_length: Maximum sequence length
|
||||
train: Whether this is for training
|
||||
|
||||
Yields:
|
||||
List of prompts for the current batch
|
||||
Tuple containing:
|
||||
- prompts_tokens: List of token sequences for current batch
|
||||
- answers_tokens: List of token sequences
|
||||
- prompts_text: List of prompt strings
|
||||
- answers_text: List of answer strings
|
||||
"""
|
||||
# Verify dataset is not empty and has correct format
|
||||
if not dataset or not isinstance(dataset[0], (tuple, list)) or len(dataset[0]) != 2:
|
||||
raise ValueError("Dataset must be a list of (prompt, answer) pairs")
|
||||
|
||||
# Sort by combined length of prompt + answer
|
||||
# Verify dataset format
|
||||
if not dataset or not isinstance(dataset[0], tuple) or len(dataset[0]) != 4:
|
||||
raise ValueError("Dataset must be list of (prompt_tokens, answer_tokens, prompt_str, answer_str) tuples")
|
||||
|
||||
# Sort by combined length of prompt + answer tokens
|
||||
idx = sorted(range(len(dataset)),
|
||||
key=lambda i: len(dataset[i][0]) + len(dataset[i][1]))
|
||||
|
||||
|
||||
if len(dataset) < batch_size:
|
||||
raise ValueError(
|
||||
f"Dataset must have at least batch_size={batch_size} "
|
||||
@ -306,22 +415,22 @@ def iterate_grpo_batches(dataset, tokenizer, batch_size, max_seq_length, train=F
|
||||
indices = np.random.permutation(len(batch_idx)) if train else range(len(batch_idx))
|
||||
|
||||
for i in indices:
|
||||
# Get current batch of prompt-answer pairs
|
||||
# Get current batch
|
||||
current_batch = [dataset[j] for j in batch_idx[i]]
|
||||
|
||||
# Extract prompts and answers
|
||||
prompts = [pair[0] for pair in current_batch]
|
||||
answers = [pair[1] for pair in current_batch]
|
||||
|
||||
if any(len(p) > max_seq_length for p in prompts):
|
||||
# Extract all components
|
||||
prompts_tokens = [item[0] for item in current_batch]
|
||||
answers_tokens = [item[1] for item in current_batch]
|
||||
prompts_text = [item[2] for item in current_batch]
|
||||
answers_text = [item[3] for item in current_batch]
|
||||
|
||||
if any(len(p) > max_seq_length for p in prompts_tokens):
|
||||
print(
|
||||
f"[WARNING] Some prompts are longer than {max_seq_length} tokens. "
|
||||
"Long prompts will be truncated."
|
||||
)
|
||||
|
||||
# For GRPO, we only need to yield the prompts
|
||||
# The answers will be used by the reward functions
|
||||
yield prompts
|
||||
|
||||
yield prompts_tokens, answers_tokens, prompts_text, answers_text
|
||||
|
||||
if not train:
|
||||
break
|
||||
@ -342,11 +451,19 @@ def evaluate_grpo(
|
||||
loss: callable = grpo_loss,
|
||||
iterate_batches: callable = iterate_grpo_batches
|
||||
):
|
||||
"""
|
||||
Evaluate model using GRPO loss.
|
||||
Returns:
|
||||
tuple: (average loss, number of tokens, average metrics)
|
||||
"""
|
||||
all_losses = 0
|
||||
ntokens = 0
|
||||
|
||||
all_metrics = None # Initialize metrics dictionary
|
||||
|
||||
# Create iterator for batches
|
||||
index_iterator = iter(range(num_batches)) if num_batches != -1 else iter(int, 1)
|
||||
|
||||
|
||||
# Iterate through batches
|
||||
for _, batch in zip(
|
||||
index_iterator,
|
||||
iterate_batches(
|
||||
@ -356,35 +473,41 @@ def evaluate_grpo(
|
||||
max_seq_length=max_seq_length,
|
||||
),
|
||||
):
|
||||
prompts = batch
|
||||
# Calculate loss for current batch
|
||||
losses, toks, metrics = loss(
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
prompts=prompts,
|
||||
batch=batch,
|
||||
reward_funcs=reward_funcs,
|
||||
beta=beta,
|
||||
group_size=group_size,
|
||||
epsilon=epsilon,
|
||||
ref_model=ref_model
|
||||
)
|
||||
|
||||
# Accumulate losses and tokens
|
||||
all_losses += losses * toks
|
||||
ntokens += toks
|
||||
|
||||
|
||||
# Accumulate metrics
|
||||
if all_metrics is None:
|
||||
all_metrics = {k: v * toks for k, v in metrics.items()}
|
||||
else:
|
||||
for k, v in metrics.items():
|
||||
all_metrics[k] += v * toks
|
||||
|
||||
|
||||
# Evaluate accumulated values
|
||||
mx.eval(all_losses, ntokens)
|
||||
|
||||
|
||||
# Aggregate across distributed workers
|
||||
all_losses = mx.distributed.all_sum(all_losses, stream=mx.cpu)
|
||||
ntokens = mx.distributed.all_sum(ntokens, stream=mx.cpu)
|
||||
all_metrics = {k: mx.distributed.all_sum(v) for k, v in all_metrics.items()}
|
||||
|
||||
|
||||
# Calculate averages
|
||||
avg_metrics = {k: (v / ntokens).item() for k, v in all_metrics.items()}
|
||||
avg_loss = (all_losses / ntokens).item()
|
||||
|
||||
|
||||
return avg_loss, ntokens, avg_metrics
|
||||
|
||||
|
||||
@ -420,8 +543,18 @@ def train_grpo(
|
||||
state = [model.state, optimizer.state]
|
||||
|
||||
def step(batch):
|
||||
|
||||
# Forward and backward pass
|
||||
(loss, toks, metrics), grad = loss_value_and_grad(model, *batch)
|
||||
(loss, toks, metrics), grad = loss_value_and_grad(
|
||||
model,
|
||||
tokenizer=tokenizer,
|
||||
batch=batch,
|
||||
reward_funcs=reward_funcs,
|
||||
beta=args.beta,
|
||||
group_size=args.group_size,
|
||||
epsilon=args.epsilon,
|
||||
ref_model=ref_model
|
||||
)
|
||||
|
||||
# All reduce the gradients if running in distributed mode
|
||||
grad = average_gradients(grad)
|
||||
@ -430,7 +563,7 @@ def train_grpo(
|
||||
optimizer.update(model, grad)
|
||||
|
||||
return loss, toks, metrics
|
||||
|
||||
|
||||
loss_value_and_grad = nn.value_and_grad(model, loss)
|
||||
|
||||
losses = 0
|
||||
|
Loading…
Reference in New Issue
Block a user