From 6c58aa995cf6a0e432c158c8249ffa922a13b27d Mon Sep 17 00:00:00 2001 From: Goekdeniz-Guelmez Date: Fri, 31 Jan 2025 16:27:31 +0100 Subject: [PATCH] updates --- llms/mlx_lm/tuner/grpo_trainer.py | 61 ++++++++++++++++++++++++++++++- 1 file changed, 59 insertions(+), 2 deletions(-) diff --git a/llms/mlx_lm/tuner/grpo_trainer.py b/llms/mlx_lm/tuner/grpo_trainer.py index 02b63e4c..9a8a57b7 100644 --- a/llms/mlx_lm/tuner/grpo_trainer.py +++ b/llms/mlx_lm/tuner/grpo_trainer.py @@ -44,7 +44,7 @@ class GRPOTrainingArgs(TrainingArgs): ) -def compute_rewards(sequences, batch_size, group_size): +def compute_default_rewards(sequences, batch_size, group_size): """ Args: sequences: List of word sequences @@ -72,6 +72,7 @@ def grpo_loss( model, tokenizer, prompts, + reward_funcs=None, beta=0.1, group_size=4, epslion=1e-4, @@ -134,7 +135,10 @@ def grpo_loss( kl_div = (mx.exp(ref_token_log_probs - token_log_probs) - (ref_token_log_probs - token_log_probs) - 1) # Calculate rewards - rewards = compute_rewards(all_completions, batch_size, group_size) + if reward_funcs: + rewards = mx.array([sum(rf(all_completions) for rf in reward_funcs)]) + else: + rewards = compute_default_rewards(all_completions, batch_size, group_size) # Compute grouped-wise rewards grouped_rewards = rewards.reshape(batch_size, group_size) @@ -266,6 +270,59 @@ def evaluate_grpo( return (all_losses / ntokens).item() +def evaluate_grpo( + model, + ref_model, + dataset, + tokenizer, + batch_size, + num_batches, + beta: float, + epslion: float, + group_size: int, + max_seq_length, + reward_funcs=None, + loss: callable = grpo_loss, + iterate_batches: callable = iterate_batches +): + all_losses = 0 + ntokens = 0 + + index_iterator = iter(range(num_batches)) if num_batches != -1 else iter(int, 1) + + for _, batch in zip( + index_iterator, + iterate_batches( + dataset=dataset, + tokenizer=tokenizer, + batch_size=batch_size, + max_seq_length=max_seq_length, + ), + ): + # Extract prompts from the batch (assuming the batch contains 'prompts') + prompts = batch.get("prompts", None) + + # Call the loss function with the correct arguments + losses, toks, metrics = loss( + model=model, + tokenizer=tokenizer, + prompts=prompts, + reward_funcs=reward_funcs, + beta=beta, + group_size=group_size, + epslion=epslion, + ref_model=ref_model + ) + + all_losses += losses * toks + ntokens += toks + mx.eval(all_losses, ntokens) + + all_losses = mx.distributed.all_sum(all_losses, stream=mx.cpu) + ntokens = mx.distributed.all_sum(ntokens, stream=mx.cpu) + + return (all_losses / ntokens).item() + def train( model,