mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-12-16 02:08:55 +08:00
adding custom system message integration in dataset, more opimizations (generates now faster, while same RAM usage), fix for the identical generatrions, seperated the reward functions into a seperate file.
This commit is contained in:
@@ -1,16 +1,16 @@
|
||||
# Copyright © 2024 Apple Inc.
|
||||
|
||||
from typing import List, Optional, Callable
|
||||
from typing import List, Optional, Tuple, Generator, Callable, Any
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
import time
|
||||
import re
|
||||
|
||||
from mlx.utils import tree_flatten
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
import numpy as np
|
||||
|
||||
from .grpo_reward_functions import r1_accuracy_reward_func, r1_int_reward_func, r1_strict_format_reward_func, r1_soft_format_reward_func, r1_count_xml, RewardFunctions
|
||||
from .trainer import grad_checkpoint, TrainingArgs, TrainingCallback, average_gradients
|
||||
from ..utils import generate_step
|
||||
from ..models import cache
|
||||
@@ -50,88 +50,16 @@ class GRPOTrainingArgs(TrainingArgs):
|
||||
)
|
||||
|
||||
|
||||
RewardFunctions = Callable[[List[str], List[str], List[str]], List[float]]
|
||||
|
||||
|
||||
def r1_extract_xml_answer(text: str) -> str:
|
||||
try:
|
||||
answer = text.split("<answer>")[-1]
|
||||
answer = answer.split("</answer>")[0]
|
||||
return answer.strip()
|
||||
except:
|
||||
print("r1_extract_xml_answer returned empty string")
|
||||
return ""
|
||||
|
||||
def r1_int_reward_func(prompts: list, completions: list, answer: list, **kwargs) -> list[float]:
|
||||
if not completions:
|
||||
return [0.0] * len(prompts)
|
||||
extracted_responses = [r1_extract_xml_answer(r) for r in completions]
|
||||
return [0.5 if r and r.isdigit() else 0.0 for r in extracted_responses]
|
||||
|
||||
def r1_accuracy_reward_func(prompts: list, completions: list, answer: list, **kwargs) -> list[float]:
|
||||
if not completions or not answer:
|
||||
return [0.0] * len(prompts)
|
||||
extracted_responses = [r1_extract_xml_answer(r) for r in completions]
|
||||
return [2.0 if r and a and r == a else 0.0 for r, a in zip(extracted_responses, answer)]
|
||||
|
||||
def r1_soft_format_reward_func(prompts: list, completions: list, answer: list, **kwargs) -> list[float]:
|
||||
if not completions:
|
||||
return [0.0] * len(prompts)
|
||||
|
||||
scores = []
|
||||
for completion in completions:
|
||||
if not completion:
|
||||
scores.append(0.0)
|
||||
continue
|
||||
|
||||
reason_start = completion.find("<think>")
|
||||
reason_end = completion.find("</think>")
|
||||
answer_start = completion.find("<answer>")
|
||||
answer_end = completion.find("</answer>")
|
||||
|
||||
if (reason_start != -1 and reason_end != -1 and
|
||||
answer_start != -1 and answer_end != -1 and
|
||||
reason_start < reason_end < answer_start < answer_end):
|
||||
reason_content = completion[reason_start+13:reason_end].strip()
|
||||
answer_content = completion[answer_start+8:answer_end].strip()
|
||||
if reason_content and answer_content:
|
||||
scores.append(0.5)
|
||||
continue
|
||||
scores.append(0.0)
|
||||
return scores
|
||||
|
||||
def r1_strict_format_reward_func(prompts: list, completions: list, answer: list, **kwargs) -> list[float]:
|
||||
if not completions:
|
||||
return [0.0] * len(prompts)
|
||||
pattern = r"<think>\n.*?\n</think>\n<answer>*?</answer>"
|
||||
matches = [bool(re.search(pattern, r)) if r else False for r in completions]
|
||||
return [0.5 if match else 0.0 for match in matches]
|
||||
|
||||
|
||||
def r1_count_xml(prompts: list, completions: list, answer: list, **kwargs) -> list[float]:
|
||||
if not completions:
|
||||
return [0.0] * len(prompts)
|
||||
scores = []
|
||||
for text in completions:
|
||||
if not text:
|
||||
scores.append(0.0)
|
||||
continue
|
||||
count = 0.0
|
||||
if text.count("<think>\n") == 1:
|
||||
count += 0.125
|
||||
if text.count("</think>") == 1:
|
||||
count += 0.125
|
||||
if text.count("<answer>") == 1:
|
||||
count += 0.125
|
||||
if text.count("</answer>") == 1:
|
||||
count += 0.125
|
||||
end_text = text.split("</answer>")[-1]
|
||||
count -= len(end_text) * 0.001 if len(end_text) > 0 else 0
|
||||
scores.append(max(0.0, count))
|
||||
return scores
|
||||
|
||||
|
||||
def generate_grpo(model: nn.Module, prompts, max_tokens, tokenizer, group_size, is_training=False, end_token: str = "</answer>", temperature: float = 0.8):
|
||||
def generate_grpo(
|
||||
model: nn.Module,
|
||||
prompts,
|
||||
max_tokens,
|
||||
tokenizer,
|
||||
group_size,
|
||||
is_training=False,
|
||||
end_token: str = "</answer>",
|
||||
temperature: float = 0.8
|
||||
):
|
||||
if len(prompts.shape) == 1:
|
||||
prompts = prompts[None, :]
|
||||
if prompts.shape[1] == 0:
|
||||
@@ -144,19 +72,38 @@ def generate_grpo(model: nn.Module, prompts, max_tokens, tokenizer, group_size,
|
||||
try:
|
||||
for idx in range(batch_size):
|
||||
current_tokens = []
|
||||
|
||||
if is_training:
|
||||
current_input = expanded_prompts[idx]
|
||||
prompt_cache = cache.make_prompt_cache(model)
|
||||
logits = model(current_input[None], cache=prompt_cache)[:, -1]
|
||||
mx.eval(logits, prompt_cache)
|
||||
|
||||
while len(current_tokens) < max_tokens:
|
||||
logits_temp = logits / temperature
|
||||
probs = nn.softmax(logits_temp, axis=-1)
|
||||
next_token = mx.argmax(probs, axis=-1)
|
||||
next_token = mx.random.categorical(logits_temp)
|
||||
token = next_token.item()
|
||||
|
||||
test_sequence = current_tokens + [token]
|
||||
if (len(test_sequence) >= len(end_sequence) and
|
||||
mx.array_equal(
|
||||
mx.array(test_sequence[-len(end_sequence):]),
|
||||
end_sequence
|
||||
)):
|
||||
current_tokens.append(token)
|
||||
break
|
||||
if token == tokenizer.eos_token_id:
|
||||
break
|
||||
current_tokens.append(token)
|
||||
current_input = mx.array([token])
|
||||
logits = model(current_input[None], cache=prompt_cache)[:, -1]
|
||||
mx.eval(current_input, logits, probs, next_token, token)
|
||||
else:
|
||||
generator = generate_step(
|
||||
expanded_prompts[idx],
|
||||
model,
|
||||
max_tokens=max_tokens,
|
||||
sampler=lambda x: mx.random.categorical(x / temperature)
|
||||
)
|
||||
for token, _ in generator:
|
||||
test_sequence = current_tokens + [token]
|
||||
if (len(test_sequence) >= len(end_sequence) and
|
||||
mx.array_equal(
|
||||
@@ -168,27 +115,10 @@ def generate_grpo(model: nn.Module, prompts, max_tokens, tokenizer, group_size,
|
||||
|
||||
if token == tokenizer.eos_token_id:
|
||||
break
|
||||
|
||||
current_tokens.append(token)
|
||||
current_input = mx.array([token])
|
||||
logits = model(current_input[None], cache=prompt_cache)[:, -1]
|
||||
mx.eval(current_input, logits, probs, next_token)
|
||||
else:
|
||||
generator = generate_step(
|
||||
expanded_prompts[idx],
|
||||
model,
|
||||
max_tokens=max_tokens,
|
||||
sampler=lambda x: mx.argmax(x, axis=-1)
|
||||
)
|
||||
for token, _ in generator:
|
||||
if token == tokenizer.eos_token_id:
|
||||
break
|
||||
current_tokens.append(token)
|
||||
|
||||
if current_tokens:
|
||||
results.append(mx.array(current_tokens))
|
||||
mx.metal.clear_cache()
|
||||
|
||||
mx.eval(results)
|
||||
return results
|
||||
|
||||
@@ -207,12 +137,10 @@ def get_per_token_logps(model: nn.Module, inputs, lengths):
|
||||
seq_logits = logits[i, :seq_len]
|
||||
seq_targets = targets[i, :seq_len]
|
||||
log_probs = nn.log_softmax(seq_logits, axis=-1)
|
||||
|
||||
token_log_probs = mx.take_along_axis(
|
||||
log_probs,
|
||||
seq_targets.reshape(seq_len, 1), axis=-1
|
||||
).squeeze(-1)
|
||||
|
||||
per_token_logps.append(token_log_probs)
|
||||
mx.eval(logits)
|
||||
return per_token_logps
|
||||
@@ -399,7 +327,7 @@ def grpo_loss(
|
||||
}
|
||||
|
||||
if is_validation:
|
||||
print(f"\nValidation sample generation:\n{all_completion_texts[-1]}\n")
|
||||
print(f"\nValidation sample generation:\n{all_completion_texts}\n")
|
||||
print(f"Validation sample answer:\n{answer_text[-1]}\n")
|
||||
mx.metal.clear_cache()
|
||||
|
||||
|
||||
Reference in New Issue
Block a user