diff --git a/llms/mlx_lm/tuner/grpo_trainer.py b/llms/mlx_lm/tuner/grpo_trainer.py index 954eb81c..9d938df8 100644 --- a/llms/mlx_lm/tuner/grpo_trainer.py +++ b/llms/mlx_lm/tuner/grpo_trainer.py @@ -13,6 +13,7 @@ import numpy as np from .trainer import grad_checkpoint, TrainingArgs, TrainingCallback, average_gradients from ..utils import generate_step +from ..models import cache @dataclass class GRPOTrainingArgs(TrainingArgs): @@ -96,15 +97,13 @@ def r1_soft_format_reward_func(prompts: list, completions: list, answer: list, * if reason_content and answer_content: scores.append(0.5) continue - scores.append(0.0) - return scores def r1_strict_format_reward_func(prompts: list, completions: list, answer: list, **kwargs) -> list[float]: if not completions: return [0.0] * len(prompts) - pattern = r"^\n.*?\n\n\n.*?\n\n$" + pattern = r"\n.*?\n\n*?" matches = [bool(re.search(pattern, r)) if r else False for r in completions] return [0.5 if match else 0.0 for match in matches] @@ -120,28 +119,31 @@ def r1_count_xml(prompts: list, completions: list, answer: list, **kwargs) -> li count = 0.0 if text.count("\n") == 1: count += 0.125 - if text.count("\n\n") == 1: + if text.count("") == 1: count += 0.125 - if text.count("\n\n") == 1: + if text.count("") == 1: count += 0.125 - if text.count("\n\n") == 1: + if text.count("") == 1: count += 0.125 - end_text = text.split("\n\n")[-1] + end_text = text.split("")[-1] count -= len(end_text) * 0.001 if len(end_text) > 0 else 0 scores.append(max(0.0, count)) return scores -def generate_grpo(model: nn.Module, prompts, max_tokens, tokenizer, group_size, is_training=False): +def generate_grpo(model: nn.Module, prompts, max_tokens, tokenizer, group_size, is_training=False, end_token: str = ""): + if model.training == False: + print("Model is in training mode", model.training, "Manually setting to eval mode") + model.train() + if len(prompts.shape) == 1: prompts = prompts[None, :] if prompts.shape[1] == 0: return None - model.eval() batch_size = prompts.shape[0] * group_size expanded_prompts = mx.repeat(prompts, group_size, axis=0) - end_sequence = mx.array(tokenizer.encode("")) + end_sequence = mx.array(tokenizer.encode(end_token)) results = [] tokens_generated = 0 @@ -153,13 +155,15 @@ def generate_grpo(model: nn.Module, prompts, max_tokens, tokenizer, group_size, if is_training: current_input = expanded_prompts[idx] + prompt_cache = cache.make_prompt_cache(model) + + # Initial forward pass with the prompt + logits = model(current_input[None], cache=prompt_cache)[:, -1] + while len(current_tokens) < max_tokens: - logits = model(current_input[None])[:, -1] probs = nn.softmax(logits, axis=-1) next_token = mx.argmax(probs, axis=-1) token = next_token.item() - current_tokens.append(token) - tokens_generated += 1 if token == tokenizer.eos_token_id: break if (len(current_tokens) >= len(end_sequence) and @@ -168,10 +172,13 @@ def generate_grpo(model: nn.Module, prompts, max_tokens, tokenizer, group_size, end_sequence )): break - current_input = mx.concatenate([current_input, mx.array([token])]) - if len(current_tokens) % 32 == 0: - mx.eval(current_input) - mx.metal.clear_cache() + + current_tokens.append(token) + tokens_generated += 1 + current_input = mx.array([token]) + logits = model(current_input[None], cache=prompt_cache)[:, -1] + mx.eval(current_input) + mx.metal.clear_cache() else: generator = generate_step( expanded_prompts[idx], @@ -180,10 +187,10 @@ def generate_grpo(model: nn.Module, prompts, max_tokens, tokenizer, group_size, sampler=lambda x: mx.argmax(x, axis=-1) ) for token, _ in generator: - current_tokens.append(token) - tokens_generated += 1 if token == tokenizer.eos_token_id: break + current_tokens.append(token) + tokens_generated += 1 if current_tokens: results.append(mx.array(current_tokens)) @@ -401,7 +408,7 @@ def grpo_loss( if is_validation: print(f"\nValidation sample generation:\n{all_completion_texts[-1]}\n") - + print(f"Validation sample answer:\n{answer_text[-1]}\n") mx.metal.clear_cache() return loss, sequence_lengths.sum(), metrics