mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-12-16 02:08:55 +08:00
seperate functions
This commit is contained in:
@@ -297,6 +297,79 @@ def get_per_token_logps(model: nn.Module, inputs, lengths):
|
||||
return per_token_logps
|
||||
|
||||
|
||||
def generate_without_gradients(
|
||||
model: nn.Module,
|
||||
tokenizer,
|
||||
prompt_tokens,
|
||||
max_tokens: int,
|
||||
group_size: int,
|
||||
temperature: float = 0.8,
|
||||
batch_size: int = 1
|
||||
):
|
||||
"""Generate completions without tracking gradients"""
|
||||
|
||||
# Store original state
|
||||
was_training = model.training
|
||||
|
||||
# Force eval mode
|
||||
model.eval()
|
||||
|
||||
# Prepare prompts
|
||||
total_samples = len(prompt_tokens)
|
||||
all_completions = []
|
||||
all_completion_texts = []
|
||||
batch_indices = []
|
||||
|
||||
# Process in smaller batches
|
||||
for i in range(0, total_samples, batch_size):
|
||||
current_batch_size = min(batch_size, total_samples - i)
|
||||
batch_prompts = prompt_tokens[i : i + current_batch_size]
|
||||
|
||||
# Pad sequences to the same length
|
||||
max_prompt_len = max(len(p) for p in batch_prompts)
|
||||
padded_prompts = []
|
||||
|
||||
for prompt in batch_prompts:
|
||||
padding = [tokenizer.pad_token_id] * (max_prompt_len - len(prompt))
|
||||
padded_prompts.append(prompt + padding)
|
||||
|
||||
# Convert to tensor and explicitly stop gradient
|
||||
prompt_tensor = mx.stop_gradient(mx.array(padded_prompts))
|
||||
|
||||
try:
|
||||
completions = generate_grpo(
|
||||
model,
|
||||
prompt_tensor,
|
||||
max_tokens,
|
||||
tokenizer,
|
||||
group_size,
|
||||
temperature=temperature,
|
||||
batch_size=current_batch_size,
|
||||
)
|
||||
|
||||
if completions is not None:
|
||||
for j, completion_ids in enumerate(completions):
|
||||
prompt_idx = i + (j // group_size)
|
||||
|
||||
if prompt_idx < total_samples:
|
||||
batch_indices.append(prompt_idx)
|
||||
completion_text = tokenizer.decode(completion_ids.tolist())
|
||||
all_completions.append(completion_ids)
|
||||
all_completion_texts.append(completion_text)
|
||||
mx.eval(completion_ids)
|
||||
except Exception as e:
|
||||
print(f"Generation error: {e}")
|
||||
continue
|
||||
|
||||
# Restore original state
|
||||
if was_training:
|
||||
model.train()
|
||||
|
||||
mx.metal.clear_cache()
|
||||
|
||||
return all_completions, all_completion_texts, batch_indices
|
||||
|
||||
|
||||
def grpo_loss(
|
||||
model,
|
||||
ref_model,
|
||||
@@ -313,70 +386,17 @@ def grpo_loss(
|
||||
is_validation: bool = False
|
||||
):
|
||||
prompt_tokens, _, prompt_text, answer_text = batch
|
||||
total_samples = len(prompt_tokens)
|
||||
|
||||
all_completions = []
|
||||
all_completion_texts = []
|
||||
batch_indices = [] # Keep track of which batch each completion belongs to
|
||||
|
||||
# Store original training state
|
||||
was_training = model.training
|
||||
print(f"Was model in training mode: {was_training}")
|
||||
|
||||
# Set model to eval mode for generation
|
||||
model.eval()
|
||||
print(f"Is model now in training mode: {model.training}")
|
||||
|
||||
# Process in smaller batches
|
||||
for i in range(0, total_samples, batch_size):
|
||||
# Get actual batch size for this iteration (might be smaller for the last batch)
|
||||
current_batch_size = min(batch_size, total_samples - i)
|
||||
batch_prompts = prompt_tokens[i : i + current_batch_size]
|
||||
|
||||
# Pad sequences to the same length
|
||||
max_prompt_len = max(len(p) for p in batch_prompts)
|
||||
padded_prompts = []
|
||||
|
||||
for prompt in batch_prompts:
|
||||
padding = [tokenizer.pad_token_id] * (max_prompt_len - len(prompt))
|
||||
padded_prompts.append(prompt + padding)
|
||||
|
||||
# Convert to tensor
|
||||
prompt_tensor = mx.array(padded_prompts)
|
||||
prompt_tensor = mx.stop_gradient(prompt_tensor) # Explicitly stop gradient on input
|
||||
|
||||
try:
|
||||
mx.metal.clear_cache()
|
||||
completions = generate_grpo(
|
||||
model,
|
||||
prompt_tensor,
|
||||
max_tokens,
|
||||
tokenizer,
|
||||
group_size,
|
||||
temperature=temperature,
|
||||
batch_size=current_batch_size,
|
||||
)
|
||||
|
||||
if completions is not None:
|
||||
for j, completion_ids in enumerate(completions):
|
||||
# Calculate which prompt this completion belongs to
|
||||
prompt_idx = i + (j // group_size)
|
||||
|
||||
if prompt_idx < total_samples: # Make sure we don't go out of bounds
|
||||
batch_indices.append(prompt_idx)
|
||||
completion_text = tokenizer.decode(completion_ids.tolist())
|
||||
all_completions.append(completion_ids)
|
||||
all_completion_texts.append(completion_text)
|
||||
mx.eval(completion_ids)
|
||||
except Exception as e:
|
||||
print(f"Generation error: {e}")
|
||||
print(f"Is model in training mode after generation: {model.training}")
|
||||
continue
|
||||
|
||||
# Restore original training state if we're not in validation mode
|
||||
if was_training:
|
||||
model.train()
|
||||
mx.metal.clear_cache()
|
||||
# Generate completions without tracking gradients
|
||||
all_completions, all_completion_texts, batch_indices = generate_without_gradients(
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
prompt_tokens=prompt_tokens,
|
||||
max_tokens=max_tokens,
|
||||
group_size=group_size,
|
||||
temperature=temperature,
|
||||
batch_size=batch_size
|
||||
)
|
||||
|
||||
# If we didn't generate any completions, return early
|
||||
if not all_completions:
|
||||
@@ -415,25 +435,30 @@ def grpo_loss(
|
||||
all_completion_texts = ordered_completion_texts
|
||||
batch_indices = ordered_batch_indices
|
||||
|
||||
# Continue with the rest of the function
|
||||
# Create new input tensors for the model to compute logits with gradient tracking
|
||||
max_length = max(ids.shape[0] for ids in all_completions)
|
||||
padded_completions = []
|
||||
attention_masks = []
|
||||
|
||||
for completion_ids in all_completions:
|
||||
padding_length = max_length - completion_ids.shape[0]
|
||||
# Convert the pre-generated completion to a regular tensor (not stop_gradient)
|
||||
# This allows gradients to flow during the loss computation phase
|
||||
completion_tensor = mx.array(completion_ids.tolist())
|
||||
|
||||
padding_length = max_length - completion_tensor.shape[0]
|
||||
if padding_length > 0:
|
||||
padding = mx.zeros((padding_length,), dtype=completion_ids.dtype)
|
||||
padded_ids = mx.concatenate([completion_ids, padding])
|
||||
padding = mx.zeros((padding_length,), dtype=completion_tensor.dtype)
|
||||
padded_ids = mx.concatenate([completion_tensor, padding])
|
||||
mask = mx.concatenate(
|
||||
[mx.ones_like(completion_ids), mx.zeros_like(padding)]
|
||||
[mx.ones_like(completion_tensor), mx.zeros_like(padding)]
|
||||
)
|
||||
else:
|
||||
padded_ids = completion_ids
|
||||
mask = mx.ones_like(completion_ids)
|
||||
padded_ids = completion_tensor
|
||||
mask = mx.ones_like(completion_tensor)
|
||||
padded_completions.append(padded_ids)
|
||||
attention_masks.append(mask)
|
||||
|
||||
# Rest of the function remains the same
|
||||
inputs = mx.stack(padded_completions)
|
||||
attention_mask = mx.stack(attention_masks)
|
||||
lengths = attention_mask.sum(axis=1)
|
||||
@@ -721,7 +746,6 @@ def evaluate_grpo(
|
||||
ref_model=ref_model,
|
||||
temperature=temperature,
|
||||
max_tokens=max_tokens,
|
||||
is_validation=True
|
||||
)
|
||||
|
||||
all_losses += losses * toks
|
||||
|
||||
Reference in New Issue
Block a user