From ca32424043e73eda80ba57ecd8c53c909fdb8ebd Mon Sep 17 00:00:00 2001 From: Goekdeniz-Guelmez Date: Mon, 3 Feb 2025 21:57:26 +0100 Subject: [PATCH] updates --- llms/mlx_lm/tuner/grpo_trainer.py | 68 +++++++++++++++++++++++++------ 1 file changed, 55 insertions(+), 13 deletions(-) diff --git a/llms/mlx_lm/tuner/grpo_trainer.py b/llms/mlx_lm/tuner/grpo_trainer.py index f4b0b9d6..29518d8f 100644 --- a/llms/mlx_lm/tuner/grpo_trainer.py +++ b/llms/mlx_lm/tuner/grpo_trainer.py @@ -164,6 +164,42 @@ def r1_count_xml(prompts: list, completions: list, answer: list, **kwargs) -> li return scores +def get_per_token_logps(model, inputs, lengths): + # Get logits from model + logits = model(inputs).astype(mx.float32) # [batch_size, seq_len, vocab_size] + # Remove last position as it corresponds to the next token prediction + logits = logits[:, :-1, :] # [batch_size, seq_len-1, vocab_size] + targets = inputs[:, 1:] # Shift inputs to get targets + + # Process sequences individually to save memory + per_token_logps = [] + for i in range(logits.shape[0]): + # Get sequence length for this example + seq_len = int(lengths[i]) - 1 # -1 because we removed last position + + # Get logits and targets for this sequence + seq_logits = logits[i, :seq_len] # [seq_len, vocab_size] + seq_targets = targets[i, :seq_len] # [seq_len] + + # Compute log probabilities + log_probs = nn.log_softmax(seq_logits, axis=-1) # [seq_len, vocab_size] + + # Gather log probs for actual tokens + token_log_probs = mx.take_along_axis( + log_probs, + seq_targets.reshape(seq_len, 1), + axis=-1 + ).squeeze(-1) # [seq_len] + + per_token_logps.append(token_log_probs) + + # Clean up intermediates + del seq_logits, seq_targets, log_probs, token_log_probs + mx.metal.clear_cache() + + return per_token_logps + + def grpo_loss( model, tokenizer, @@ -248,24 +284,30 @@ def grpo_loss( targets = inputs[:, 1:] # Current policy probabilities - token_log_probs = mx.take_along_axis( - log_probs, - targets.reshape(*targets.shape, 1), - axis=-1 - ).squeeze(-1) + token_log_probs = get_per_token_logps(model, inputs, lengths) # Reference policy probabilities if ref_model is not None: - ref_logits = ref_model(inputs).astype(mx.float32) + ref_token_log_probs = get_per_token_logps(ref_model, inputs, lengths) else: - ref_logits = mx.array(logits) + ref_token_log_probs = token_log_probs - ref_log_probs = nn.log_softmax(ref_logits[:, :-1, :], axis=-1) - ref_token_log_probs = mx.take_along_axis( - ref_log_probs, - targets.reshape(*targets.shape, 1), - axis=-1 - ).squeeze(-1) + max_len = max(x.shape[0] for x in token_log_probs) + padded_log_probs = [] + padded_ref_log_probs = [] + + for i in range(len(token_log_probs)): + seq_len = token_log_probs[i].shape[0] + padding = mx.zeros((max_len - seq_len,), dtype=mx.float32) + + padded_log_probs.append(mx.concatenate([token_log_probs[i], padding])) + padded_ref_log_probs.append(mx.concatenate([ref_token_log_probs[i], padding])) + + del padding + mx.metal.clear_cache() + + token_log_probs = mx.stack(padded_log_probs) + ref_token_log_probs = mx.stack(padded_ref_log_probs) # Calculate rewards and advantages rewards = mx.zeros((len(all_completions),))