mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-26 02:33:23 +08:00
adding custom system message integration in dataset, more opimizations (generates now faster, while same RAM usage), fix for the identical generatrions, seperated the reward functions into a seperate file.
This commit is contained in:
parent
bd5f081ca5
commit
e4eac9c97b
@ -19,6 +19,7 @@ class GRPODataset:
|
||||
tokenizer: PreTrainedTokenizer,
|
||||
prompt_key: str = "prompt",
|
||||
answer_key: str = "answer",
|
||||
system_key: str = "system",
|
||||
use_chat_template: bool = False,
|
||||
use_prompt: bool = False
|
||||
):
|
||||
@ -27,9 +28,11 @@ class GRPODataset:
|
||||
prompt_str = str(item[prompt_key])
|
||||
answer_str = str(item[answer_key])
|
||||
if use_chat_template:
|
||||
default_system_str = "A conversation between User and Assistant. The user asks a question, and the Assistant solves it. The assistant first thinks about the reasoning process in the mind and then provides the user with the answer. The reasoning process and answer are enclosed within <think> </think> and <answer> </answer> tags, respectively, i.e., <think> reasoning process here </think><answer> answer here </answer>."
|
||||
system_str = item.get(system_key, default_system_str)
|
||||
prompt_tokens = tokenizer.apply_chat_template(
|
||||
[
|
||||
{'role': 'system', 'content': """A conversation between User and Assistant. The user asks a question, and the Assistant solves it. The assistant first thinks about the reasoning process in the mind and then provides the user with the answer. The reasoning process and answer are enclosed within <think> </think> and <answer> </answer> tags, respectively, i.e., <think> reasoning process here </think><answer> answer here </answer>."""},
|
||||
{'role': 'system', 'content': system_str},
|
||||
{'role': 'user', 'content': prompt_str}
|
||||
],
|
||||
add_generation_prompt=True
|
||||
|
82
llms/mlx_lm/tuner/grpo_reward_functions.py
Normal file
82
llms/mlx_lm/tuner/grpo_reward_functions.py
Normal file
@ -0,0 +1,82 @@
|
||||
from typing import List, Optional, Callable
|
||||
import re
|
||||
|
||||
|
||||
RewardFunctions = Callable[[List[str], List[str], List[str]], List[float]]
|
||||
|
||||
|
||||
def r1_extract_xml_answer(text: str) -> str:
|
||||
try:
|
||||
answer = text.split("<answer>")[-1]
|
||||
answer = answer.split("</answer>")[0]
|
||||
return answer.strip()
|
||||
except:
|
||||
print("r1_extract_xml_answer returned empty string")
|
||||
return ""
|
||||
|
||||
def r1_int_reward_func(prompts: list, completions: list, answer: list, **kwargs) -> list[float]:
|
||||
if not completions:
|
||||
return [0.0] * len(prompts)
|
||||
extracted_responses = [r1_extract_xml_answer(r) for r in completions]
|
||||
return [0.5 if r and r.isdigit() else 0.0 for r in extracted_responses]
|
||||
|
||||
def r1_accuracy_reward_func(prompts: list, completions: list, answer: list, **kwargs) -> list[float]:
|
||||
if not completions or not answer:
|
||||
return [0.0] * len(prompts)
|
||||
extracted_responses = [r1_extract_xml_answer(r) for r in completions]
|
||||
return [2.0 if r and a and r == a else 0.0 for r, a in zip(extracted_responses, answer)]
|
||||
|
||||
def r1_soft_format_reward_func(prompts: list, completions: list, answer: list, **kwargs) -> list[float]:
|
||||
if not completions:
|
||||
return [0.0] * len(prompts)
|
||||
|
||||
scores = []
|
||||
for completion in completions:
|
||||
if not completion:
|
||||
scores.append(0.0)
|
||||
continue
|
||||
|
||||
reason_start = completion.find("<think>")
|
||||
reason_end = completion.find("</think>")
|
||||
answer_start = completion.find("<answer>")
|
||||
answer_end = completion.find("</answer>")
|
||||
|
||||
if (reason_start != -1 and reason_end != -1 and
|
||||
answer_start != -1 and answer_end != -1 and
|
||||
reason_start < reason_end < answer_start < answer_end):
|
||||
reason_content = completion[reason_start+13:reason_end].strip()
|
||||
answer_content = completion[answer_start+8:answer_end].strip()
|
||||
if reason_content and answer_content:
|
||||
scores.append(0.5)
|
||||
continue
|
||||
scores.append(0.0)
|
||||
return scores
|
||||
|
||||
def r1_strict_format_reward_func(prompts: list, completions: list, answer: list, **kwargs) -> list[float]:
|
||||
if not completions:
|
||||
return [0.0] * len(prompts)
|
||||
pattern = r"<think>\n.*?\n</think>\n<answer>*?</answer>"
|
||||
matches = [bool(re.search(pattern, r)) if r else False for r in completions]
|
||||
return [0.5 if match else 0.0 for match in matches]
|
||||
|
||||
def r1_count_xml(prompts: list, completions: list, answer: list, **kwargs) -> list[float]:
|
||||
if not completions:
|
||||
return [0.0] * len(prompts)
|
||||
scores = []
|
||||
for text in completions:
|
||||
if not text:
|
||||
scores.append(0.0)
|
||||
continue
|
||||
count = 0.0
|
||||
if text.count("<think>\n") == 1:
|
||||
count += 0.125
|
||||
if text.count("</think>") == 1:
|
||||
count += 0.125
|
||||
if text.count("<answer>") == 1:
|
||||
count += 0.125
|
||||
if text.count("</answer>") == 1:
|
||||
count += 0.125
|
||||
end_text = text.split("</answer>")[-1]
|
||||
count -= len(end_text) * 0.001 if len(end_text) > 0 else 0
|
||||
scores.append(max(0.0, count))
|
||||
return scores
|
@ -1,16 +1,16 @@
|
||||
# Copyright © 2024 Apple Inc.
|
||||
|
||||
from typing import List, Optional, Callable
|
||||
from typing import List, Optional, Tuple, Generator, Callable, Any
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
import time
|
||||
import re
|
||||
|
||||
from mlx.utils import tree_flatten
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
import numpy as np
|
||||
|
||||
from .grpo_reward_functions import r1_accuracy_reward_func, r1_int_reward_func, r1_strict_format_reward_func, r1_soft_format_reward_func, r1_count_xml, RewardFunctions
|
||||
from .trainer import grad_checkpoint, TrainingArgs, TrainingCallback, average_gradients
|
||||
from ..utils import generate_step
|
||||
from ..models import cache
|
||||
@ -50,88 +50,16 @@ class GRPOTrainingArgs(TrainingArgs):
|
||||
)
|
||||
|
||||
|
||||
RewardFunctions = Callable[[List[str], List[str], List[str]], List[float]]
|
||||
|
||||
|
||||
def r1_extract_xml_answer(text: str) -> str:
|
||||
try:
|
||||
answer = text.split("<answer>")[-1]
|
||||
answer = answer.split("</answer>")[0]
|
||||
return answer.strip()
|
||||
except:
|
||||
print("r1_extract_xml_answer returned empty string")
|
||||
return ""
|
||||
|
||||
def r1_int_reward_func(prompts: list, completions: list, answer: list, **kwargs) -> list[float]:
|
||||
if not completions:
|
||||
return [0.0] * len(prompts)
|
||||
extracted_responses = [r1_extract_xml_answer(r) for r in completions]
|
||||
return [0.5 if r and r.isdigit() else 0.0 for r in extracted_responses]
|
||||
|
||||
def r1_accuracy_reward_func(prompts: list, completions: list, answer: list, **kwargs) -> list[float]:
|
||||
if not completions or not answer:
|
||||
return [0.0] * len(prompts)
|
||||
extracted_responses = [r1_extract_xml_answer(r) for r in completions]
|
||||
return [2.0 if r and a and r == a else 0.0 for r, a in zip(extracted_responses, answer)]
|
||||
|
||||
def r1_soft_format_reward_func(prompts: list, completions: list, answer: list, **kwargs) -> list[float]:
|
||||
if not completions:
|
||||
return [0.0] * len(prompts)
|
||||
|
||||
scores = []
|
||||
for completion in completions:
|
||||
if not completion:
|
||||
scores.append(0.0)
|
||||
continue
|
||||
|
||||
reason_start = completion.find("<think>")
|
||||
reason_end = completion.find("</think>")
|
||||
answer_start = completion.find("<answer>")
|
||||
answer_end = completion.find("</answer>")
|
||||
|
||||
if (reason_start != -1 and reason_end != -1 and
|
||||
answer_start != -1 and answer_end != -1 and
|
||||
reason_start < reason_end < answer_start < answer_end):
|
||||
reason_content = completion[reason_start+13:reason_end].strip()
|
||||
answer_content = completion[answer_start+8:answer_end].strip()
|
||||
if reason_content and answer_content:
|
||||
scores.append(0.5)
|
||||
continue
|
||||
scores.append(0.0)
|
||||
return scores
|
||||
|
||||
def r1_strict_format_reward_func(prompts: list, completions: list, answer: list, **kwargs) -> list[float]:
|
||||
if not completions:
|
||||
return [0.0] * len(prompts)
|
||||
pattern = r"<think>\n.*?\n</think>\n<answer>*?</answer>"
|
||||
matches = [bool(re.search(pattern, r)) if r else False for r in completions]
|
||||
return [0.5 if match else 0.0 for match in matches]
|
||||
|
||||
|
||||
def r1_count_xml(prompts: list, completions: list, answer: list, **kwargs) -> list[float]:
|
||||
if not completions:
|
||||
return [0.0] * len(prompts)
|
||||
scores = []
|
||||
for text in completions:
|
||||
if not text:
|
||||
scores.append(0.0)
|
||||
continue
|
||||
count = 0.0
|
||||
if text.count("<think>\n") == 1:
|
||||
count += 0.125
|
||||
if text.count("</think>") == 1:
|
||||
count += 0.125
|
||||
if text.count("<answer>") == 1:
|
||||
count += 0.125
|
||||
if text.count("</answer>") == 1:
|
||||
count += 0.125
|
||||
end_text = text.split("</answer>")[-1]
|
||||
count -= len(end_text) * 0.001 if len(end_text) > 0 else 0
|
||||
scores.append(max(0.0, count))
|
||||
return scores
|
||||
|
||||
|
||||
def generate_grpo(model: nn.Module, prompts, max_tokens, tokenizer, group_size, is_training=False, end_token: str = "</answer>", temperature: float = 0.8):
|
||||
def generate_grpo(
|
||||
model: nn.Module,
|
||||
prompts,
|
||||
max_tokens,
|
||||
tokenizer,
|
||||
group_size,
|
||||
is_training=False,
|
||||
end_token: str = "</answer>",
|
||||
temperature: float = 0.8
|
||||
):
|
||||
if len(prompts.shape) == 1:
|
||||
prompts = prompts[None, :]
|
||||
if prompts.shape[1] == 0:
|
||||
@ -144,19 +72,38 @@ def generate_grpo(model: nn.Module, prompts, max_tokens, tokenizer, group_size,
|
||||
try:
|
||||
for idx in range(batch_size):
|
||||
current_tokens = []
|
||||
|
||||
if is_training:
|
||||
current_input = expanded_prompts[idx]
|
||||
prompt_cache = cache.make_prompt_cache(model)
|
||||
logits = model(current_input[None], cache=prompt_cache)[:, -1]
|
||||
mx.eval(logits, prompt_cache)
|
||||
|
||||
while len(current_tokens) < max_tokens:
|
||||
logits_temp = logits / temperature
|
||||
probs = nn.softmax(logits_temp, axis=-1)
|
||||
next_token = mx.argmax(probs, axis=-1)
|
||||
next_token = mx.random.categorical(logits_temp)
|
||||
token = next_token.item()
|
||||
|
||||
test_sequence = current_tokens + [token]
|
||||
if (len(test_sequence) >= len(end_sequence) and
|
||||
mx.array_equal(
|
||||
mx.array(test_sequence[-len(end_sequence):]),
|
||||
end_sequence
|
||||
)):
|
||||
current_tokens.append(token)
|
||||
break
|
||||
if token == tokenizer.eos_token_id:
|
||||
break
|
||||
current_tokens.append(token)
|
||||
current_input = mx.array([token])
|
||||
logits = model(current_input[None], cache=prompt_cache)[:, -1]
|
||||
mx.eval(current_input, logits, probs, next_token, token)
|
||||
else:
|
||||
generator = generate_step(
|
||||
expanded_prompts[idx],
|
||||
model,
|
||||
max_tokens=max_tokens,
|
||||
sampler=lambda x: mx.random.categorical(x / temperature)
|
||||
)
|
||||
for token, _ in generator:
|
||||
test_sequence = current_tokens + [token]
|
||||
if (len(test_sequence) >= len(end_sequence) and
|
||||
mx.array_equal(
|
||||
@ -168,27 +115,10 @@ def generate_grpo(model: nn.Module, prompts, max_tokens, tokenizer, group_size,
|
||||
|
||||
if token == tokenizer.eos_token_id:
|
||||
break
|
||||
|
||||
current_tokens.append(token)
|
||||
current_input = mx.array([token])
|
||||
logits = model(current_input[None], cache=prompt_cache)[:, -1]
|
||||
mx.eval(current_input, logits, probs, next_token)
|
||||
else:
|
||||
generator = generate_step(
|
||||
expanded_prompts[idx],
|
||||
model,
|
||||
max_tokens=max_tokens,
|
||||
sampler=lambda x: mx.argmax(x, axis=-1)
|
||||
)
|
||||
for token, _ in generator:
|
||||
if token == tokenizer.eos_token_id:
|
||||
break
|
||||
current_tokens.append(token)
|
||||
|
||||
if current_tokens:
|
||||
results.append(mx.array(current_tokens))
|
||||
mx.metal.clear_cache()
|
||||
|
||||
mx.eval(results)
|
||||
return results
|
||||
|
||||
@ -207,12 +137,10 @@ def get_per_token_logps(model: nn.Module, inputs, lengths):
|
||||
seq_logits = logits[i, :seq_len]
|
||||
seq_targets = targets[i, :seq_len]
|
||||
log_probs = nn.log_softmax(seq_logits, axis=-1)
|
||||
|
||||
token_log_probs = mx.take_along_axis(
|
||||
log_probs,
|
||||
seq_targets.reshape(seq_len, 1), axis=-1
|
||||
).squeeze(-1)
|
||||
|
||||
per_token_logps.append(token_log_probs)
|
||||
mx.eval(logits)
|
||||
return per_token_logps
|
||||
@ -399,7 +327,7 @@ def grpo_loss(
|
||||
}
|
||||
|
||||
if is_validation:
|
||||
print(f"\nValidation sample generation:\n{all_completion_texts[-1]}\n")
|
||||
print(f"\nValidation sample generation:\n{all_completion_texts}\n")
|
||||
print(f"Validation sample answer:\n{answer_text[-1]}\n")
|
||||
mx.metal.clear_cache()
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user