From b7bc8115078dd093bca583c9291d9d23fec1c752 Mon Sep 17 00:00:00 2001 From: Goekdeniz-Guelmez Date: Mon, 10 Feb 2025 19:45:19 +0100 Subject: [PATCH] nits --- llms/mlx_lm/tuner/grpo_trainer.py | 15 +++++---------- 1 file changed, 5 insertions(+), 10 deletions(-) diff --git a/llms/mlx_lm/tuner/grpo_trainer.py b/llms/mlx_lm/tuner/grpo_trainer.py index 36b44ac2..ab1ab605 100644 --- a/llms/mlx_lm/tuner/grpo_trainer.py +++ b/llms/mlx_lm/tuner/grpo_trainer.py @@ -159,24 +159,19 @@ def get_per_token_logps(model, inputs, lengths): logits = model(inputs).astype(mx.float16) logits = logits[:, :-1, :] targets = inputs[:, 1:] - + mx.eval(logits) per_token_logps = [] for i in range(logits.shape[0]): seq_len = int(lengths[i]) - 1 - seq_logits = logits[i, :seq_len] seq_targets = targets[i, :seq_len] - log_probs = nn.log_softmax(seq_logits, axis=-1) - token_log_probs = mx.take_along_axis( log_probs, seq_targets.reshape(seq_len, 1), axis=-1 ).squeeze(-1) - per_token_logps.append(token_log_probs) - mx.eval(logits) return per_token_logps @@ -270,8 +265,8 @@ def grpo_loss( padded_log_probs.append(mx.concatenate([token_log_probs[i], padding])) padded_ref_log_probs.append(mx.concatenate([ref_token_log_probs[i], padding])) - token_log_probs = mx.stack(padded_log_probs).astype(mx.float32) - ref_token_log_probs = mx.stack(padded_ref_log_probs).astype(mx.float32) + token_log_probs = mx.stack(padded_log_probs) + ref_token_log_probs = mx.stack(padded_ref_log_probs) # Calculate rewards and advantages rewards = mx.zeros((len(all_completions),)) @@ -299,7 +294,7 @@ def grpo_loss( length_mask = mx.arange(inputs.shape[1] - 1)[None, :] < (lengths[:, None] - 1) # Compute policy ratio - policy_ratio = mx.exp(mx.array(token_log_probs - mx.stop_gradient(ref_token_log_probs), dtype=mx.float32)) + policy_ratio = mx.exp(mx.array(token_log_probs - mx.stop_gradient(ref_token_log_probs))) # Compute per-token loss following GRPO formula per_token_loss = -((policy_ratio * advantages.reshape(-1, 1) - beta * kl_div) * length_mask) @@ -580,7 +575,7 @@ def train_grpo( for i, reward_func in enumerate(reward_funcs): val_metrics_str += ( f", Val {reward_func.__name__}_mean {val_metrics[f'{reward_func.__name__}_mean']:.3f}, " - # f"Val {reward_func.__name__}_std {val_metrics[f'{reward_func.__name__}_std']:.3f}" + f"Val {reward_func.__name__}_std {val_metrics[f'{reward_func.__name__}_std']:.3f}" ) print(