mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-07-15 23:11:12 +08:00
optims
This commit is contained in:
parent
1d9e4802f0
commit
05d921b788
@ -35,68 +35,50 @@ class GRPOTrainingArgs(TrainingArgs):
|
||||
)
|
||||
|
||||
|
||||
def generate_for_grpo(
|
||||
model,
|
||||
prompt,
|
||||
max_tokens,
|
||||
tokenizer,
|
||||
temperature=1.0
|
||||
):
|
||||
try:
|
||||
def generate_grpo(model, prompt, max_tokens, tokenizer, temperature=1.0):
|
||||
model.eval()
|
||||
if len(prompt.shape) == 1:
|
||||
prompt = prompt[None, :]
|
||||
|
||||
generated = []
|
||||
current_prompt = prompt[0]
|
||||
|
||||
for _ in range(max_tokens):
|
||||
current_batch = current_prompt[None, :]
|
||||
logits = model(current_batch)
|
||||
token_logits = logits[0, -1]
|
||||
|
||||
# Ensure prompt is the right shape
|
||||
if len(prompt.shape) == 1:
|
||||
prompt = prompt[None, :]
|
||||
|
||||
# Initialize generation
|
||||
generated = []
|
||||
current_prompt = prompt[0]
|
||||
|
||||
for step in range(max_tokens):
|
||||
try:
|
||||
# Get model output with explicit shape checking
|
||||
current_batch = current_prompt[None, :]
|
||||
|
||||
logits = model(current_batch)
|
||||
|
||||
# Ensure we have the last token logits
|
||||
token_logits = logits[0, -1]
|
||||
|
||||
# Apply temperature and get probabilities
|
||||
if temperature > 0:
|
||||
token_logits = token_logits / temperature
|
||||
probs = mx.softmax(token_logits)
|
||||
|
||||
# Sample the next token
|
||||
next_token = mx.random.categorical(probs[None, :])
|
||||
next_token = next_token[0]
|
||||
|
||||
# Force evaluation to catch any issues
|
||||
mx.eval(next_token)
|
||||
token_value = next_token.item()
|
||||
|
||||
# Add to generated sequence
|
||||
generated.append(next_token)
|
||||
current_prompt = mx.concatenate([current_prompt, next_token[None]])
|
||||
|
||||
if token_value == tokenizer.eos_token_id:
|
||||
break
|
||||
|
||||
except Exception as e:
|
||||
raise
|
||||
|
||||
if not generated:
|
||||
return prompt[0]
|
||||
if temperature > 0:
|
||||
token_logits = token_logits / temperature
|
||||
|
||||
try:
|
||||
result = mx.concatenate([prompt[0], mx.stack(generated)])
|
||||
mx.eval(result)
|
||||
return result
|
||||
except Exception as e:
|
||||
raise
|
||||
probs = mx.softmax(token_logits)
|
||||
next_token = mx.random.categorical(probs[None, :])
|
||||
next_token = next_token[0]
|
||||
mx.eval(next_token)
|
||||
|
||||
token_value = next_token.item()
|
||||
generated.append(next_token)
|
||||
|
||||
# Clear intermediate tensors
|
||||
del logits, token_logits, probs
|
||||
mx.metal.clear_cache()
|
||||
|
||||
current_prompt = mx.concatenate([current_prompt, next_token[None]])
|
||||
if token_value == tokenizer.eos_token_id:
|
||||
break
|
||||
|
||||
except Exception as e:
|
||||
raise
|
||||
if not generated:
|
||||
return prompt[0]
|
||||
|
||||
result = mx.concatenate([prompt[0], mx.stack(generated)])
|
||||
mx.eval(result)
|
||||
model.train()
|
||||
|
||||
# Clear generated tokens
|
||||
del generated
|
||||
mx.metal.clear_cache()
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def r1_extract_xml_answer(text: str) -> str:
|
||||
@ -191,67 +173,46 @@ def grpo_loss(
|
||||
group_size=4,
|
||||
epsilon=1e-4,
|
||||
ref_model=None,
|
||||
max_tokens=128,
|
||||
max_tokens=64,
|
||||
temperature=1.0
|
||||
):
|
||||
"""Modified GRPO loss function with better error handling"""
|
||||
prompt_tokens, answer_tokens, prompt_text, answer_text = batch
|
||||
batch_size = len(prompt_tokens)
|
||||
|
||||
# Generate completions for each prompt
|
||||
# Generation logic remains the same
|
||||
all_completions = []
|
||||
all_completion_texts = []
|
||||
|
||||
for prompt in prompt_tokens:
|
||||
prompt_tensor = mx.array(prompt)
|
||||
prompt_completions = []
|
||||
prompt_completion_texts = []
|
||||
|
||||
# Generate group_size completions for each prompt
|
||||
for _ in range(group_size):
|
||||
try:
|
||||
completion_ids = generate_for_grpo(
|
||||
model,
|
||||
prompt_tensor,
|
||||
max_tokens,
|
||||
tokenizer=tokenizer,
|
||||
temperature=temperature
|
||||
)
|
||||
|
||||
# Verify completion_ids is not None
|
||||
completion_ids = generate_grpo(model, prompt_tensor, max_tokens, tokenizer, temperature)
|
||||
if completion_ids is None:
|
||||
print("Warning: generate_for_grpo returned None")
|
||||
break
|
||||
continue
|
||||
|
||||
completion_text = tokenizer.decode(completion_ids.tolist())
|
||||
all_completions.append(completion_ids)
|
||||
all_completion_texts.append(completion_text)
|
||||
|
||||
prompt_completions.append(completion_ids)
|
||||
prompt_completion_texts.append(completion_text)
|
||||
del completion_ids
|
||||
mx.metal.clear_cache()
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error in completion generation: {str(e)}")
|
||||
# Fallback to using original prompt
|
||||
prompt_completions.append(prompt_tensor)
|
||||
prompt_completion_texts.append(tokenizer.decode(prompt_tensor.tolist()))
|
||||
print(f"Generation error: {e}")
|
||||
continue
|
||||
|
||||
all_completions.extend(prompt_completions)
|
||||
all_completion_texts.extend(prompt_completion_texts)
|
||||
del prompt_tensor
|
||||
mx.metal.clear_cache()
|
||||
|
||||
# Verify we have the expected number of completions
|
||||
assert len(all_completions) == batch_size * group_size
|
||||
assert len(all_completion_texts) == batch_size * group_size
|
||||
|
||||
# Expand answer_text and prompt_text to match completion groups
|
||||
# Prepare inputs
|
||||
expanded_answers = []
|
||||
expanded_prompts = []
|
||||
for i in range(batch_size):
|
||||
expanded_answers.extend([answer_text[i]] * group_size)
|
||||
expanded_prompts.extend([prompt_text[i]] * group_size)
|
||||
|
||||
# Verify we have the expected number of completions
|
||||
assert len(all_completions) == batch_size * group_size
|
||||
assert len(all_completion_texts) == batch_size * group_size
|
||||
|
||||
|
||||
max_length = max(ids.shape[0] for ids in all_completions)
|
||||
padded_completions = []
|
||||
attention_masks = []
|
||||
@ -267,32 +228,37 @@ def grpo_loss(
|
||||
mask = mx.ones_like(completion_ids)
|
||||
padded_completions.append(padded_ids)
|
||||
attention_masks.append(mask)
|
||||
|
||||
del completion_ids
|
||||
if padding_length > 0:
|
||||
del padding
|
||||
del mask
|
||||
mx.metal.clear_cache()
|
||||
|
||||
inputs = mx.stack(padded_completions)
|
||||
attention_mask = mx.stack(attention_masks)
|
||||
lengths = attention_mask.sum(axis=1)
|
||||
|
||||
# Get logits from current model
|
||||
del padded_completions, attention_masks
|
||||
mx.metal.clear_cache()
|
||||
|
||||
# Get logits and compute log probabilities
|
||||
logits = model(inputs).astype(mx.float32)
|
||||
|
||||
# Calculate log probabilities
|
||||
log_probs = nn.log_softmax(logits[:, :-1, :], axis=-1)
|
||||
|
||||
# Prepare targets
|
||||
targets = inputs[:, 1:]
|
||||
|
||||
# Gather actual token probabilities
|
||||
# Current policy probabilities
|
||||
token_log_probs = mx.take_along_axis(
|
||||
log_probs,
|
||||
targets.reshape(*targets.shape, 1),
|
||||
axis=-1
|
||||
).squeeze(-1)
|
||||
|
||||
# Get reference model log probabilities
|
||||
# Reference policy probabilities
|
||||
if ref_model is not None:
|
||||
ref_logits = ref_model(inputs).astype(mx.float32)
|
||||
else:
|
||||
ref_logits = model(inputs).astype(mx.float32)
|
||||
ref_logits = mx.array(logits)
|
||||
|
||||
ref_log_probs = nn.log_softmax(ref_logits[:, :-1, :], axis=-1)
|
||||
ref_token_log_probs = mx.take_along_axis(
|
||||
@ -301,124 +267,107 @@ def grpo_loss(
|
||||
axis=-1
|
||||
).squeeze(-1)
|
||||
|
||||
# Compute KL divergence
|
||||
kl_div = (mx.exp(ref_token_log_probs - token_log_probs) - (ref_token_log_probs - token_log_probs) - 1)
|
||||
|
||||
# Calculate combined rewards from all reward functions
|
||||
# Calculate rewards and advantages
|
||||
rewards = mx.zeros((len(all_completions),))
|
||||
for reward_func in reward_funcs:
|
||||
func_rewards = mx.array(reward_func(
|
||||
prompts=prompt_text,
|
||||
completions=all_completion_texts,
|
||||
answer=answer_text
|
||||
))
|
||||
rewards += func_rewards
|
||||
|
||||
# Normalize rewards if using multiple reward functions
|
||||
if len(reward_funcs) > 1:
|
||||
rewards /= len(reward_funcs)
|
||||
|
||||
# Compute grouped-wise rewards
|
||||
grouped_rewards = rewards.reshape(batch_size, group_size)
|
||||
mean_grouped_rewards = mx.mean(grouped_rewards, axis=1)
|
||||
std_grouped_rewards = mx.std(grouped_rewards, axis=1)
|
||||
|
||||
# Normalize rewards to compute advantages
|
||||
mean_grouped_rewards = mx.repeat(mean_grouped_rewards.reshape(-1, 1), group_size, axis=1).reshape(-1)
|
||||
std_grouped_rewards = mx.repeat(std_grouped_rewards.reshape(-1, 1), group_size, axis=1).reshape(-1)
|
||||
advantages = (rewards - mean_grouped_rewards) / (std_grouped_rewards + epsilon)
|
||||
|
||||
# Create length mask for the shifted sequence
|
||||
length_mask = mx.arange(inputs.shape[1] - 1)[None, :] < (lengths[:, None] - 1)
|
||||
|
||||
# Calculate policy gradient loss
|
||||
per_token_loss = mx.exp(token_log_probs - mx.stop_gradient(token_log_probs)) * advantages.reshape(-1, 1)
|
||||
per_token_loss = -(per_token_loss - beta * kl_div)
|
||||
|
||||
# Normalize loss properly per sequence
|
||||
sequence_sums = (per_token_loss * length_mask).sum(axis=1)
|
||||
sequence_lengths = length_mask.sum(axis=1)
|
||||
loss = (sequence_sums / sequence_lengths).mean()
|
||||
|
||||
# Calculate mean KL divergence
|
||||
mean_kl = ((kl_div * length_mask).sum(axis=1) / length_mask.sum(axis=1)).mean()
|
||||
|
||||
# Collect metrics for each reward function separately
|
||||
reward_metrics = {}
|
||||
for i, reward_func in enumerate(reward_funcs):
|
||||
func_rewards = mx.array(reward_func(
|
||||
prompts=prompt_text,
|
||||
prompts=prompt_text,
|
||||
completions=all_completion_texts,
|
||||
answer=answer_text
|
||||
))
|
||||
rewards += func_rewards
|
||||
|
||||
if len(reward_funcs) > 1:
|
||||
rewards /= len(reward_funcs)
|
||||
|
||||
# Reshape rewards and compute advantages following GRPO formula
|
||||
rewards_reshaped = rewards.reshape(batch_size, group_size)
|
||||
mean_rewards = mx.broadcast_to(mx.mean(rewards_reshaped, axis=1)[:, None], (rewards_reshaped.shape[0], group_size)).reshape(-1)
|
||||
std_rewards = mx.broadcast_to(mx.std(rewards_reshaped, axis=1)[:, None], (rewards_reshaped.shape[0], group_size)).reshape(-1)
|
||||
advantages = (rewards - mean_rewards) / (std_rewards + epsilon)
|
||||
|
||||
# Compute KL divergence using Schulman's approximator
|
||||
kl_div = mx.exp(ref_token_log_probs - token_log_probs) - (ref_token_log_probs - token_log_probs) - 1
|
||||
|
||||
# Create mask for valid tokens
|
||||
length_mask = mx.arange(inputs.shape[1] - 1)[None, :] < (lengths[:, None] - 1)
|
||||
|
||||
# Compute policy ratio
|
||||
policy_ratio = mx.exp(token_log_probs - mx.stop_gradient(token_log_probs))
|
||||
|
||||
# Compute per-token loss following GRPO formula
|
||||
per_token_loss = -(policy_ratio * advantages.reshape(-1, 1) - beta * kl_div)
|
||||
|
||||
# Average over tokens and sequences
|
||||
sequence_sums = (per_token_loss * length_mask).sum(axis=1)
|
||||
sequence_lengths = length_mask.sum(axis=1)
|
||||
loss = (sequence_sums / sequence_lengths).mean()
|
||||
|
||||
# Calculate mean KL divergence for metrics
|
||||
mean_kl = ((kl_div * length_mask).sum(axis=1) / length_mask.sum(axis=1)).mean()
|
||||
|
||||
# Collect reward metrics
|
||||
reward_metrics = {}
|
||||
for i, reward_func in enumerate(reward_funcs):
|
||||
func_rewards = mx.array(reward_func(
|
||||
prompts=prompt_text,
|
||||
completions=all_completion_texts,
|
||||
answer=answer_text
|
||||
))
|
||||
# func_grouped_rewards = func_rewards.reshape(batch_size, group_size)
|
||||
reward_metrics[f'reward_func_{i}_mean'] = mx.mean(func_rewards)
|
||||
reward_metrics[f'reward_func_{i}_std'] = mx.std(func_rewards)
|
||||
|
||||
|
||||
# Clean up
|
||||
del all_completions
|
||||
mx.metal.clear_cache()
|
||||
|
||||
metrics = {
|
||||
'total_rewards_mean': mx.mean(rewards),
|
||||
'total_rewards_std': mx.std(rewards),
|
||||
'grouped_rewards_mean': mx.mean(grouped_rewards),
|
||||
'grouped_rewards_std': mx.std(grouped_rewards),
|
||||
'grouped_rewards_mean': mx.mean(rewards_reshaped),
|
||||
'grouped_rewards_std': mx.std(rewards_reshaped),
|
||||
'kl': mean_kl,
|
||||
**reward_metrics
|
||||
}
|
||||
|
||||
|
||||
return loss, sequence_lengths.sum(), metrics
|
||||
|
||||
|
||||
def iterate_grpo_batches(dataset, tokenizer, batch_size, max_seq_length, train=False):
|
||||
"""
|
||||
Creates batches from dataset entries for GRPO training.
|
||||
|
||||
Args:
|
||||
dataset: List of (prompt_tokens, answer_tokens, prompt_str, answer_str) tuples
|
||||
tokenizer: Tokenizer for processing inputs
|
||||
batch_size: Size of each batch
|
||||
max_seq_length: Maximum sequence length
|
||||
train: Whether this is for training
|
||||
|
||||
Yields:
|
||||
Tuple containing:
|
||||
- prompts_tokens: List of token sequences for current batch
|
||||
- answers_tokens: List of token sequences
|
||||
- prompts_text: List of prompt strings
|
||||
- answers_text: List of answer strings
|
||||
"""
|
||||
# Verify dataset format
|
||||
"""Memory-optimized version of iterate_grpo_batches"""
|
||||
if not dataset or not isinstance(dataset[0], tuple) or len(dataset[0]) != 4:
|
||||
raise ValueError("Dataset must be list of (prompt_tokens, answer_tokens, prompt_str, answer_str) tuples")
|
||||
|
||||
# Sort by combined length of prompt + answer tokens
|
||||
idx = sorted(range(len(dataset)),
|
||||
key=lambda i: len(dataset[i][0]) + len(dataset[i][1]))
|
||||
|
||||
# Sort by length but use generator to avoid keeping full sorted list in memory
|
||||
def length_key(i):
|
||||
return len(dataset[i][0]) + len(dataset[i][1])
|
||||
|
||||
idx = sorted(range(len(dataset)), key=length_key)
|
||||
|
||||
if len(dataset) < batch_size:
|
||||
raise ValueError(
|
||||
f"Dataset must have at least batch_size={batch_size} "
|
||||
f"examples but only has {len(dataset)}."
|
||||
)
|
||||
|
||||
# Handle distributed training
|
||||
step = mx.distributed.init().size()
|
||||
if batch_size % step != 0:
|
||||
raise ValueError("The batch size must be divisible by the number of workers")
|
||||
|
||||
# Create batch indices
|
||||
batch_idx = [
|
||||
idx[i : i + batch_size : step]
|
||||
for i in range(0, len(idx) - batch_size + 1, batch_size)
|
||||
]
|
||||
# Use generator for batch indices
|
||||
def batch_index_generator():
|
||||
for i in range(0, len(idx) - batch_size + 1, batch_size):
|
||||
yield idx[i : i + batch_size : step]
|
||||
|
||||
while True:
|
||||
# Shuffle batch indices if training
|
||||
indices = np.random.permutation(len(batch_idx)) if train else range(len(batch_idx))
|
||||
indices = (
|
||||
np.random.permutation(list(batch_index_generator())) if train
|
||||
else batch_index_generator()
|
||||
)
|
||||
|
||||
for i in indices:
|
||||
# Get current batch
|
||||
current_batch = [dataset[j] for j in batch_idx[i]]
|
||||
for batch_idx in indices:
|
||||
current_batch = [dataset[j] for j in batch_idx]
|
||||
|
||||
# Extract all components
|
||||
prompts_tokens = [item[0] for item in current_batch]
|
||||
answers_tokens = [item[1] for item in current_batch]
|
||||
prompts_text = [item[2] for item in current_batch]
|
||||
@ -553,7 +502,8 @@ def train_grpo(
|
||||
beta=args.beta,
|
||||
group_size=args.group_size,
|
||||
epsilon=args.epsilon,
|
||||
ref_model=ref_model
|
||||
ref_model=ref_model,
|
||||
max_tokens=args.max_seq_length,
|
||||
)
|
||||
|
||||
# All reduce the gradients if running in distributed mode
|
||||
@ -649,8 +599,10 @@ def train_grpo(
|
||||
losses += loss
|
||||
n_tokens += toks
|
||||
steps += 1
|
||||
|
||||
for k, v in metrics.items():
|
||||
accumulated_metrics[k] += v
|
||||
|
||||
mx.eval(state, losses, n_tokens)
|
||||
|
||||
if it % args.steps_per_report == 0 or it == args.iters:
|
||||
|
Loading…
Reference in New Issue
Block a user