This commit is contained in:
Goekdeniz-Guelmez 2025-03-09 00:18:33 +01:00
parent 0bc2a881ad
commit e88f0fad4b

View File

@ -121,9 +121,9 @@ def generate_grpo(
prompt_tokens,
max_tokens: int,
group_size: int,
end_token: str = "</answer>",
temperature: float = 0.8,
batch_size: int = 1,
temperature: float,
batch_size: int,
end_token: str = "</answer>"
):
try:
end_sequence = mx.array(tokenizer.encode(end_token))
@ -239,7 +239,6 @@ def grpo_loss(
expanded_answers = []
expanded_prompts = []
unique_prompt_indices = sorted(set(batch_indices))
grouped_completions = {idx: [] for idx in unique_prompt_indices}
@ -262,7 +261,6 @@ def grpo_loss(
all_completions = ordered_completions
all_completion_texts = ordered_completion_texts
batch_indices = ordered_batch_indices
max_length = max(ids.shape[0] for ids in all_completions)
padded_completions = []
attention_masks = []
@ -617,11 +615,8 @@ def train_grpo(
state = [model.state, optimizer.state]
def step(batch):
# Extract prompt tokens from the batch
prompt_tokens, targets, prompt_lens, target_lens = batch
# First, generate completions without gradient tracking
# The model will be frozen during this call
all_completions, all_completion_texts, batch_indices = generate_grpo(
model=model,
tokenizer=tokenizer,
@ -630,9 +625,7 @@ def train_grpo(
group_size=args.group_size,
temperature=args.temperature
)
# Now calculate loss and gradients with pre-generated completions
# We need to update loss_fn to accept these pre-generated completions
(loss, toks, metrics), grad = loss_value_and_grad(
model,
tokenizer=tokenizer,