This commit is contained in:
Goekdeniz-Guelmez 2025-01-31 16:27:31 +01:00
parent 93370ff1c3
commit 6c58aa995c

View File

@ -44,7 +44,7 @@ class GRPOTrainingArgs(TrainingArgs):
)
def compute_rewards(sequences, batch_size, group_size):
def compute_default_rewards(sequences, batch_size, group_size):
"""
Args:
sequences: List of word sequences
@ -72,6 +72,7 @@ def grpo_loss(
model,
tokenizer,
prompts,
reward_funcs=None,
beta=0.1,
group_size=4,
epslion=1e-4,
@ -134,7 +135,10 @@ def grpo_loss(
kl_div = (mx.exp(ref_token_log_probs - token_log_probs) - (ref_token_log_probs - token_log_probs) - 1)
# Calculate rewards
rewards = compute_rewards(all_completions, batch_size, group_size)
if reward_funcs:
rewards = mx.array([sum(rf(all_completions) for rf in reward_funcs)])
else:
rewards = compute_default_rewards(all_completions, batch_size, group_size)
# Compute grouped-wise rewards
grouped_rewards = rewards.reshape(batch_size, group_size)
@ -266,6 +270,59 @@ def evaluate_grpo(
return (all_losses / ntokens).item()
def evaluate_grpo(
model,
ref_model,
dataset,
tokenizer,
batch_size,
num_batches,
beta: float,
epslion: float,
group_size: int,
max_seq_length,
reward_funcs=None,
loss: callable = grpo_loss,
iterate_batches: callable = iterate_batches
):
all_losses = 0
ntokens = 0
index_iterator = iter(range(num_batches)) if num_batches != -1 else iter(int, 1)
for _, batch in zip(
index_iterator,
iterate_batches(
dataset=dataset,
tokenizer=tokenizer,
batch_size=batch_size,
max_seq_length=max_seq_length,
),
):
# Extract prompts from the batch (assuming the batch contains 'prompts')
prompts = batch.get("prompts", None)
# Call the loss function with the correct arguments
losses, toks, metrics = loss(
model=model,
tokenizer=tokenizer,
prompts=prompts,
reward_funcs=reward_funcs,
beta=beta,
group_size=group_size,
epslion=epslion,
ref_model=ref_model
)
all_losses += losses * toks
ntokens += toks
mx.eval(all_losses, ntokens)
all_losses = mx.distributed.all_sum(all_losses, stream=mx.cpu)
ntokens = mx.distributed.all_sum(ntokens, stream=mx.cpu)
return (all_losses / ntokens).item()
def train(
model,