mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-08-10 19:26:46 +08:00
first working prototype, will try training out at home
This commit is contained in:
parent
23d75cd7ad
commit
1d9e4802f0
@ -9,7 +9,7 @@ class GRPODataset:
|
|||||||
"""
|
"""
|
||||||
Dataset wrapper for GRPO training data.
|
Dataset wrapper for GRPO training data.
|
||||||
Each example should have a 'prompt' and 'answer' field.
|
Each example should have a 'prompt' and 'answer' field.
|
||||||
Returns data in (prompt, answer) tuple format required by GRPO trainer.
|
Returns data in (prompt_tokens, answer_tokens, prompt_str, answer_str) tuple format.
|
||||||
"""
|
"""
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@ -20,15 +20,14 @@ class GRPODataset:
|
|||||||
):
|
):
|
||||||
self._data = []
|
self._data = []
|
||||||
for item in data:
|
for item in data:
|
||||||
# Get prompt and answer text
|
prompt_str = str(item[prompt_key])
|
||||||
prompt = str(item[prompt_key])
|
answer_str = str(item[answer_key])
|
||||||
answer = str(item[answer_key])
|
prompt_tokens = tokenizer.encode(prompt_str)
|
||||||
|
answer_tokens = tokenizer.encode(answer_str)
|
||||||
# Store as (prompt, answer) tuple
|
self._data.append((prompt_tokens, answer_tokens, prompt_str, answer_str))
|
||||||
self._data.append((prompt, answer))
|
|
||||||
|
|
||||||
def __getitem__(self, idx: int) -> Tuple[str, str]:
|
def __getitem__(self, idx: int) -> Tuple[List[int], List[int], str, str]:
|
||||||
"""Returns a (prompt, answer) tuple for the given index."""
|
"""Returns a (prompt_tokens, answer_tokens, prompt_str, answer_str) tuple."""
|
||||||
return self._data[idx]
|
return self._data[idx]
|
||||||
|
|
||||||
def __len__(self) -> int:
|
def __len__(self) -> int:
|
||||||
|
@ -12,7 +12,7 @@ from mlx.utils import tree_flatten
|
|||||||
|
|
||||||
from .trainer import grad_checkpoint, TrainingArgs, TrainingCallback, average_gradients, iterate_batches
|
from .trainer import grad_checkpoint, TrainingArgs, TrainingCallback, average_gradients, iterate_batches
|
||||||
|
|
||||||
from mlx_lm import generate
|
from mlx_lm.utils import generate_step
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@ -35,6 +35,70 @@ class GRPOTrainingArgs(TrainingArgs):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def generate_for_grpo(
|
||||||
|
model,
|
||||||
|
prompt,
|
||||||
|
max_tokens,
|
||||||
|
tokenizer,
|
||||||
|
temperature=1.0
|
||||||
|
):
|
||||||
|
try:
|
||||||
|
|
||||||
|
# Ensure prompt is the right shape
|
||||||
|
if len(prompt.shape) == 1:
|
||||||
|
prompt = prompt[None, :]
|
||||||
|
|
||||||
|
# Initialize generation
|
||||||
|
generated = []
|
||||||
|
current_prompt = prompt[0]
|
||||||
|
|
||||||
|
for step in range(max_tokens):
|
||||||
|
try:
|
||||||
|
# Get model output with explicit shape checking
|
||||||
|
current_batch = current_prompt[None, :]
|
||||||
|
|
||||||
|
logits = model(current_batch)
|
||||||
|
|
||||||
|
# Ensure we have the last token logits
|
||||||
|
token_logits = logits[0, -1]
|
||||||
|
|
||||||
|
# Apply temperature and get probabilities
|
||||||
|
if temperature > 0:
|
||||||
|
token_logits = token_logits / temperature
|
||||||
|
probs = mx.softmax(token_logits)
|
||||||
|
|
||||||
|
# Sample the next token
|
||||||
|
next_token = mx.random.categorical(probs[None, :])
|
||||||
|
next_token = next_token[0]
|
||||||
|
|
||||||
|
# Force evaluation to catch any issues
|
||||||
|
mx.eval(next_token)
|
||||||
|
token_value = next_token.item()
|
||||||
|
|
||||||
|
# Add to generated sequence
|
||||||
|
generated.append(next_token)
|
||||||
|
current_prompt = mx.concatenate([current_prompt, next_token[None]])
|
||||||
|
|
||||||
|
if token_value == tokenizer.eos_token_id:
|
||||||
|
break
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
raise
|
||||||
|
|
||||||
|
if not generated:
|
||||||
|
return prompt[0]
|
||||||
|
|
||||||
|
try:
|
||||||
|
result = mx.concatenate([prompt[0], mx.stack(generated)])
|
||||||
|
mx.eval(result)
|
||||||
|
return result
|
||||||
|
except Exception as e:
|
||||||
|
raise
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
raise
|
||||||
|
|
||||||
|
|
||||||
def r1_extract_xml_answer(text: str) -> str:
|
def r1_extract_xml_answer(text: str) -> str:
|
||||||
"""Extracts the answer from an XML formatted text string."""
|
"""Extracts the answer from an XML formatted text string."""
|
||||||
try:
|
try:
|
||||||
@ -45,42 +109,45 @@ def r1_extract_xml_answer(text: str) -> str:
|
|||||||
print("[extract_xml_answer] Failed to extract answer from: ", text)
|
print("[extract_xml_answer] Failed to extract answer from: ", text)
|
||||||
return ""
|
return ""
|
||||||
|
|
||||||
def r1_accuracy_reward_func(prompts, completions, answer, **kwargs) -> list[float]:
|
def r1_accuracy_reward_func(prompts: list, completions: list, answer: list, **kwargs) -> list[float]:
|
||||||
"""Calculates reward based on accuracy of extracted answers.
|
"""Calculates reward based on accuracy of extracted answers.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
prompts: List of input prompts
|
prompts: List of input prompts
|
||||||
completions: List of completion strings
|
completions: List of completion strings
|
||||||
answer: Expected answer or list of answers
|
answer: Expected answer or list of answers
|
||||||
**kwargs: Additional arguments
|
**kwargs: Additional arguments
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
list[float]: Reward values for each completion
|
list[float]: Reward values for each completion
|
||||||
"""
|
"""
|
||||||
extracted_responses = [r1_extract_xml_answer(r) for r in completions]
|
extracted_responses = [r1_extract_xml_answer(r) for r in completions]
|
||||||
q = prompts[0] if isinstance(prompts[0], str) else prompts[0][-1]['content']
|
|
||||||
return [2.0 if r == a else 0.0 for r, a in zip(extracted_responses, answer)]
|
return [2.0 if r == a else 0.0 for r, a in zip(extracted_responses, answer)]
|
||||||
|
|
||||||
def r1_int_reward_func(completions, **kwargs) -> list[float]:
|
def r1_int_reward_func(prompts: list, completions: list, answer: list, **kwargs) -> list[float]:
|
||||||
"""Rewards numerical responses.
|
"""Rewards numerical responses.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
prompts: List of input prompts
|
||||||
completions: List of completion strings
|
completions: List of completion strings
|
||||||
|
answer: Expected answer or list of answers
|
||||||
**kwargs: Additional arguments
|
**kwargs: Additional arguments
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
list[float]: Reward values for each completion
|
list[float]: Reward values for each completion
|
||||||
"""
|
"""
|
||||||
extracted_responses = [r1_extract_xml_answer(r) for r in completions]
|
extracted_responses = [r1_extract_xml_answer(r) for r in completions]
|
||||||
return [0.5 if r.isdigit() else 0.0 for r in extracted_responses]
|
return [0.5 if r.isdigit() else 0.0 for r in extracted_responses]
|
||||||
|
|
||||||
def r1_strict_format_reward_func(completions, **kwargs) -> list[float]:
|
def r1_soft_format_reward_func(prompts: list, completions: list, answer: list, **kwargs) -> list[float]:
|
||||||
|
"""Rewards completions with flexible XML format."""
|
||||||
|
pattern = r"<think>.*?</think>\s*<answer>.*?</answer>"
|
||||||
|
matches = [re.match(pattern, r) for r in completions]
|
||||||
|
return [0.5 if match else 0.0 for match in matches]
|
||||||
|
|
||||||
|
def r1_strict_format_reward_func(prompts: list, completions: list, answer: list, **kwargs) -> list[float]:
|
||||||
"""Rewards completions with strict XML format.
|
"""Rewards completions with strict XML format.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
prompts: List of input prompts
|
||||||
completions: List of completion strings
|
completions: List of completion strings
|
||||||
|
answer: Expected answer or list of answers
|
||||||
**kwargs: Additional arguments
|
**kwargs: Additional arguments
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
list[float]: Reward values for each completion
|
list[float]: Reward values for each completion
|
||||||
"""
|
"""
|
||||||
@ -88,98 +155,128 @@ def r1_strict_format_reward_func(completions, **kwargs) -> list[float]:
|
|||||||
matches = [re.match(pattern, r) for r in completions]
|
matches = [re.match(pattern, r) for r in completions]
|
||||||
return [0.5 if match else 0.0 for match in matches]
|
return [0.5 if match else 0.0 for match in matches]
|
||||||
|
|
||||||
def r1_soft_format_reward_func(completions, **kwargs) -> list[float]:
|
def r1_count_xml(prompts: list, completions: list, answer: list, **kwargs) -> list[float]:
|
||||||
"""Rewards completions with flexible XML format.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
completions: List of completion strings
|
|
||||||
**kwargs: Additional arguments
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
list[float]: Reward values for each completion
|
|
||||||
"""
|
|
||||||
pattern = r"<think>.*?</think>\s*<answer>.*?</answer>"
|
|
||||||
matches = [re.match(pattern, r) for r in completions]
|
|
||||||
return [0.5 if match else 0.0 for match in matches]
|
|
||||||
|
|
||||||
def r1_count_xml(text: str) -> float:
|
|
||||||
"""Calculates score based on XML formatting.
|
"""Calculates score based on XML formatting.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
text: Input text string
|
prompts: List of input prompts (unused)
|
||||||
|
completions: List of completion strings to evaluate
|
||||||
|
answer: Expected answer or list of answers (unused)
|
||||||
|
**kwargs: Additional arguments
|
||||||
Returns:
|
Returns:
|
||||||
float: Score based on XML tag presence and formatting
|
list[float]: List of scores based on XML tag presence and formatting
|
||||||
"""
|
"""
|
||||||
count = 0.0
|
scores = []
|
||||||
if text.count("<think>\n") == 1:
|
for text in completions:
|
||||||
count += 0.125
|
count = 0.0
|
||||||
if text.count("\n</think>\n") == 1:
|
if text.count("<think>\n") == 1:
|
||||||
count += 0.125
|
count += 0.125
|
||||||
if text.count("\n<answer>\n") == 1:
|
if text.count("\n</think>\n") == 1:
|
||||||
count += 0.125
|
count += 0.125
|
||||||
|
if text.count("\n<answer>\n") == 1:
|
||||||
|
count += 0.125
|
||||||
count -= len(text.split("\n</answer>\n")[-1])*0.001
|
count -= len(text.split("\n</answer>\n")[-1])*0.001
|
||||||
if text.count("\n</answer>") == 1:
|
if text.count("\n</answer>") == 1:
|
||||||
count += 0.125
|
count += 0.125
|
||||||
count -= (len(text.split("\n</answer>")[-1]) - 1)*0.001
|
count -= (len(text.split("\n</answer>")[-1]) - 1)*0.001
|
||||||
return count
|
scores.append(count)
|
||||||
|
return scores
|
||||||
|
|
||||||
|
|
||||||
def grpo_loss(
|
def grpo_loss(
|
||||||
model,
|
model,
|
||||||
tokenizer,
|
tokenizer,
|
||||||
prompts,
|
batch,
|
||||||
reward_funcs=None,
|
reward_funcs=None,
|
||||||
beta=0.1,
|
beta=0.1,
|
||||||
group_size=4,
|
group_size=4,
|
||||||
epsilon=1e-4,
|
epsilon=1e-4,
|
||||||
ref_model=None
|
ref_model=None,
|
||||||
):
|
max_tokens=128,
|
||||||
"""
|
temperature=1.0
|
||||||
Calculates the GRPO loss with support for multiple reward functions.
|
):
|
||||||
|
"""Modified GRPO loss function with better error handling"""
|
||||||
|
prompt_tokens, answer_tokens, prompt_text, answer_text = batch
|
||||||
|
batch_size = len(prompt_tokens)
|
||||||
|
|
||||||
Args:
|
# Generate completions for each prompt
|
||||||
model: The model to optimize
|
|
||||||
tokenizer: Tokenizer for processing inputs
|
|
||||||
prompts: List of input prompts
|
|
||||||
reward_funcs: List of reward functions to use
|
|
||||||
beta: KL penalty coefficient
|
|
||||||
group_size: Number of completions per prompt
|
|
||||||
epsilon: Small constant for numerical stability
|
|
||||||
ref_model: Optional reference model for KL divergence
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
tuple: (loss, total_sequence_length, metrics_dict)
|
|
||||||
"""
|
|
||||||
batch_size = len(prompts)
|
|
||||||
# Generate multiple completions for each prompt
|
|
||||||
all_completions = []
|
all_completions = []
|
||||||
|
all_completion_texts = []
|
||||||
|
|
||||||
for prompt in prompts:
|
for prompt in prompt_tokens:
|
||||||
|
prompt_tensor = mx.array(prompt)
|
||||||
prompt_completions = []
|
prompt_completions = []
|
||||||
|
prompt_completion_texts = []
|
||||||
|
|
||||||
|
# Generate group_size completions for each prompt
|
||||||
for _ in range(group_size):
|
for _ in range(group_size):
|
||||||
completion = generate(model, tokenizer, prompt)
|
try:
|
||||||
prompt_completions.append(completion)
|
completion_ids = generate_for_grpo(
|
||||||
|
model,
|
||||||
|
prompt_tensor,
|
||||||
|
max_tokens,
|
||||||
|
tokenizer=tokenizer,
|
||||||
|
temperature=temperature
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify completion_ids is not None
|
||||||
|
if completion_ids is None:
|
||||||
|
print("Warning: generate_for_grpo returned None")
|
||||||
|
break
|
||||||
|
|
||||||
|
completion_text = tokenizer.decode(completion_ids.tolist())
|
||||||
|
|
||||||
|
prompt_completions.append(completion_ids)
|
||||||
|
prompt_completion_texts.append(completion_text)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error in completion generation: {str(e)}")
|
||||||
|
# Fallback to using original prompt
|
||||||
|
prompt_completions.append(prompt_tensor)
|
||||||
|
prompt_completion_texts.append(tokenizer.decode(prompt_tensor.tolist()))
|
||||||
|
|
||||||
all_completions.extend(prompt_completions)
|
all_completions.extend(prompt_completions)
|
||||||
|
all_completion_texts.extend(prompt_completion_texts)
|
||||||
|
|
||||||
|
# Verify we have the expected number of completions
|
||||||
|
assert len(all_completions) == batch_size * group_size
|
||||||
|
assert len(all_completion_texts) == batch_size * group_size
|
||||||
|
|
||||||
# Tokenize all prompts + completions
|
# Expand answer_text and prompt_text to match completion groups
|
||||||
tokenized_inputs = tokenizer(
|
expanded_answers = []
|
||||||
[p + c for p, c in zip(prompts * group_size, all_completions)],
|
expanded_prompts = []
|
||||||
return_tensors="np",
|
for i in range(batch_size):
|
||||||
padding=True
|
expanded_answers.extend([answer_text[i]] * group_size)
|
||||||
)
|
expanded_prompts.extend([prompt_text[i]] * group_size)
|
||||||
|
|
||||||
|
# Verify we have the expected number of completions
|
||||||
|
assert len(all_completions) == batch_size * group_size
|
||||||
|
assert len(all_completion_texts) == batch_size * group_size
|
||||||
|
|
||||||
inputs = mx.array(tokenized_inputs["input_ids"])
|
max_length = max(ids.shape[0] for ids in all_completions)
|
||||||
attention_mask = mx.array(tokenized_inputs["attention_mask"])
|
padded_completions = []
|
||||||
|
attention_masks = []
|
||||||
|
|
||||||
# Get lengths for proper masking
|
for completion_ids in all_completions:
|
||||||
|
padding_length = max_length - completion_ids.shape[0]
|
||||||
|
if padding_length > 0:
|
||||||
|
padding = mx.zeros((padding_length,), dtype=completion_ids.dtype)
|
||||||
|
padded_ids = mx.concatenate([completion_ids, padding])
|
||||||
|
mask = mx.concatenate([mx.ones_like(completion_ids), mx.zeros_like(padding)])
|
||||||
|
else:
|
||||||
|
padded_ids = completion_ids
|
||||||
|
mask = mx.ones_like(completion_ids)
|
||||||
|
padded_completions.append(padded_ids)
|
||||||
|
attention_masks.append(mask)
|
||||||
|
|
||||||
|
inputs = mx.stack(padded_completions)
|
||||||
|
attention_mask = mx.stack(attention_masks)
|
||||||
lengths = attention_mask.sum(axis=1)
|
lengths = attention_mask.sum(axis=1)
|
||||||
|
|
||||||
# Get logits from current model
|
# Get logits from current model
|
||||||
logits = model(inputs).astype(mx.float32)
|
logits = model(inputs).astype(mx.float32)
|
||||||
|
|
||||||
# Calculate log probabilities
|
# Calculate log probabilities
|
||||||
log_probs = mx.log_softmax(logits[:, :-1, :], axis=-1)
|
log_probs = nn.log_softmax(logits[:, :-1, :], axis=-1)
|
||||||
|
|
||||||
# Prepare targets
|
# Prepare targets
|
||||||
targets = inputs[:, 1:]
|
targets = inputs[:, 1:]
|
||||||
@ -197,7 +294,7 @@ def grpo_loss(
|
|||||||
else:
|
else:
|
||||||
ref_logits = model(inputs).astype(mx.float32)
|
ref_logits = model(inputs).astype(mx.float32)
|
||||||
|
|
||||||
ref_log_probs = mx.log_softmax(ref_logits[:, :-1, :], axis=-1)
|
ref_log_probs = nn.log_softmax(ref_logits[:, :-1, :], axis=-1)
|
||||||
ref_token_log_probs = mx.take_along_axis(
|
ref_token_log_probs = mx.take_along_axis(
|
||||||
ref_log_probs,
|
ref_log_probs,
|
||||||
targets.reshape(*targets.shape, 1),
|
targets.reshape(*targets.shape, 1),
|
||||||
@ -210,7 +307,11 @@ def grpo_loss(
|
|||||||
# Calculate combined rewards from all reward functions
|
# Calculate combined rewards from all reward functions
|
||||||
rewards = mx.zeros((len(all_completions),))
|
rewards = mx.zeros((len(all_completions),))
|
||||||
for reward_func in reward_funcs:
|
for reward_func in reward_funcs:
|
||||||
func_rewards = mx.array(reward_func(all_completions))
|
func_rewards = mx.array(reward_func(
|
||||||
|
prompts=prompt_text,
|
||||||
|
completions=all_completion_texts,
|
||||||
|
answer=answer_text
|
||||||
|
))
|
||||||
rewards += func_rewards
|
rewards += func_rewards
|
||||||
|
|
||||||
# Normalize rewards if using multiple reward functions
|
# Normalize rewards if using multiple reward functions
|
||||||
@ -245,8 +346,12 @@ def grpo_loss(
|
|||||||
# Collect metrics for each reward function separately
|
# Collect metrics for each reward function separately
|
||||||
reward_metrics = {}
|
reward_metrics = {}
|
||||||
for i, reward_func in enumerate(reward_funcs):
|
for i, reward_func in enumerate(reward_funcs):
|
||||||
func_rewards = mx.array(reward_func(all_completions))
|
func_rewards = mx.array(reward_func(
|
||||||
func_grouped_rewards = func_rewards.reshape(batch_size, group_size)
|
prompts=prompt_text,
|
||||||
|
completions=all_completion_texts,
|
||||||
|
answer=answer_text
|
||||||
|
))
|
||||||
|
# func_grouped_rewards = func_rewards.reshape(batch_size, group_size)
|
||||||
reward_metrics[f'reward_func_{i}_mean'] = mx.mean(func_rewards)
|
reward_metrics[f'reward_func_{i}_mean'] = mx.mean(func_rewards)
|
||||||
reward_metrics[f'reward_func_{i}_std'] = mx.std(func_rewards)
|
reward_metrics[f'reward_func_{i}_std'] = mx.std(func_rewards)
|
||||||
|
|
||||||
@ -264,26 +369,30 @@ def grpo_loss(
|
|||||||
|
|
||||||
def iterate_grpo_batches(dataset, tokenizer, batch_size, max_seq_length, train=False):
|
def iterate_grpo_batches(dataset, tokenizer, batch_size, max_seq_length, train=False):
|
||||||
"""
|
"""
|
||||||
Creates batches from prompt-answer pairs for GRPO training.
|
Creates batches from dataset entries for GRPO training.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
dataset: List of (prompt, answer) pairs
|
dataset: List of (prompt_tokens, answer_tokens, prompt_str, answer_str) tuples
|
||||||
tokenizer: Tokenizer for processing inputs
|
tokenizer: Tokenizer for processing inputs
|
||||||
batch_size: Size of each batch
|
batch_size: Size of each batch
|
||||||
max_seq_length: Maximum sequence length
|
max_seq_length: Maximum sequence length
|
||||||
train: Whether this is for training
|
train: Whether this is for training
|
||||||
|
|
||||||
Yields:
|
Yields:
|
||||||
List of prompts for the current batch
|
Tuple containing:
|
||||||
|
- prompts_tokens: List of token sequences for current batch
|
||||||
|
- answers_tokens: List of token sequences
|
||||||
|
- prompts_text: List of prompt strings
|
||||||
|
- answers_text: List of answer strings
|
||||||
"""
|
"""
|
||||||
# Verify dataset is not empty and has correct format
|
# Verify dataset format
|
||||||
if not dataset or not isinstance(dataset[0], (tuple, list)) or len(dataset[0]) != 2:
|
if not dataset or not isinstance(dataset[0], tuple) or len(dataset[0]) != 4:
|
||||||
raise ValueError("Dataset must be a list of (prompt, answer) pairs")
|
raise ValueError("Dataset must be list of (prompt_tokens, answer_tokens, prompt_str, answer_str) tuples")
|
||||||
|
|
||||||
# Sort by combined length of prompt + answer
|
# Sort by combined length of prompt + answer tokens
|
||||||
idx = sorted(range(len(dataset)),
|
idx = sorted(range(len(dataset)),
|
||||||
key=lambda i: len(dataset[i][0]) + len(dataset[i][1]))
|
key=lambda i: len(dataset[i][0]) + len(dataset[i][1]))
|
||||||
|
|
||||||
if len(dataset) < batch_size:
|
if len(dataset) < batch_size:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Dataset must have at least batch_size={batch_size} "
|
f"Dataset must have at least batch_size={batch_size} "
|
||||||
@ -306,22 +415,22 @@ def iterate_grpo_batches(dataset, tokenizer, batch_size, max_seq_length, train=F
|
|||||||
indices = np.random.permutation(len(batch_idx)) if train else range(len(batch_idx))
|
indices = np.random.permutation(len(batch_idx)) if train else range(len(batch_idx))
|
||||||
|
|
||||||
for i in indices:
|
for i in indices:
|
||||||
# Get current batch of prompt-answer pairs
|
# Get current batch
|
||||||
current_batch = [dataset[j] for j in batch_idx[i]]
|
current_batch = [dataset[j] for j in batch_idx[i]]
|
||||||
|
|
||||||
# Extract prompts and answers
|
# Extract all components
|
||||||
prompts = [pair[0] for pair in current_batch]
|
prompts_tokens = [item[0] for item in current_batch]
|
||||||
answers = [pair[1] for pair in current_batch]
|
answers_tokens = [item[1] for item in current_batch]
|
||||||
|
prompts_text = [item[2] for item in current_batch]
|
||||||
if any(len(p) > max_seq_length for p in prompts):
|
answers_text = [item[3] for item in current_batch]
|
||||||
|
|
||||||
|
if any(len(p) > max_seq_length for p in prompts_tokens):
|
||||||
print(
|
print(
|
||||||
f"[WARNING] Some prompts are longer than {max_seq_length} tokens. "
|
f"[WARNING] Some prompts are longer than {max_seq_length} tokens. "
|
||||||
"Long prompts will be truncated."
|
"Long prompts will be truncated."
|
||||||
)
|
)
|
||||||
|
|
||||||
# For GRPO, we only need to yield the prompts
|
yield prompts_tokens, answers_tokens, prompts_text, answers_text
|
||||||
# The answers will be used by the reward functions
|
|
||||||
yield prompts
|
|
||||||
|
|
||||||
if not train:
|
if not train:
|
||||||
break
|
break
|
||||||
@ -342,11 +451,19 @@ def evaluate_grpo(
|
|||||||
loss: callable = grpo_loss,
|
loss: callable = grpo_loss,
|
||||||
iterate_batches: callable = iterate_grpo_batches
|
iterate_batches: callable = iterate_grpo_batches
|
||||||
):
|
):
|
||||||
|
"""
|
||||||
|
Evaluate model using GRPO loss.
|
||||||
|
Returns:
|
||||||
|
tuple: (average loss, number of tokens, average metrics)
|
||||||
|
"""
|
||||||
all_losses = 0
|
all_losses = 0
|
||||||
ntokens = 0
|
ntokens = 0
|
||||||
|
all_metrics = None # Initialize metrics dictionary
|
||||||
|
|
||||||
|
# Create iterator for batches
|
||||||
index_iterator = iter(range(num_batches)) if num_batches != -1 else iter(int, 1)
|
index_iterator = iter(range(num_batches)) if num_batches != -1 else iter(int, 1)
|
||||||
|
|
||||||
|
# Iterate through batches
|
||||||
for _, batch in zip(
|
for _, batch in zip(
|
||||||
index_iterator,
|
index_iterator,
|
||||||
iterate_batches(
|
iterate_batches(
|
||||||
@ -356,35 +473,41 @@ def evaluate_grpo(
|
|||||||
max_seq_length=max_seq_length,
|
max_seq_length=max_seq_length,
|
||||||
),
|
),
|
||||||
):
|
):
|
||||||
prompts = batch
|
# Calculate loss for current batch
|
||||||
losses, toks, metrics = loss(
|
losses, toks, metrics = loss(
|
||||||
model=model,
|
model=model,
|
||||||
tokenizer=tokenizer,
|
tokenizer=tokenizer,
|
||||||
prompts=prompts,
|
batch=batch,
|
||||||
reward_funcs=reward_funcs,
|
reward_funcs=reward_funcs,
|
||||||
beta=beta,
|
beta=beta,
|
||||||
group_size=group_size,
|
group_size=group_size,
|
||||||
epsilon=epsilon,
|
epsilon=epsilon,
|
||||||
ref_model=ref_model
|
ref_model=ref_model
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Accumulate losses and tokens
|
||||||
all_losses += losses * toks
|
all_losses += losses * toks
|
||||||
ntokens += toks
|
ntokens += toks
|
||||||
|
|
||||||
|
# Accumulate metrics
|
||||||
if all_metrics is None:
|
if all_metrics is None:
|
||||||
all_metrics = {k: v * toks for k, v in metrics.items()}
|
all_metrics = {k: v * toks for k, v in metrics.items()}
|
||||||
else:
|
else:
|
||||||
for k, v in metrics.items():
|
for k, v in metrics.items():
|
||||||
all_metrics[k] += v * toks
|
all_metrics[k] += v * toks
|
||||||
|
|
||||||
|
# Evaluate accumulated values
|
||||||
mx.eval(all_losses, ntokens)
|
mx.eval(all_losses, ntokens)
|
||||||
|
|
||||||
|
# Aggregate across distributed workers
|
||||||
all_losses = mx.distributed.all_sum(all_losses, stream=mx.cpu)
|
all_losses = mx.distributed.all_sum(all_losses, stream=mx.cpu)
|
||||||
ntokens = mx.distributed.all_sum(ntokens, stream=mx.cpu)
|
ntokens = mx.distributed.all_sum(ntokens, stream=mx.cpu)
|
||||||
all_metrics = {k: mx.distributed.all_sum(v) for k, v in all_metrics.items()}
|
all_metrics = {k: mx.distributed.all_sum(v) for k, v in all_metrics.items()}
|
||||||
|
|
||||||
|
# Calculate averages
|
||||||
avg_metrics = {k: (v / ntokens).item() for k, v in all_metrics.items()}
|
avg_metrics = {k: (v / ntokens).item() for k, v in all_metrics.items()}
|
||||||
avg_loss = (all_losses / ntokens).item()
|
avg_loss = (all_losses / ntokens).item()
|
||||||
|
|
||||||
return avg_loss, ntokens, avg_metrics
|
return avg_loss, ntokens, avg_metrics
|
||||||
|
|
||||||
|
|
||||||
@ -420,8 +543,18 @@ def train_grpo(
|
|||||||
state = [model.state, optimizer.state]
|
state = [model.state, optimizer.state]
|
||||||
|
|
||||||
def step(batch):
|
def step(batch):
|
||||||
|
|
||||||
# Forward and backward pass
|
# Forward and backward pass
|
||||||
(loss, toks, metrics), grad = loss_value_and_grad(model, *batch)
|
(loss, toks, metrics), grad = loss_value_and_grad(
|
||||||
|
model,
|
||||||
|
tokenizer=tokenizer,
|
||||||
|
batch=batch,
|
||||||
|
reward_funcs=reward_funcs,
|
||||||
|
beta=args.beta,
|
||||||
|
group_size=args.group_size,
|
||||||
|
epsilon=args.epsilon,
|
||||||
|
ref_model=ref_model
|
||||||
|
)
|
||||||
|
|
||||||
# All reduce the gradients if running in distributed mode
|
# All reduce the gradients if running in distributed mode
|
||||||
grad = average_gradients(grad)
|
grad = average_gradients(grad)
|
||||||
@ -430,7 +563,7 @@ def train_grpo(
|
|||||||
optimizer.update(model, grad)
|
optimizer.update(model, grad)
|
||||||
|
|
||||||
return loss, toks, metrics
|
return loss, toks, metrics
|
||||||
|
|
||||||
loss_value_and_grad = nn.value_and_grad(model, loss)
|
loss_value_and_grad = nn.value_and_grad(model, loss)
|
||||||
|
|
||||||
losses = 0
|
losses = 0
|
||||||
|
Loading…
Reference in New Issue
Block a user