From d9c4c6e60cba40feff013676fef1cf04971b9543 Mon Sep 17 00:00:00 2001 From: Goekdeniz-Guelmez Date: Sat, 22 Feb 2025 02:34:56 +0100 Subject: [PATCH] clean up and readding temperature argument --- llms/mlx_lm/tuner/grpo_trainer.py | 25 ++++++++++++------------- 1 file changed, 12 insertions(+), 13 deletions(-) diff --git a/llms/mlx_lm/tuner/grpo_trainer.py b/llms/mlx_lm/tuner/grpo_trainer.py index 8b20e71f..9d2dfd42 100644 --- a/llms/mlx_lm/tuner/grpo_trainer.py +++ b/llms/mlx_lm/tuner/grpo_trainer.py @@ -131,7 +131,7 @@ def r1_count_xml(prompts: list, completions: list, answer: list, **kwargs) -> li return scores -def generate_grpo(model: nn.Module, prompts, max_tokens, tokenizer, group_size, is_training=False, end_token: str = ""): +def generate_grpo(model: nn.Module, prompts, max_tokens, tokenizer, group_size, is_training=False, end_token: str = "", temperature: float = 0.8): if model.training == False: print("Model is in training mode", model.training, "Manually setting to eval mode") model.train() @@ -146,8 +146,6 @@ def generate_grpo(model: nn.Module, prompts, max_tokens, tokenizer, group_size, end_sequence = mx.array(tokenizer.encode(end_token)) results = [] - tokens_generated = 0 - start_time = time.perf_counter() try: for idx in range(batch_size): @@ -160,20 +158,23 @@ def generate_grpo(model: nn.Module, prompts, max_tokens, tokenizer, group_size, logits = model(current_input[None], cache=prompt_cache)[:, -1] while len(current_tokens) < max_tokens: - probs = nn.softmax(logits, axis=-1) + logits_temp = logits / temperature + probs = nn.softmax(logits_temp, axis=-1) next_token = mx.argmax(probs, axis=-1) token = next_token.item() - if token == tokenizer.eos_token_id: - break - if (len(current_tokens) >= len(end_sequence) and + + test_sequence = current_tokens + [token] + if (len(test_sequence) >= len(end_sequence) and mx.array_equal( - mx.array(current_tokens[-len(end_sequence):]), + mx.array(test_sequence[-len(end_sequence):]), end_sequence )): break + if token == tokenizer.eos_token_id: + break + current_tokens.append(token) - tokens_generated += 1 current_input = mx.array([token]) logits = model(current_input[None], cache=prompt_cache)[:, -1] mx.eval(current_input) @@ -189,15 +190,12 @@ def generate_grpo(model: nn.Module, prompts, max_tokens, tokenizer, group_size, if token == tokenizer.eos_token_id: break current_tokens.append(token) - tokens_generated += 1 if current_tokens: results.append(mx.array(current_tokens)) mx.metal.clear_cache() mx.eval(results) - generation_time = time.perf_counter() - start_time - print(f"Generated {tokens_generated} tokens in {generation_time:.2f}s ({tokens_generated/generation_time:.2f} tokens/s)") return results except Exception as e: @@ -266,7 +264,8 @@ def grpo_loss( max_tokens, tokenizer, group_size, - is_training=True + is_training=True, + temperature=temperature ) if completions is not None: for completion_ids in completions: