mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-07-18 00:11:14 +08:00
optims
This commit is contained in:
parent
1d9e4802f0
commit
05d921b788
@ -35,68 +35,50 @@ class GRPOTrainingArgs(TrainingArgs):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def generate_for_grpo(
|
def generate_grpo(model, prompt, max_tokens, tokenizer, temperature=1.0):
|
||||||
model,
|
model.eval()
|
||||||
prompt,
|
if len(prompt.shape) == 1:
|
||||||
max_tokens,
|
prompt = prompt[None, :]
|
||||||
tokenizer,
|
|
||||||
temperature=1.0
|
|
||||||
):
|
|
||||||
try:
|
|
||||||
|
|
||||||
# Ensure prompt is the right shape
|
generated = []
|
||||||
if len(prompt.shape) == 1:
|
current_prompt = prompt[0]
|
||||||
prompt = prompt[None, :]
|
|
||||||
|
|
||||||
# Initialize generation
|
for _ in range(max_tokens):
|
||||||
generated = []
|
current_batch = current_prompt[None, :]
|
||||||
current_prompt = prompt[0]
|
logits = model(current_batch)
|
||||||
|
token_logits = logits[0, -1]
|
||||||
|
|
||||||
for step in range(max_tokens):
|
if temperature > 0:
|
||||||
try:
|
token_logits = token_logits / temperature
|
||||||
# Get model output with explicit shape checking
|
|
||||||
current_batch = current_prompt[None, :]
|
|
||||||
|
|
||||||
logits = model(current_batch)
|
probs = mx.softmax(token_logits)
|
||||||
|
next_token = mx.random.categorical(probs[None, :])
|
||||||
|
next_token = next_token[0]
|
||||||
|
mx.eval(next_token)
|
||||||
|
|
||||||
# Ensure we have the last token logits
|
token_value = next_token.item()
|
||||||
token_logits = logits[0, -1]
|
generated.append(next_token)
|
||||||
|
|
||||||
# Apply temperature and get probabilities
|
# Clear intermediate tensors
|
||||||
if temperature > 0:
|
del logits, token_logits, probs
|
||||||
token_logits = token_logits / temperature
|
mx.metal.clear_cache()
|
||||||
probs = mx.softmax(token_logits)
|
|
||||||
|
|
||||||
# Sample the next token
|
current_prompt = mx.concatenate([current_prompt, next_token[None]])
|
||||||
next_token = mx.random.categorical(probs[None, :])
|
if token_value == tokenizer.eos_token_id:
|
||||||
next_token = next_token[0]
|
break
|
||||||
|
|
||||||
# Force evaluation to catch any issues
|
if not generated:
|
||||||
mx.eval(next_token)
|
return prompt[0]
|
||||||
token_value = next_token.item()
|
|
||||||
|
|
||||||
# Add to generated sequence
|
result = mx.concatenate([prompt[0], mx.stack(generated)])
|
||||||
generated.append(next_token)
|
mx.eval(result)
|
||||||
current_prompt = mx.concatenate([current_prompt, next_token[None]])
|
model.train()
|
||||||
|
|
||||||
if token_value == tokenizer.eos_token_id:
|
# Clear generated tokens
|
||||||
break
|
del generated
|
||||||
|
mx.metal.clear_cache()
|
||||||
|
|
||||||
except Exception as e:
|
return result
|
||||||
raise
|
|
||||||
|
|
||||||
if not generated:
|
|
||||||
return prompt[0]
|
|
||||||
|
|
||||||
try:
|
|
||||||
result = mx.concatenate([prompt[0], mx.stack(generated)])
|
|
||||||
mx.eval(result)
|
|
||||||
return result
|
|
||||||
except Exception as e:
|
|
||||||
raise
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
raise
|
|
||||||
|
|
||||||
|
|
||||||
def r1_extract_xml_answer(text: str) -> str:
|
def r1_extract_xml_answer(text: str) -> str:
|
||||||
@ -191,67 +173,46 @@ def grpo_loss(
|
|||||||
group_size=4,
|
group_size=4,
|
||||||
epsilon=1e-4,
|
epsilon=1e-4,
|
||||||
ref_model=None,
|
ref_model=None,
|
||||||
max_tokens=128,
|
max_tokens=64,
|
||||||
temperature=1.0
|
temperature=1.0
|
||||||
):
|
):
|
||||||
"""Modified GRPO loss function with better error handling"""
|
|
||||||
prompt_tokens, answer_tokens, prompt_text, answer_text = batch
|
prompt_tokens, answer_tokens, prompt_text, answer_text = batch
|
||||||
batch_size = len(prompt_tokens)
|
batch_size = len(prompt_tokens)
|
||||||
|
|
||||||
# Generate completions for each prompt
|
# Generation logic remains the same
|
||||||
all_completions = []
|
all_completions = []
|
||||||
all_completion_texts = []
|
all_completion_texts = []
|
||||||
|
|
||||||
for prompt in prompt_tokens:
|
for prompt in prompt_tokens:
|
||||||
prompt_tensor = mx.array(prompt)
|
prompt_tensor = mx.array(prompt)
|
||||||
prompt_completions = []
|
|
||||||
prompt_completion_texts = []
|
|
||||||
|
|
||||||
# Generate group_size completions for each prompt
|
|
||||||
for _ in range(group_size):
|
for _ in range(group_size):
|
||||||
try:
|
try:
|
||||||
completion_ids = generate_for_grpo(
|
completion_ids = generate_grpo(model, prompt_tensor, max_tokens, tokenizer, temperature)
|
||||||
model,
|
|
||||||
prompt_tensor,
|
|
||||||
max_tokens,
|
|
||||||
tokenizer=tokenizer,
|
|
||||||
temperature=temperature
|
|
||||||
)
|
|
||||||
|
|
||||||
# Verify completion_ids is not None
|
|
||||||
if completion_ids is None:
|
if completion_ids is None:
|
||||||
print("Warning: generate_for_grpo returned None")
|
continue
|
||||||
break
|
|
||||||
|
|
||||||
completion_text = tokenizer.decode(completion_ids.tolist())
|
completion_text = tokenizer.decode(completion_ids.tolist())
|
||||||
|
all_completions.append(completion_ids)
|
||||||
|
all_completion_texts.append(completion_text)
|
||||||
|
|
||||||
prompt_completions.append(completion_ids)
|
del completion_ids
|
||||||
prompt_completion_texts.append(completion_text)
|
mx.metal.clear_cache()
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"Error in completion generation: {str(e)}")
|
print(f"Generation error: {e}")
|
||||||
# Fallback to using original prompt
|
continue
|
||||||
prompt_completions.append(prompt_tensor)
|
|
||||||
prompt_completion_texts.append(tokenizer.decode(prompt_tensor.tolist()))
|
|
||||||
|
|
||||||
all_completions.extend(prompt_completions)
|
del prompt_tensor
|
||||||
all_completion_texts.extend(prompt_completion_texts)
|
mx.metal.clear_cache()
|
||||||
|
|
||||||
# Verify we have the expected number of completions
|
# Prepare inputs
|
||||||
assert len(all_completions) == batch_size * group_size
|
|
||||||
assert len(all_completion_texts) == batch_size * group_size
|
|
||||||
|
|
||||||
# Expand answer_text and prompt_text to match completion groups
|
|
||||||
expanded_answers = []
|
expanded_answers = []
|
||||||
expanded_prompts = []
|
expanded_prompts = []
|
||||||
for i in range(batch_size):
|
for i in range(batch_size):
|
||||||
expanded_answers.extend([answer_text[i]] * group_size)
|
expanded_answers.extend([answer_text[i]] * group_size)
|
||||||
expanded_prompts.extend([prompt_text[i]] * group_size)
|
expanded_prompts.extend([prompt_text[i]] * group_size)
|
||||||
|
|
||||||
# Verify we have the expected number of completions
|
|
||||||
assert len(all_completions) == batch_size * group_size
|
|
||||||
assert len(all_completion_texts) == batch_size * group_size
|
|
||||||
|
|
||||||
max_length = max(ids.shape[0] for ids in all_completions)
|
max_length = max(ids.shape[0] for ids in all_completions)
|
||||||
padded_completions = []
|
padded_completions = []
|
||||||
attention_masks = []
|
attention_masks = []
|
||||||
@ -268,31 +229,36 @@ def grpo_loss(
|
|||||||
padded_completions.append(padded_ids)
|
padded_completions.append(padded_ids)
|
||||||
attention_masks.append(mask)
|
attention_masks.append(mask)
|
||||||
|
|
||||||
|
del completion_ids
|
||||||
|
if padding_length > 0:
|
||||||
|
del padding
|
||||||
|
del mask
|
||||||
|
mx.metal.clear_cache()
|
||||||
|
|
||||||
inputs = mx.stack(padded_completions)
|
inputs = mx.stack(padded_completions)
|
||||||
attention_mask = mx.stack(attention_masks)
|
attention_mask = mx.stack(attention_masks)
|
||||||
lengths = attention_mask.sum(axis=1)
|
lengths = attention_mask.sum(axis=1)
|
||||||
|
|
||||||
# Get logits from current model
|
del padded_completions, attention_masks
|
||||||
|
mx.metal.clear_cache()
|
||||||
|
|
||||||
|
# Get logits and compute log probabilities
|
||||||
logits = model(inputs).astype(mx.float32)
|
logits = model(inputs).astype(mx.float32)
|
||||||
|
|
||||||
# Calculate log probabilities
|
|
||||||
log_probs = nn.log_softmax(logits[:, :-1, :], axis=-1)
|
log_probs = nn.log_softmax(logits[:, :-1, :], axis=-1)
|
||||||
|
|
||||||
# Prepare targets
|
|
||||||
targets = inputs[:, 1:]
|
targets = inputs[:, 1:]
|
||||||
|
|
||||||
# Gather actual token probabilities
|
# Current policy probabilities
|
||||||
token_log_probs = mx.take_along_axis(
|
token_log_probs = mx.take_along_axis(
|
||||||
log_probs,
|
log_probs,
|
||||||
targets.reshape(*targets.shape, 1),
|
targets.reshape(*targets.shape, 1),
|
||||||
axis=-1
|
axis=-1
|
||||||
).squeeze(-1)
|
).squeeze(-1)
|
||||||
|
|
||||||
# Get reference model log probabilities
|
# Reference policy probabilities
|
||||||
if ref_model is not None:
|
if ref_model is not None:
|
||||||
ref_logits = ref_model(inputs).astype(mx.float32)
|
ref_logits = ref_model(inputs).astype(mx.float32)
|
||||||
else:
|
else:
|
||||||
ref_logits = model(inputs).astype(mx.float32)
|
ref_logits = mx.array(logits)
|
||||||
|
|
||||||
ref_log_probs = nn.log_softmax(ref_logits[:, :-1, :], axis=-1)
|
ref_log_probs = nn.log_softmax(ref_logits[:, :-1, :], axis=-1)
|
||||||
ref_token_log_probs = mx.take_along_axis(
|
ref_token_log_probs = mx.take_along_axis(
|
||||||
@ -301,10 +267,7 @@ def grpo_loss(
|
|||||||
axis=-1
|
axis=-1
|
||||||
).squeeze(-1)
|
).squeeze(-1)
|
||||||
|
|
||||||
# Compute KL divergence
|
# Calculate rewards and advantages
|
||||||
kl_div = (mx.exp(ref_token_log_probs - token_log_probs) - (ref_token_log_probs - token_log_probs) - 1)
|
|
||||||
|
|
||||||
# Calculate combined rewards from all reward functions
|
|
||||||
rewards = mx.zeros((len(all_completions),))
|
rewards = mx.zeros((len(all_completions),))
|
||||||
for reward_func in reward_funcs:
|
for reward_func in reward_funcs:
|
||||||
func_rewards = mx.array(reward_func(
|
func_rewards = mx.array(reward_func(
|
||||||
@ -314,36 +277,36 @@ def grpo_loss(
|
|||||||
))
|
))
|
||||||
rewards += func_rewards
|
rewards += func_rewards
|
||||||
|
|
||||||
# Normalize rewards if using multiple reward functions
|
|
||||||
if len(reward_funcs) > 1:
|
if len(reward_funcs) > 1:
|
||||||
rewards /= len(reward_funcs)
|
rewards /= len(reward_funcs)
|
||||||
|
|
||||||
# Compute grouped-wise rewards
|
# Reshape rewards and compute advantages following GRPO formula
|
||||||
grouped_rewards = rewards.reshape(batch_size, group_size)
|
rewards_reshaped = rewards.reshape(batch_size, group_size)
|
||||||
mean_grouped_rewards = mx.mean(grouped_rewards, axis=1)
|
mean_rewards = mx.broadcast_to(mx.mean(rewards_reshaped, axis=1)[:, None], (rewards_reshaped.shape[0], group_size)).reshape(-1)
|
||||||
std_grouped_rewards = mx.std(grouped_rewards, axis=1)
|
std_rewards = mx.broadcast_to(mx.std(rewards_reshaped, axis=1)[:, None], (rewards_reshaped.shape[0], group_size)).reshape(-1)
|
||||||
|
advantages = (rewards - mean_rewards) / (std_rewards + epsilon)
|
||||||
|
|
||||||
# Normalize rewards to compute advantages
|
# Compute KL divergence using Schulman's approximator
|
||||||
mean_grouped_rewards = mx.repeat(mean_grouped_rewards.reshape(-1, 1), group_size, axis=1).reshape(-1)
|
kl_div = mx.exp(ref_token_log_probs - token_log_probs) - (ref_token_log_probs - token_log_probs) - 1
|
||||||
std_grouped_rewards = mx.repeat(std_grouped_rewards.reshape(-1, 1), group_size, axis=1).reshape(-1)
|
|
||||||
advantages = (rewards - mean_grouped_rewards) / (std_grouped_rewards + epsilon)
|
|
||||||
|
|
||||||
# Create length mask for the shifted sequence
|
# Create mask for valid tokens
|
||||||
length_mask = mx.arange(inputs.shape[1] - 1)[None, :] < (lengths[:, None] - 1)
|
length_mask = mx.arange(inputs.shape[1] - 1)[None, :] < (lengths[:, None] - 1)
|
||||||
|
|
||||||
# Calculate policy gradient loss
|
# Compute policy ratio
|
||||||
per_token_loss = mx.exp(token_log_probs - mx.stop_gradient(token_log_probs)) * advantages.reshape(-1, 1)
|
policy_ratio = mx.exp(token_log_probs - mx.stop_gradient(token_log_probs))
|
||||||
per_token_loss = -(per_token_loss - beta * kl_div)
|
|
||||||
|
|
||||||
# Normalize loss properly per sequence
|
# Compute per-token loss following GRPO formula
|
||||||
|
per_token_loss = -(policy_ratio * advantages.reshape(-1, 1) - beta * kl_div)
|
||||||
|
|
||||||
|
# Average over tokens and sequences
|
||||||
sequence_sums = (per_token_loss * length_mask).sum(axis=1)
|
sequence_sums = (per_token_loss * length_mask).sum(axis=1)
|
||||||
sequence_lengths = length_mask.sum(axis=1)
|
sequence_lengths = length_mask.sum(axis=1)
|
||||||
loss = (sequence_sums / sequence_lengths).mean()
|
loss = (sequence_sums / sequence_lengths).mean()
|
||||||
|
|
||||||
# Calculate mean KL divergence
|
# Calculate mean KL divergence for metrics
|
||||||
mean_kl = ((kl_div * length_mask).sum(axis=1) / length_mask.sum(axis=1)).mean()
|
mean_kl = ((kl_div * length_mask).sum(axis=1) / length_mask.sum(axis=1)).mean()
|
||||||
|
|
||||||
# Collect metrics for each reward function separately
|
# Collect reward metrics
|
||||||
reward_metrics = {}
|
reward_metrics = {}
|
||||||
for i, reward_func in enumerate(reward_funcs):
|
for i, reward_func in enumerate(reward_funcs):
|
||||||
func_rewards = mx.array(reward_func(
|
func_rewards = mx.array(reward_func(
|
||||||
@ -351,15 +314,18 @@ def grpo_loss(
|
|||||||
completions=all_completion_texts,
|
completions=all_completion_texts,
|
||||||
answer=answer_text
|
answer=answer_text
|
||||||
))
|
))
|
||||||
# func_grouped_rewards = func_rewards.reshape(batch_size, group_size)
|
|
||||||
reward_metrics[f'reward_func_{i}_mean'] = mx.mean(func_rewards)
|
reward_metrics[f'reward_func_{i}_mean'] = mx.mean(func_rewards)
|
||||||
reward_metrics[f'reward_func_{i}_std'] = mx.std(func_rewards)
|
reward_metrics[f'reward_func_{i}_std'] = mx.std(func_rewards)
|
||||||
|
|
||||||
|
# Clean up
|
||||||
|
del all_completions
|
||||||
|
mx.metal.clear_cache()
|
||||||
|
|
||||||
metrics = {
|
metrics = {
|
||||||
'total_rewards_mean': mx.mean(rewards),
|
'total_rewards_mean': mx.mean(rewards),
|
||||||
'total_rewards_std': mx.std(rewards),
|
'total_rewards_std': mx.std(rewards),
|
||||||
'grouped_rewards_mean': mx.mean(grouped_rewards),
|
'grouped_rewards_mean': mx.mean(rewards_reshaped),
|
||||||
'grouped_rewards_std': mx.std(grouped_rewards),
|
'grouped_rewards_std': mx.std(rewards_reshaped),
|
||||||
'kl': mean_kl,
|
'kl': mean_kl,
|
||||||
**reward_metrics
|
**reward_metrics
|
||||||
}
|
}
|
||||||
@ -368,30 +334,15 @@ def grpo_loss(
|
|||||||
|
|
||||||
|
|
||||||
def iterate_grpo_batches(dataset, tokenizer, batch_size, max_seq_length, train=False):
|
def iterate_grpo_batches(dataset, tokenizer, batch_size, max_seq_length, train=False):
|
||||||
"""
|
"""Memory-optimized version of iterate_grpo_batches"""
|
||||||
Creates batches from dataset entries for GRPO training.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
dataset: List of (prompt_tokens, answer_tokens, prompt_str, answer_str) tuples
|
|
||||||
tokenizer: Tokenizer for processing inputs
|
|
||||||
batch_size: Size of each batch
|
|
||||||
max_seq_length: Maximum sequence length
|
|
||||||
train: Whether this is for training
|
|
||||||
|
|
||||||
Yields:
|
|
||||||
Tuple containing:
|
|
||||||
- prompts_tokens: List of token sequences for current batch
|
|
||||||
- answers_tokens: List of token sequences
|
|
||||||
- prompts_text: List of prompt strings
|
|
||||||
- answers_text: List of answer strings
|
|
||||||
"""
|
|
||||||
# Verify dataset format
|
|
||||||
if not dataset or not isinstance(dataset[0], tuple) or len(dataset[0]) != 4:
|
if not dataset or not isinstance(dataset[0], tuple) or len(dataset[0]) != 4:
|
||||||
raise ValueError("Dataset must be list of (prompt_tokens, answer_tokens, prompt_str, answer_str) tuples")
|
raise ValueError("Dataset must be list of (prompt_tokens, answer_tokens, prompt_str, answer_str) tuples")
|
||||||
|
|
||||||
# Sort by combined length of prompt + answer tokens
|
# Sort by length but use generator to avoid keeping full sorted list in memory
|
||||||
idx = sorted(range(len(dataset)),
|
def length_key(i):
|
||||||
key=lambda i: len(dataset[i][0]) + len(dataset[i][1]))
|
return len(dataset[i][0]) + len(dataset[i][1])
|
||||||
|
|
||||||
|
idx = sorted(range(len(dataset)), key=length_key)
|
||||||
|
|
||||||
if len(dataset) < batch_size:
|
if len(dataset) < batch_size:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
@ -399,26 +350,24 @@ def iterate_grpo_batches(dataset, tokenizer, batch_size, max_seq_length, train=F
|
|||||||
f"examples but only has {len(dataset)}."
|
f"examples but only has {len(dataset)}."
|
||||||
)
|
)
|
||||||
|
|
||||||
# Handle distributed training
|
|
||||||
step = mx.distributed.init().size()
|
step = mx.distributed.init().size()
|
||||||
if batch_size % step != 0:
|
if batch_size % step != 0:
|
||||||
raise ValueError("The batch size must be divisible by the number of workers")
|
raise ValueError("The batch size must be divisible by the number of workers")
|
||||||
|
|
||||||
# Create batch indices
|
# Use generator for batch indices
|
||||||
batch_idx = [
|
def batch_index_generator():
|
||||||
idx[i : i + batch_size : step]
|
for i in range(0, len(idx) - batch_size + 1, batch_size):
|
||||||
for i in range(0, len(idx) - batch_size + 1, batch_size)
|
yield idx[i : i + batch_size : step]
|
||||||
]
|
|
||||||
|
|
||||||
while True:
|
while True:
|
||||||
# Shuffle batch indices if training
|
indices = (
|
||||||
indices = np.random.permutation(len(batch_idx)) if train else range(len(batch_idx))
|
np.random.permutation(list(batch_index_generator())) if train
|
||||||
|
else batch_index_generator()
|
||||||
|
)
|
||||||
|
|
||||||
for i in indices:
|
for batch_idx in indices:
|
||||||
# Get current batch
|
current_batch = [dataset[j] for j in batch_idx]
|
||||||
current_batch = [dataset[j] for j in batch_idx[i]]
|
|
||||||
|
|
||||||
# Extract all components
|
|
||||||
prompts_tokens = [item[0] for item in current_batch]
|
prompts_tokens = [item[0] for item in current_batch]
|
||||||
answers_tokens = [item[1] for item in current_batch]
|
answers_tokens = [item[1] for item in current_batch]
|
||||||
prompts_text = [item[2] for item in current_batch]
|
prompts_text = [item[2] for item in current_batch]
|
||||||
@ -553,7 +502,8 @@ def train_grpo(
|
|||||||
beta=args.beta,
|
beta=args.beta,
|
||||||
group_size=args.group_size,
|
group_size=args.group_size,
|
||||||
epsilon=args.epsilon,
|
epsilon=args.epsilon,
|
||||||
ref_model=ref_model
|
ref_model=ref_model,
|
||||||
|
max_tokens=args.max_seq_length,
|
||||||
)
|
)
|
||||||
|
|
||||||
# All reduce the gradients if running in distributed mode
|
# All reduce the gradients if running in distributed mode
|
||||||
@ -649,8 +599,10 @@ def train_grpo(
|
|||||||
losses += loss
|
losses += loss
|
||||||
n_tokens += toks
|
n_tokens += toks
|
||||||
steps += 1
|
steps += 1
|
||||||
|
|
||||||
for k, v in metrics.items():
|
for k, v in metrics.items():
|
||||||
accumulated_metrics[k] += v
|
accumulated_metrics[k] += v
|
||||||
|
|
||||||
mx.eval(state, losses, n_tokens)
|
mx.eval(state, losses, n_tokens)
|
||||||
|
|
||||||
if it % args.steps_per_report == 0 or it == args.iters:
|
if it % args.steps_per_report == 0 or it == args.iters:
|
||||||
|
Loading…
Reference in New Issue
Block a user