From 80bcf6895698155b3ebd56c5d48cc88bf438e674 Mon Sep 17 00:00:00 2001 From: Goekdeniz-Guelmez Date: Fri, 31 Jan 2025 16:54:18 +0100 Subject: [PATCH] grpo_trainer shoudl be done --- llms/mlx_lm/tuner/grpo_trainer.py | 119 +++++++++++++----------------- 1 file changed, 51 insertions(+), 68 deletions(-) diff --git a/llms/mlx_lm/tuner/grpo_trainer.py b/llms/mlx_lm/tuner/grpo_trainer.py index 9a8a57b7..b5030ce0 100644 --- a/llms/mlx_lm/tuner/grpo_trainer.py +++ b/llms/mlx_lm/tuner/grpo_trainer.py @@ -240,6 +240,7 @@ def evaluate_grpo( epslion: float, group_size: int, max_seq_length, + reward_funcs = None, loss: callable = grpo_loss, iterate_batches: callable = iterate_batches ): @@ -257,52 +258,7 @@ def evaluate_grpo( max_seq_length=max_seq_length, ), ): - losses, toks = loss( - model, - *batch - ) - all_losses += losses * toks - ntokens += toks - mx.eval(all_losses, ntokens) - - all_losses = mx.distributed.all_sum(all_losses, stream=mx.cpu) - ntokens = mx.distributed.all_sum(ntokens, stream=mx.cpu) - - return (all_losses / ntokens).item() - -def evaluate_grpo( - model, - ref_model, - dataset, - tokenizer, - batch_size, - num_batches, - beta: float, - epslion: float, - group_size: int, - max_seq_length, - reward_funcs=None, - loss: callable = grpo_loss, - iterate_batches: callable = iterate_batches -): - all_losses = 0 - ntokens = 0 - - index_iterator = iter(range(num_batches)) if num_batches != -1 else iter(int, 1) - - for _, batch in zip( - index_iterator, - iterate_batches( - dataset=dataset, - tokenizer=tokenizer, - batch_size=batch_size, - max_seq_length=max_seq_length, - ), - ): - # Extract prompts from the batch (assuming the batch contains 'prompts') - prompts = batch.get("prompts", None) - - # Call the loss function with the correct arguments + prompts = batch losses, toks, metrics = loss( model=model, tokenizer=tokenizer, @@ -313,15 +269,25 @@ def evaluate_grpo( epslion=epslion, ref_model=ref_model ) - all_losses += losses * toks ntokens += toks + + if all_metrics is None: + all_metrics = {k: v * toks for k, v in metrics.items()} + else: + for k, v in metrics.items(): + all_metrics[k] += v * toks + mx.eval(all_losses, ntokens) all_losses = mx.distributed.all_sum(all_losses, stream=mx.cpu) ntokens = mx.distributed.all_sum(ntokens, stream=mx.cpu) + all_metrics = {k: mx.distributed.all_sum(v) for k, v in all_metrics.items()} - return (all_losses / ntokens).item() + avg_metrics = {k: (v / ntokens).item() for k, v in all_metrics.items()} + avg_loss = (all_losses / ntokens).item() + + return avg_loss, ntokens, avg_metrics def train( @@ -335,7 +301,7 @@ def train( iterate_batches: callable = iterate_batches, training_callback: TrainingCallback = None, ): - print(f"Starting training..., iters: {args.iters}") + print(f"Starting GRPO training..., iters: {args.iters}") world = mx.distributed.init() world_size = world.size() rank = world.rank() @@ -349,7 +315,7 @@ def train( def step(batch): # Forward and backward pass - (lvalue, toks), grad = loss_value_and_grad(model, *batch) + (loss, toks, metrics), grad = loss_value_and_grad(model, *batch) # All reduce the gradients if running in distributed mode grad = average_gradients(grad) @@ -357,18 +323,22 @@ def train( # Model update optimizer.update(model, grad) - return lvalue, toks + return loss, toks, metrics loss_value_and_grad = nn.value_and_grad(model, loss) - # Save initial model weights as reference - ref_weights = {k: v.copy() for k, v in model.parameters().items()} - losses = 0 n_tokens = 0 steps = 0 trained_tokens = 0 - # Main training loop + accumulated_metrics = { + 'rewards': 0, + 'rewards_std': 0, + 'grouped_rewards': 0, + 'grouped_rewards_std': 0, + 'kl': 0 + } + start = time.perf_counter() for it, batch in zip( range(1, args.iters + 1), @@ -384,7 +354,7 @@ def train( # is always measured before any training. if it == 1 or it % args.steps_per_eval == 0 or it == args.iters: stop = time.perf_counter() - val_loss = evaluate( + val_loss, val_ntokens, val_metrics = evaluate( model=model, dataset=val_dataset, loss=loss, @@ -398,61 +368,74 @@ def train( if rank == 0: print( f"Iter {it}: " - f"Val loss {val_loss:.3f}, " + f"Val loss {val_loss:.8f}, " + f"Val rewards {val_metrics['rewards']:.3f}, " + f"Val rewards_std {val_metrics['rewards_std']:.3f}, " + f"Val grouped_rewards {val_metrics['grouped_rewards']:.3f}, " + f"Val grouped_rewards_std {val_metrics['grouped_rewards_std']:.3f}, " + f"Val kl {val_metrics['kl']:.3f}, " f"Val took {val_time:.3f}s", flush=True, ) if training_callback is not None: - val_info = { + training_callback.on_val_loss_report({ "iteration": it, "val_loss": val_loss, + **{f"val_{k}": v for k, v in val_metrics.items()}, "val_time": val_time, - } - training_callback.on_val_loss_report(val_info) + }) start = time.perf_counter() - lvalue, toks = step(batch) - losses += lvalue + loss, toks, metrics = step(batch) + losses += loss n_tokens += toks steps += 1 + for k, v in metrics.items(): + accumulated_metrics[k] += v mx.eval(state, losses, n_tokens) - # Report training loss if needed if it % args.steps_per_report == 0 or it == args.iters: stop = time.perf_counter() train_loss = mx.distributed.all_sum(losses, stream=mx.cpu).item() train_loss /= steps * mx.distributed.init().size() + avg_metrics = {k: v / (steps * world_size) for k, v in accumulated_metrics.items()} n_tokens = mx.distributed.all_sum(n_tokens, stream=mx.cpu).item() learning_rate = optimizer.learning_rate.item() it_sec = args.steps_per_report / (stop - start) tokens_sec = float(n_tokens) / (stop - start) trained_tokens += n_tokens peak_mem = mx.metal.get_peak_memory() / 1e9 + if rank == 0: print( - f"Iter {it}: Train loss {train_loss:.3f}, " + f"Iter {it}: Train loss {train_loss:.8f}, " + f"Rewards {avg_metrics['rewards']:.3f}, " + f"Rewards_std {avg_metrics['rewards_std']:.3f}, " + f"Grouped Rewards {avg_metrics['grouped_rewards']:.3f}, " + f"Grouped Rewards {avg_metrics['grouped_rewards']:.3f}, " + f"Grouped Rewards_std {val_metrics['grouped_rewards_std']:.3f}, " + f"KL {val_metrics['kl']:.3f}, " f"Learning Rate {learning_rate:.3e}, " f"It/sec {it_sec:.3f}, " f"Tokens/sec {tokens_sec:.3f}, " - f"Trained Tokens {trained_tokens}, " f"Peak mem {peak_mem:.3f} GB", flush=True, ) if training_callback is not None: - train_info = { + training_callback.on_train_loss_report({ "iteration": it, "train_loss": train_loss, + **{f"train_{k}": v for k, v in avg_metrics.items()}, "learning_rate": learning_rate, "iterations_per_second": it_sec, "tokens_per_second": tokens_sec, "trained_tokens": trained_tokens, "peak_memory": peak_mem, - } - training_callback.on_train_loss_report(train_info) + }) losses = 0 n_tokens = 0