From c51b0a2715cf75f16a514bfb01d473a4676f70c7 Mon Sep 17 00:00:00 2001 From: Goekdeniz-Guelmez Date: Sat, 22 Feb 2025 00:21:47 +0100 Subject: [PATCH] fix --- llms/mlx_lm/tuner/grpo_trainer.py | 130 ++++++++++++++++-------------- 1 file changed, 71 insertions(+), 59 deletions(-) diff --git a/llms/mlx_lm/tuner/grpo_trainer.py b/llms/mlx_lm/tuner/grpo_trainer.py index 3e581d13..12553b8a 100644 --- a/llms/mlx_lm/tuner/grpo_trainer.py +++ b/llms/mlx_lm/tuner/grpo_trainer.py @@ -76,10 +76,30 @@ def r1_accuracy_reward_func(prompts: list, completions: list, answer: list, **kw def r1_soft_format_reward_func(prompts: list, completions: list, answer: list, **kwargs) -> list[float]: if not completions: return [0.0] * len(prompts) - has_think = r".*" - has_answer = r".*" - matches = [(bool(re.search(has_think, r)) and bool(re.search(has_answer, r))) if r else False for r in completions] - return [0.5 if match else 0.0 for match in matches] + + scores = [] + for completion in completions: + if not completion: + scores.append(0.0) + continue + + reason_start = completion.find("") + reason_end = completion.find("") + answer_start = completion.find("") + answer_end = completion.find("") + + if (reason_start != -1 and reason_end != -1 and + answer_start != -1 and answer_end != -1 and + reason_start < reason_end < answer_start < answer_end): + reason_content = completion[reason_start+13:reason_end].strip() + answer_content = completion[answer_start+8:answer_end].strip() + if reason_content and answer_content: + scores.append(0.5) + continue + + scores.append(0.0) + + return scores def r1_strict_format_reward_func(prompts: list, completions: list, answer: list, **kwargs) -> list[float]: if not completions: @@ -110,7 +130,7 @@ def r1_count_xml(prompts: list, completions: list, answer: list, **kwargs) -> li count -= len(end_text) * 0.001 if len(end_text) > 0 else 0 scores.append(max(0.0, count)) return scores - + def generate_grpo(model: nn.Module, prompts, max_tokens, tokenizer, group_size, is_training=False): if len(prompts.shape) == 1: @@ -118,53 +138,49 @@ def generate_grpo(model: nn.Module, prompts, max_tokens, tokenizer, group_size, if prompts.shape[1] == 0: return None + model.eval() batch_size = prompts.shape[0] * group_size expanded_prompts = mx.repeat(prompts, group_size, axis=0) - mx.eval(expanded_prompts) + end_sequence = mx.array(tokenizer.encode("")) results = [] tokens_generated = 0 start_time = time.perf_counter() - for idx in range(batch_size): - current_prompt = expanded_prompts[idx:idx+1] - mx.eval(current_prompt) - - current_tokens = [] - try: + try: + for idx in range(batch_size): + current_tokens = [] + if is_training: - # Initialize with prompt - current_input = current_prompt[0] - mx.eval(current_input) - + current_input = expanded_prompts[idx] while len(current_tokens) < max_tokens: - # Generate one token at a time - logits = model(current_input[None]) - next_token = mx.random.categorical(logits[:, -1, :]) + logits = model(current_input[None])[:, -1] + next_token = mx.argmax(logits, axis=-1) token = next_token.item() current_tokens.append(token) tokens_generated += 1 - # Clear intermediate results - mx.eval(next_token) - del logits - if token == tokenizer.eos_token_id: break - # Update input for next iteration - current_input = mx.array([token]) - mx.eval(current_input) + if (len(current_tokens) >= len(end_sequence) and + mx.array_equal( + mx.array(current_tokens[-len(end_sequence):]), + end_sequence + )): + break - # Clear cache periodically - if len(current_tokens) % 8 == 0: + current_input = mx.concatenate([current_input, mx.array([token])]) + + if len(current_tokens) % 32 == 0: + mx.eval(current_input) mx.metal.clear_cache() else: generator = generate_step( - current_prompt[0], + expanded_prompts[idx], model, max_tokens=max_tokens, - sampler=lambda x: mx.random.categorical(x) + sampler=lambda x: mx.argmax(x, axis=-1) ) for token, _ in generator: @@ -174,28 +190,18 @@ def generate_grpo(model: nn.Module, prompts, max_tokens, tokenizer, group_size, break if current_tokens: - token_array = mx.array(current_tokens) - mx.eval(token_array) - results.append(token_array) - del token_array - - except Exception as e: - print(f"Generation failed for sequence {idx}: {e}") - continue + results.append(mx.array(current_tokens)) + mx.metal.clear_cache() - mx.metal.clear_cache() + mx.eval(results) + generation_time = time.perf_counter() - start_time + print(f"Generated {tokens_generated} tokens in {generation_time:.2f}s ({tokens_generated/generation_time:.2f} tokens/s)") + return results - if not results: - print("No successful generations") + except Exception as e: + print(f"Generation error: {str(e)}") return None - mx.eval(results) - - generation_time = time.perf_counter() - start_time - print(f"Generated {tokens_generated} tokens in {generation_time:.2f}s ({tokens_generated/generation_time:.2f} tokens/s)") - - return results - def get_per_token_logps(model: nn.Module, inputs, lengths): logits = model(inputs).astype(mx.float16) @@ -243,15 +249,23 @@ def grpo_loss( prompt_tensor = mx.stack([mx.array(p) for p in batch_prompts]) try: - completions = generate_grpo( - model, - prompt_tensor, - max_tokens, - tokenizer, - group_size, - True - ) - + if is_validation: + completions = generate_grpo( + model, + prompt_tensor, + max_tokens, + tokenizer, + group_size + ) + else: + completions = generate_grpo( + model, + prompt_tensor, + max_tokens, + tokenizer, + group_size, + is_training=True + ) if completions is not None: for completion_ids in completions: completion_text = tokenizer.decode(completion_ids.tolist()) @@ -261,8 +275,6 @@ def grpo_loss( except Exception as e: print(f"Generation error: {e}") continue - - mx.metal.clear_cache() expanded_answers = [] expanded_prompts = []