mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-25 01:41:19 +08:00
seperate functions
This commit is contained in:
parent
d723ddfeda
commit
f13a0d04ca
@ -454,3 +454,176 @@ def main():
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
def compute_grpo_loss_and_grad(
|
||||
model,
|
||||
ref_model,
|
||||
completion_tensors,
|
||||
prompt_texts,
|
||||
answer_texts,
|
||||
beta=0.1,
|
||||
epsilon=1e-4,
|
||||
reward_funcs=None,
|
||||
reward_weights=None
|
||||
):
|
||||
"""
|
||||
Compute GRPO loss and gradients using pre-generated completions.
|
||||
|
||||
Args:
|
||||
model: The policy model
|
||||
ref_model: The reference model
|
||||
completion_tensors: List of tensors containing generated completions
|
||||
prompt_texts: List of prompt texts
|
||||
answer_texts: List of answer texts
|
||||
beta: KL penalty coefficient
|
||||
epsilon: Numerical stability constant
|
||||
reward_funcs: List of reward functions
|
||||
reward_weights: Optional weights for reward functions
|
||||
"""
|
||||
# Ensure model is in training mode for gradient computation
|
||||
model.train()
|
||||
|
||||
# Get completion texts for reward calculation
|
||||
completion_texts = [tokenizer.decode(comp.tolist()) for comp in completion_tensors]
|
||||
|
||||
# Prepare inputs for loss computation
|
||||
max_length = max(tensor.shape[0] for tensor in completion_tensors)
|
||||
padded_completions = []
|
||||
attention_masks = []
|
||||
|
||||
for completion_tensor in completion_tensors:
|
||||
padding_length = max_length - completion_tensor.shape[0]
|
||||
if padding_length > 0:
|
||||
padding = mx.zeros((padding_length,), dtype=completion_tensor.dtype)
|
||||
padded_ids = mx.concatenate([completion_tensor, padding])
|
||||
mask = mx.concatenate(
|
||||
[mx.ones_like(completion_tensor), mx.zeros_like(padding)]
|
||||
)
|
||||
else:
|
||||
padded_ids = completion_tensor
|
||||
mask = mx.ones_like(completion_tensor)
|
||||
padded_completions.append(padded_ids)
|
||||
attention_masks.append(mask)
|
||||
|
||||
inputs = mx.stack(padded_completions)
|
||||
attention_mask = mx.stack(attention_masks)
|
||||
lengths = attention_mask.sum(axis=1)
|
||||
|
||||
# Compute log probabilities for both models
|
||||
token_log_probs = get_per_token_logps(model, inputs, lengths)
|
||||
|
||||
if ref_model is None:
|
||||
ref_token_log_probs = [mx.stop_gradient(tlp) for tlp in token_log_probs]
|
||||
else:
|
||||
ref_token_log_probs = get_per_token_logps(ref_model, inputs, lengths)
|
||||
ref_token_log_probs = [mx.stop_gradient(tlp) for tlp in ref_token_log_probs]
|
||||
|
||||
# Pad log probabilities to same length
|
||||
max_len = max(x.shape[0] for x in token_log_probs)
|
||||
padded_log_probs = []
|
||||
padded_ref_log_probs = []
|
||||
|
||||
for i in range(len(token_log_probs)):
|
||||
seq_len = token_log_probs[i].shape[0]
|
||||
padding = mx.zeros((max_len - seq_len,))
|
||||
|
||||
padded_log_probs.append(mx.concatenate([token_log_probs[i], padding]))
|
||||
padded_ref_log_probs.append(mx.concatenate([ref_token_log_probs[i], padding]))
|
||||
|
||||
token_log_probs = mx.stack(padded_log_probs)
|
||||
ref_token_log_probs = mx.stack(padded_ref_log_probs)
|
||||
|
||||
# Calculate rewards
|
||||
all_func_rewards = []
|
||||
for reward_func in reward_funcs:
|
||||
func_rewards = mx.array(
|
||||
reward_func(
|
||||
prompts=prompt_texts,
|
||||
completions=completion_texts,
|
||||
answer=answer_texts,
|
||||
)
|
||||
)
|
||||
all_func_rewards.append(func_rewards)
|
||||
|
||||
# Stack rewards and apply weights
|
||||
rewards = mx.stack(all_func_rewards, axis=1)
|
||||
if reward_weights is not None:
|
||||
if len(reward_weights) != len(reward_funcs):
|
||||
raise ValueError(
|
||||
f"Number of reward weights ({len(reward_weights)}) must match number of reward "
|
||||
f"functions ({len(reward_funcs)})"
|
||||
)
|
||||
reward_weights = mx.array(reward_weights, dtype=mx.float32)
|
||||
else:
|
||||
reward_weights = mx.ones(len(reward_funcs), dtype=mx.float32)
|
||||
|
||||
rewards = (rewards * mx.expand_dims(reward_weights, 0)).sum(axis=1)
|
||||
|
||||
# Group rewards by prompt (assuming completions are grouped by prompt)
|
||||
group_size = len(completion_tensors) // len(prompt_texts)
|
||||
if len(completion_tensors) % len(prompt_texts) != 0:
|
||||
raise ValueError("Number of completions must be divisible by number of prompts")
|
||||
|
||||
rewards_by_group = []
|
||||
for i in range(0, len(rewards), group_size):
|
||||
rewards_by_group.append(rewards[i:i+group_size])
|
||||
|
||||
# Calculate advantages
|
||||
advantages = mx.zeros_like(rewards)
|
||||
for i, group_rewards in enumerate(rewards_by_group):
|
||||
if len(group_rewards) > 1: # Only normalize if we have multiple samples
|
||||
mean_reward = mx.mean(group_rewards)
|
||||
std_reward = mx.std(group_rewards)
|
||||
|
||||
for j in range(group_size):
|
||||
idx = i * group_size + j
|
||||
advantages[idx] = (group_rewards[j] - mean_reward) / (std_reward + epsilon)
|
||||
else:
|
||||
# If only one sample, advantage is 0
|
||||
advantages[i * group_size] = 0.0
|
||||
|
||||
# Compute KL divergence
|
||||
kl_div = (
|
||||
mx.exp(ref_token_log_probs - token_log_probs)
|
||||
- (ref_token_log_probs - token_log_probs)
|
||||
- 1
|
||||
)
|
||||
|
||||
# Create mask for valid tokens
|
||||
length_mask = mx.arange(inputs.shape[1] - 1)[None, :] < (lengths[:, None] - 1)
|
||||
|
||||
# Compute policy ratio
|
||||
policy_ratio = mx.exp(token_log_probs - ref_token_log_probs)
|
||||
|
||||
# Compute per-token loss
|
||||
per_token_loss = -(
|
||||
(policy_ratio * advantages.reshape(-1, 1) - beta * kl_div) * length_mask
|
||||
)
|
||||
|
||||
# Average over tokens
|
||||
sequence_sums = per_token_loss.sum(axis=1)
|
||||
sequence_lengths = length_mask.sum(axis=1)
|
||||
loss = (sequence_sums / sequence_lengths).mean()
|
||||
|
||||
# Calculate metrics for reporting
|
||||
mean_kl = ((kl_div * length_mask).sum(axis=1) / length_mask.sum(axis=1)).mean()
|
||||
|
||||
metrics = {
|
||||
"total_rewards_mean": mx.mean(rewards),
|
||||
"total_rewards_std": mx.std(rewards),
|
||||
"kl": mean_kl,
|
||||
}
|
||||
|
||||
for i, reward_func in enumerate(reward_funcs):
|
||||
func_name = reward_func.__name__
|
||||
func_rewards = all_func_rewards[i]
|
||||
metrics[f"{func_name}_mean"] = mx.mean(func_rewards)
|
||||
metrics[f"{func_name}_std"] = mx.std(func_rewards)
|
||||
|
||||
return loss, sequence_lengths.sum(), metrics
|
@ -297,6 +297,79 @@ def get_per_token_logps(model: nn.Module, inputs, lengths):
|
||||
return per_token_logps
|
||||
|
||||
|
||||
def generate_without_gradients(
|
||||
model: nn.Module,
|
||||
tokenizer,
|
||||
prompt_tokens,
|
||||
max_tokens: int,
|
||||
group_size: int,
|
||||
temperature: float = 0.8,
|
||||
batch_size: int = 1
|
||||
):
|
||||
"""Generate completions without tracking gradients"""
|
||||
|
||||
# Store original state
|
||||
was_training = model.training
|
||||
|
||||
# Force eval mode
|
||||
model.eval()
|
||||
|
||||
# Prepare prompts
|
||||
total_samples = len(prompt_tokens)
|
||||
all_completions = []
|
||||
all_completion_texts = []
|
||||
batch_indices = []
|
||||
|
||||
# Process in smaller batches
|
||||
for i in range(0, total_samples, batch_size):
|
||||
current_batch_size = min(batch_size, total_samples - i)
|
||||
batch_prompts = prompt_tokens[i : i + current_batch_size]
|
||||
|
||||
# Pad sequences to the same length
|
||||
max_prompt_len = max(len(p) for p in batch_prompts)
|
||||
padded_prompts = []
|
||||
|
||||
for prompt in batch_prompts:
|
||||
padding = [tokenizer.pad_token_id] * (max_prompt_len - len(prompt))
|
||||
padded_prompts.append(prompt + padding)
|
||||
|
||||
# Convert to tensor and explicitly stop gradient
|
||||
prompt_tensor = mx.stop_gradient(mx.array(padded_prompts))
|
||||
|
||||
try:
|
||||
completions = generate_grpo(
|
||||
model,
|
||||
prompt_tensor,
|
||||
max_tokens,
|
||||
tokenizer,
|
||||
group_size,
|
||||
temperature=temperature,
|
||||
batch_size=current_batch_size,
|
||||
)
|
||||
|
||||
if completions is not None:
|
||||
for j, completion_ids in enumerate(completions):
|
||||
prompt_idx = i + (j // group_size)
|
||||
|
||||
if prompt_idx < total_samples:
|
||||
batch_indices.append(prompt_idx)
|
||||
completion_text = tokenizer.decode(completion_ids.tolist())
|
||||
all_completions.append(completion_ids)
|
||||
all_completion_texts.append(completion_text)
|
||||
mx.eval(completion_ids)
|
||||
except Exception as e:
|
||||
print(f"Generation error: {e}")
|
||||
continue
|
||||
|
||||
# Restore original state
|
||||
if was_training:
|
||||
model.train()
|
||||
|
||||
mx.metal.clear_cache()
|
||||
|
||||
return all_completions, all_completion_texts, batch_indices
|
||||
|
||||
|
||||
def grpo_loss(
|
||||
model,
|
||||
ref_model,
|
||||
@ -313,70 +386,17 @@ def grpo_loss(
|
||||
is_validation: bool = False
|
||||
):
|
||||
prompt_tokens, _, prompt_text, answer_text = batch
|
||||
total_samples = len(prompt_tokens)
|
||||
|
||||
all_completions = []
|
||||
all_completion_texts = []
|
||||
batch_indices = [] # Keep track of which batch each completion belongs to
|
||||
|
||||
# Store original training state
|
||||
was_training = model.training
|
||||
print(f"Was model in training mode: {was_training}")
|
||||
|
||||
# Set model to eval mode for generation
|
||||
model.eval()
|
||||
print(f"Is model now in training mode: {model.training}")
|
||||
|
||||
# Process in smaller batches
|
||||
for i in range(0, total_samples, batch_size):
|
||||
# Get actual batch size for this iteration (might be smaller for the last batch)
|
||||
current_batch_size = min(batch_size, total_samples - i)
|
||||
batch_prompts = prompt_tokens[i : i + current_batch_size]
|
||||
|
||||
# Pad sequences to the same length
|
||||
max_prompt_len = max(len(p) for p in batch_prompts)
|
||||
padded_prompts = []
|
||||
|
||||
for prompt in batch_prompts:
|
||||
padding = [tokenizer.pad_token_id] * (max_prompt_len - len(prompt))
|
||||
padded_prompts.append(prompt + padding)
|
||||
|
||||
# Convert to tensor
|
||||
prompt_tensor = mx.array(padded_prompts)
|
||||
prompt_tensor = mx.stop_gradient(prompt_tensor) # Explicitly stop gradient on input
|
||||
|
||||
try:
|
||||
mx.metal.clear_cache()
|
||||
completions = generate_grpo(
|
||||
model,
|
||||
prompt_tensor,
|
||||
max_tokens,
|
||||
tokenizer,
|
||||
group_size,
|
||||
temperature=temperature,
|
||||
batch_size=current_batch_size,
|
||||
)
|
||||
|
||||
if completions is not None:
|
||||
for j, completion_ids in enumerate(completions):
|
||||
# Calculate which prompt this completion belongs to
|
||||
prompt_idx = i + (j // group_size)
|
||||
|
||||
if prompt_idx < total_samples: # Make sure we don't go out of bounds
|
||||
batch_indices.append(prompt_idx)
|
||||
completion_text = tokenizer.decode(completion_ids.tolist())
|
||||
all_completions.append(completion_ids)
|
||||
all_completion_texts.append(completion_text)
|
||||
mx.eval(completion_ids)
|
||||
except Exception as e:
|
||||
print(f"Generation error: {e}")
|
||||
print(f"Is model in training mode after generation: {model.training}")
|
||||
continue
|
||||
|
||||
# Restore original training state if we're not in validation mode
|
||||
if was_training:
|
||||
model.train()
|
||||
mx.metal.clear_cache()
|
||||
# Generate completions without tracking gradients
|
||||
all_completions, all_completion_texts, batch_indices = generate_without_gradients(
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
prompt_tokens=prompt_tokens,
|
||||
max_tokens=max_tokens,
|
||||
group_size=group_size,
|
||||
temperature=temperature,
|
||||
batch_size=batch_size
|
||||
)
|
||||
|
||||
# If we didn't generate any completions, return early
|
||||
if not all_completions:
|
||||
@ -415,25 +435,30 @@ def grpo_loss(
|
||||
all_completion_texts = ordered_completion_texts
|
||||
batch_indices = ordered_batch_indices
|
||||
|
||||
# Continue with the rest of the function
|
||||
# Create new input tensors for the model to compute logits with gradient tracking
|
||||
max_length = max(ids.shape[0] for ids in all_completions)
|
||||
padded_completions = []
|
||||
attention_masks = []
|
||||
|
||||
for completion_ids in all_completions:
|
||||
padding_length = max_length - completion_ids.shape[0]
|
||||
# Convert the pre-generated completion to a regular tensor (not stop_gradient)
|
||||
# This allows gradients to flow during the loss computation phase
|
||||
completion_tensor = mx.array(completion_ids.tolist())
|
||||
|
||||
padding_length = max_length - completion_tensor.shape[0]
|
||||
if padding_length > 0:
|
||||
padding = mx.zeros((padding_length,), dtype=completion_ids.dtype)
|
||||
padded_ids = mx.concatenate([completion_ids, padding])
|
||||
padding = mx.zeros((padding_length,), dtype=completion_tensor.dtype)
|
||||
padded_ids = mx.concatenate([completion_tensor, padding])
|
||||
mask = mx.concatenate(
|
||||
[mx.ones_like(completion_ids), mx.zeros_like(padding)]
|
||||
[mx.ones_like(completion_tensor), mx.zeros_like(padding)]
|
||||
)
|
||||
else:
|
||||
padded_ids = completion_ids
|
||||
mask = mx.ones_like(completion_ids)
|
||||
padded_ids = completion_tensor
|
||||
mask = mx.ones_like(completion_tensor)
|
||||
padded_completions.append(padded_ids)
|
||||
attention_masks.append(mask)
|
||||
|
||||
# Rest of the function remains the same
|
||||
inputs = mx.stack(padded_completions)
|
||||
attention_mask = mx.stack(attention_masks)
|
||||
lengths = attention_mask.sum(axis=1)
|
||||
@ -721,7 +746,6 @@ def evaluate_grpo(
|
||||
ref_model=ref_model,
|
||||
temperature=temperature,
|
||||
max_tokens=max_tokens,
|
||||
is_validation=True
|
||||
)
|
||||
|
||||
all_losses += losses * toks
|
||||
|
Loading…
Reference in New Issue
Block a user