diff --git a/llms/mlx_lm/tuner/grpo_trainer.py b/llms/mlx_lm/tuner/grpo_trainer.py index c29b2f5d..b9d58c01 100644 --- a/llms/mlx_lm/tuner/grpo_trainer.py +++ b/llms/mlx_lm/tuner/grpo_trainer.py @@ -183,15 +183,24 @@ def get_per_token_logps(model, inputs, lengths): return per_token_logps +def compute_kl(logprobs1, logprobs2): + ratio = mx.exp(logprobs1 - logprobs2) + return ratio - 1 - (logprobs1 - logprobs2) + + +def compute_policy_ratio(current_logprobs, ref_logprobs): + return mx.exp(mx.array(current_logprobs - ref_logprobs, dtype=mx.float32)) + + def grpo_loss( model, + ref_model, tokenizer, batch, reward_funcs=None, beta=0.1, group_size=4, epsilon=1e-4, - ref_model=None, max_tokens=64, temperature=1.0 ): @@ -257,10 +266,10 @@ def grpo_loss( mx.metal.clear_cache() # Reference policy probabilities - if ref_model is not None: - ref_token_log_probs = get_per_token_logps(ref_model, inputs, lengths) - else: + if ref_model is None: ref_token_log_probs = token_log_probs + else: + ref_token_log_probs = get_per_token_logps(ref_model, inputs, lengths) max_len = max(x.shape[0] for x in token_log_probs) padded_log_probs = [] @@ -268,13 +277,13 @@ def grpo_loss( for i in range(len(token_log_probs)): seq_len = token_log_probs[i].shape[0] - padding = mx.zeros((max_len - seq_len,), dtype=mx.float16) + padding = mx.zeros((max_len - seq_len,)) padded_log_probs.append(mx.concatenate([token_log_probs[i], padding])) padded_ref_log_probs.append(mx.concatenate([ref_token_log_probs[i], padding])) - token_log_probs = mx.stack(padded_log_probs) - ref_token_log_probs = mx.stack(padded_ref_log_probs) + token_log_probs = mx.stack(padded_log_probs).astype(mx.float32) + ref_token_log_probs = mx.stack(padded_ref_log_probs).astype(mx.float32) # Calculate rewards and advantages rewards = mx.zeros((len(all_completions),)) @@ -296,26 +305,23 @@ def grpo_loss( advantages = (rewards - mean_rewards) / (std_rewards + epsilon) # Compute KL divergence using Schulman's approximator - kl_div = (mx.exp(token_log_probs - ref_token_log_probs) - 1) - (token_log_probs - ref_token_log_probs) + kl_div = compute_kl(token_log_probs, ref_token_log_probs) # Create mask for valid tokens length_mask = mx.arange(inputs.shape[1] - 1)[None, :] < (lengths[:, None] - 1) # Compute policy ratio - policy_ratio = mx.exp(token_log_probs - mx.stop_gradient(ref_token_log_probs)) + policy_ratio = compute_policy_ratio(token_log_probs, ref_token_log_probs) # Compute per-token loss following GRPO formula per_token_loss = -((policy_ratio * advantages.reshape(-1, 1) - beta * kl_div) * length_mask) - # Average over tokens and sequences - sequence_sums = per_token_loss.sum(axis=1) - sequence_lengths = length_mask.sum(axis=1) - - loss = (sequence_sums / sequence_lengths).mean() + # Average over tokens + loss = per_token_loss.sum().mean() # Calculate mean KL divergence for metrics mean_kl = ((kl_div * length_mask).sum(axis=1) / length_mask.sum(axis=1)).mean() - + # Collect reward metrics reward_metrics = {} for i, reward_func in enumerate(reward_funcs): @@ -326,7 +332,6 @@ def grpo_loss( answer=expanded_answers )) reward_metrics[f'{func_name}_mean'] = mx.mean(func_rewards) - reward_metrics[f'{func_name}_std'] = mx.std(func_rewards) metrics = { 'total_rewards_mean': mx.mean(rewards), @@ -338,11 +343,10 @@ def grpo_loss( } mx.metal.clear_cache() - return loss, sequence_lengths.sum(), metrics + return loss, length_mask.sum(axis=1).sum(), metrics def iterate_grpo_batches(dataset, tokenizer, batch_size, max_seq_length, train=False): - """Memory-optimized version of iterate_grpo_batches""" if not dataset or not isinstance(dataset[0], tuple) or len(dataset[0]) != 4: raise ValueError("Dataset must be list of (prompt_tokens, answer_tokens, prompt_str, answer_str) tuples") @@ -585,7 +589,7 @@ def train_grpo( for i, reward_func in enumerate(reward_funcs): val_metrics_str += ( f", Val {reward_func.__name__}_mean {val_metrics[f'{reward_func.__name__}_mean']:.3f}, " - f"Val {reward_func.__name__}_std {val_metrics[f'{reward_func.__name__}_std']:.3f}" + # f"Val {reward_func.__name__}_std {val_metrics[f'{reward_func.__name__}_std']:.3f}" ) print(