From e4eac9c97b68b0cbf006fd314769fce723bedfe0 Mon Sep 17 00:00:00 2001 From: Goekdeniz-Guelmez Date: Mon, 24 Feb 2025 20:49:11 +0100 Subject: [PATCH] adding custom system message integration in dataset, more opimizations (generates now faster, while same RAM usage), fix for the identical generatrions, seperated the reward functions into a seperate file. --- llms/mlx_lm/tuner/datasets.py | 5 +- llms/mlx_lm/tuner/grpo_reward_functions.py | 82 ++++++++++++ llms/mlx_lm/tuner/grpo_trainer.py | 144 ++++++--------------- 3 files changed, 122 insertions(+), 109 deletions(-) create mode 100644 llms/mlx_lm/tuner/grpo_reward_functions.py diff --git a/llms/mlx_lm/tuner/datasets.py b/llms/mlx_lm/tuner/datasets.py index abdd7c36..5bf17ef8 100644 --- a/llms/mlx_lm/tuner/datasets.py +++ b/llms/mlx_lm/tuner/datasets.py @@ -19,6 +19,7 @@ class GRPODataset: tokenizer: PreTrainedTokenizer, prompt_key: str = "prompt", answer_key: str = "answer", + system_key: str = "system", use_chat_template: bool = False, use_prompt: bool = False ): @@ -27,9 +28,11 @@ class GRPODataset: prompt_str = str(item[prompt_key]) answer_str = str(item[answer_key]) if use_chat_template: + default_system_str = "A conversation between User and Assistant. The user asks a question, and the Assistant solves it. The assistant first thinks about the reasoning process in the mind and then provides the user with the answer. The reasoning process and answer are enclosed within and tags, respectively, i.e., reasoning process here answer here ." + system_str = item.get(system_key, default_system_str) prompt_tokens = tokenizer.apply_chat_template( [ - {'role': 'system', 'content': """A conversation between User and Assistant. The user asks a question, and the Assistant solves it. The assistant first thinks about the reasoning process in the mind and then provides the user with the answer. The reasoning process and answer are enclosed within and tags, respectively, i.e., reasoning process here answer here ."""}, + {'role': 'system', 'content': system_str}, {'role': 'user', 'content': prompt_str} ], add_generation_prompt=True diff --git a/llms/mlx_lm/tuner/grpo_reward_functions.py b/llms/mlx_lm/tuner/grpo_reward_functions.py new file mode 100644 index 00000000..59dfbfef --- /dev/null +++ b/llms/mlx_lm/tuner/grpo_reward_functions.py @@ -0,0 +1,82 @@ +from typing import List, Optional, Callable +import re + + +RewardFunctions = Callable[[List[str], List[str], List[str]], List[float]] + + +def r1_extract_xml_answer(text: str) -> str: + try: + answer = text.split("")[-1] + answer = answer.split("")[0] + return answer.strip() + except: + print("r1_extract_xml_answer returned empty string") + return "" + +def r1_int_reward_func(prompts: list, completions: list, answer: list, **kwargs) -> list[float]: + if not completions: + return [0.0] * len(prompts) + extracted_responses = [r1_extract_xml_answer(r) for r in completions] + return [0.5 if r and r.isdigit() else 0.0 for r in extracted_responses] + +def r1_accuracy_reward_func(prompts: list, completions: list, answer: list, **kwargs) -> list[float]: + if not completions or not answer: + return [0.0] * len(prompts) + extracted_responses = [r1_extract_xml_answer(r) for r in completions] + return [2.0 if r and a and r == a else 0.0 for r, a in zip(extracted_responses, answer)] + +def r1_soft_format_reward_func(prompts: list, completions: list, answer: list, **kwargs) -> list[float]: + if not completions: + return [0.0] * len(prompts) + + scores = [] + for completion in completions: + if not completion: + scores.append(0.0) + continue + + reason_start = completion.find("") + reason_end = completion.find("") + answer_start = completion.find("") + answer_end = completion.find("") + + if (reason_start != -1 and reason_end != -1 and + answer_start != -1 and answer_end != -1 and + reason_start < reason_end < answer_start < answer_end): + reason_content = completion[reason_start+13:reason_end].strip() + answer_content = completion[answer_start+8:answer_end].strip() + if reason_content and answer_content: + scores.append(0.5) + continue + scores.append(0.0) + return scores + +def r1_strict_format_reward_func(prompts: list, completions: list, answer: list, **kwargs) -> list[float]: + if not completions: + return [0.0] * len(prompts) + pattern = r"\n.*?\n\n*?" + matches = [bool(re.search(pattern, r)) if r else False for r in completions] + return [0.5 if match else 0.0 for match in matches] + +def r1_count_xml(prompts: list, completions: list, answer: list, **kwargs) -> list[float]: + if not completions: + return [0.0] * len(prompts) + scores = [] + for text in completions: + if not text: + scores.append(0.0) + continue + count = 0.0 + if text.count("\n") == 1: + count += 0.125 + if text.count("") == 1: + count += 0.125 + if text.count("") == 1: + count += 0.125 + if text.count("") == 1: + count += 0.125 + end_text = text.split("")[-1] + count -= len(end_text) * 0.001 if len(end_text) > 0 else 0 + scores.append(max(0.0, count)) + return scores \ No newline at end of file diff --git a/llms/mlx_lm/tuner/grpo_trainer.py b/llms/mlx_lm/tuner/grpo_trainer.py index 9d9051f0..5ec3020a 100644 --- a/llms/mlx_lm/tuner/grpo_trainer.py +++ b/llms/mlx_lm/tuner/grpo_trainer.py @@ -1,16 +1,16 @@ # Copyright © 2024 Apple Inc. -from typing import List, Optional, Callable +from typing import List, Optional, Tuple, Generator, Callable, Any from dataclasses import dataclass, field from pathlib import Path import time -import re from mlx.utils import tree_flatten import mlx.core as mx import mlx.nn as nn import numpy as np +from .grpo_reward_functions import r1_accuracy_reward_func, r1_int_reward_func, r1_strict_format_reward_func, r1_soft_format_reward_func, r1_count_xml, RewardFunctions from .trainer import grad_checkpoint, TrainingArgs, TrainingCallback, average_gradients from ..utils import generate_step from ..models import cache @@ -50,88 +50,16 @@ class GRPOTrainingArgs(TrainingArgs): ) -RewardFunctions = Callable[[List[str], List[str], List[str]], List[float]] - - -def r1_extract_xml_answer(text: str) -> str: - try: - answer = text.split("")[-1] - answer = answer.split("")[0] - return answer.strip() - except: - print("r1_extract_xml_answer returned empty string") - return "" - -def r1_int_reward_func(prompts: list, completions: list, answer: list, **kwargs) -> list[float]: - if not completions: - return [0.0] * len(prompts) - extracted_responses = [r1_extract_xml_answer(r) for r in completions] - return [0.5 if r and r.isdigit() else 0.0 for r in extracted_responses] - -def r1_accuracy_reward_func(prompts: list, completions: list, answer: list, **kwargs) -> list[float]: - if not completions or not answer: - return [0.0] * len(prompts) - extracted_responses = [r1_extract_xml_answer(r) for r in completions] - return [2.0 if r and a and r == a else 0.0 for r, a in zip(extracted_responses, answer)] - -def r1_soft_format_reward_func(prompts: list, completions: list, answer: list, **kwargs) -> list[float]: - if not completions: - return [0.0] * len(prompts) - - scores = [] - for completion in completions: - if not completion: - scores.append(0.0) - continue - - reason_start = completion.find("") - reason_end = completion.find("") - answer_start = completion.find("") - answer_end = completion.find("") - - if (reason_start != -1 and reason_end != -1 and - answer_start != -1 and answer_end != -1 and - reason_start < reason_end < answer_start < answer_end): - reason_content = completion[reason_start+13:reason_end].strip() - answer_content = completion[answer_start+8:answer_end].strip() - if reason_content and answer_content: - scores.append(0.5) - continue - scores.append(0.0) - return scores - -def r1_strict_format_reward_func(prompts: list, completions: list, answer: list, **kwargs) -> list[float]: - if not completions: - return [0.0] * len(prompts) - pattern = r"\n.*?\n\n*?" - matches = [bool(re.search(pattern, r)) if r else False for r in completions] - return [0.5 if match else 0.0 for match in matches] - - -def r1_count_xml(prompts: list, completions: list, answer: list, **kwargs) -> list[float]: - if not completions: - return [0.0] * len(prompts) - scores = [] - for text in completions: - if not text: - scores.append(0.0) - continue - count = 0.0 - if text.count("\n") == 1: - count += 0.125 - if text.count("") == 1: - count += 0.125 - if text.count("") == 1: - count += 0.125 - if text.count("") == 1: - count += 0.125 - end_text = text.split("")[-1] - count -= len(end_text) * 0.001 if len(end_text) > 0 else 0 - scores.append(max(0.0, count)) - return scores - - -def generate_grpo(model: nn.Module, prompts, max_tokens, tokenizer, group_size, is_training=False, end_token: str = "", temperature: float = 0.8): +def generate_grpo( + model: nn.Module, + prompts, + max_tokens, + tokenizer, + group_size, + is_training=False, + end_token: str = "", + temperature: float = 0.8 + ): if len(prompts.shape) == 1: prompts = prompts[None, :] if prompts.shape[1] == 0: @@ -144,19 +72,38 @@ def generate_grpo(model: nn.Module, prompts, max_tokens, tokenizer, group_size, try: for idx in range(batch_size): current_tokens = [] - if is_training: current_input = expanded_prompts[idx] prompt_cache = cache.make_prompt_cache(model) logits = model(current_input[None], cache=prompt_cache)[:, -1] mx.eval(logits, prompt_cache) - while len(current_tokens) < max_tokens: logits_temp = logits / temperature probs = nn.softmax(logits_temp, axis=-1) - next_token = mx.argmax(probs, axis=-1) + next_token = mx.random.categorical(logits_temp) token = next_token.item() - + test_sequence = current_tokens + [token] + if (len(test_sequence) >= len(end_sequence) and + mx.array_equal( + mx.array(test_sequence[-len(end_sequence):]), + end_sequence + )): + current_tokens.append(token) + break + if token == tokenizer.eos_token_id: + break + current_tokens.append(token) + current_input = mx.array([token]) + logits = model(current_input[None], cache=prompt_cache)[:, -1] + mx.eval(current_input, logits, probs, next_token, token) + else: + generator = generate_step( + expanded_prompts[idx], + model, + max_tokens=max_tokens, + sampler=lambda x: mx.random.categorical(x / temperature) + ) + for token, _ in generator: test_sequence = current_tokens + [token] if (len(test_sequence) >= len(end_sequence) and mx.array_equal( @@ -168,27 +115,10 @@ def generate_grpo(model: nn.Module, prompts, max_tokens, tokenizer, group_size, if token == tokenizer.eos_token_id: break - current_tokens.append(token) - current_input = mx.array([token]) - logits = model(current_input[None], cache=prompt_cache)[:, -1] - mx.eval(current_input, logits, probs, next_token) - else: - generator = generate_step( - expanded_prompts[idx], - model, - max_tokens=max_tokens, - sampler=lambda x: mx.argmax(x, axis=-1) - ) - for token, _ in generator: - if token == tokenizer.eos_token_id: - break - current_tokens.append(token) - if current_tokens: results.append(mx.array(current_tokens)) mx.metal.clear_cache() - mx.eval(results) return results @@ -207,12 +137,10 @@ def get_per_token_logps(model: nn.Module, inputs, lengths): seq_logits = logits[i, :seq_len] seq_targets = targets[i, :seq_len] log_probs = nn.log_softmax(seq_logits, axis=-1) - token_log_probs = mx.take_along_axis( log_probs, seq_targets.reshape(seq_len, 1), axis=-1 ).squeeze(-1) - per_token_logps.append(token_log_probs) mx.eval(logits) return per_token_logps @@ -399,7 +327,7 @@ def grpo_loss( } if is_validation: - print(f"\nValidation sample generation:\n{all_completion_texts[-1]}\n") + print(f"\nValidation sample generation:\n{all_completion_texts}\n") print(f"Validation sample answer:\n{answer_text[-1]}\n") mx.metal.clear_cache()