mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-26 02:33:23 +08:00
udpates
This commit is contained in:
parent
d84ad0cf86
commit
a33cad84b4
@ -206,15 +206,15 @@ def build_parser():
|
|||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--use-chat-template",
|
"--use-chat-template",
|
||||||
type=bool,
|
action="store_true",
|
||||||
help="If the model is a Chat model, use the Chat template.",
|
help="If the model is a Chat model, use the Chat template.",
|
||||||
default=False,
|
default=None,
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--use-prompt",
|
"--use-prompt",
|
||||||
type=bool,
|
action="store_true",
|
||||||
help="Rather to use the prompt from teh R1 paper.",
|
help="Rather to use the prompt from the R1 paper.",
|
||||||
default=False,
|
default=None,
|
||||||
)
|
)
|
||||||
return parser
|
return parser
|
||||||
|
|
||||||
|
@ -12,7 +12,6 @@ from mlx.utils import tree_flatten
|
|||||||
|
|
||||||
from .trainer import grad_checkpoint, TrainingArgs, TrainingCallback, average_gradients, iterate_batches
|
from .trainer import grad_checkpoint, TrainingArgs, TrainingCallback, average_gradients, iterate_batches
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class GRPOTrainingArgs(TrainingArgs):
|
class GRPOTrainingArgs(TrainingArgs):
|
||||||
group_size: int = field(
|
group_size: int = field(
|
||||||
@ -35,7 +34,6 @@ class GRPOTrainingArgs(TrainingArgs):
|
|||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def r1_extract_xml_answer(text: str) -> str:
|
def r1_extract_xml_answer(text: str) -> str:
|
||||||
"""Extracts the answer from an XML formatted text string."""
|
"""Extracts the answer from an XML formatted text string."""
|
||||||
try:
|
try:
|
||||||
@ -46,62 +44,30 @@ def r1_extract_xml_answer(text: str) -> str:
|
|||||||
print("[extract_xml_answer] Failed to extract answer from: ", text)
|
print("[extract_xml_answer] Failed to extract answer from: ", text)
|
||||||
return ""
|
return ""
|
||||||
|
|
||||||
def r1_accuracy_reward_func(prompts: list, completions: list, answer: list, **kwargs) -> list[float]:
|
|
||||||
"""Calculates reward based on accuracy of extracted answers.
|
|
||||||
Args:
|
|
||||||
prompts: List of input prompts
|
|
||||||
completions: List of completion strings
|
|
||||||
answer: Expected answer or list of answers
|
|
||||||
**kwargs: Additional arguments
|
|
||||||
Returns:
|
|
||||||
list[float]: Reward values for each completion
|
|
||||||
"""
|
|
||||||
extracted_responses = [r1_extract_xml_answer(r) for r in completions]
|
|
||||||
return [2.0 if r == a else 0.0 for r, a in zip(extracted_responses, answer)]
|
|
||||||
|
|
||||||
def r1_int_reward_func(prompts: list, completions: list, answer: list, **kwargs) -> list[float]:
|
def r1_int_reward_func(prompts: list, completions: list, answer: list, **kwargs) -> list[float]:
|
||||||
"""Rewards numerical responses.
|
|
||||||
Args:
|
|
||||||
prompts: List of input prompts
|
|
||||||
completions: List of completion strings
|
|
||||||
answer: Expected answer or list of answers
|
|
||||||
**kwargs: Additional arguments
|
|
||||||
Returns:
|
|
||||||
list[float]: Reward values for each completion
|
|
||||||
"""
|
|
||||||
extracted_responses = [r1_extract_xml_answer(r) for r in completions]
|
extracted_responses = [r1_extract_xml_answer(r) for r in completions]
|
||||||
return [0.5 if r.isdigit() else 0.0 for r in extracted_responses]
|
return [0.5 if r.isdigit() else 0.0 for r in extracted_responses]
|
||||||
|
|
||||||
|
def r1_accuracy_reward_func(prompts: list, completions: list, answer: list, **kwargs) -> list[float]:
|
||||||
|
extracted_responses = [r1_extract_xml_answer(r) for r in completions]
|
||||||
|
return [2.0 if r == a else 0.0 for r, a in zip(extracted_responses, answer)]
|
||||||
|
|
||||||
def r1_soft_format_reward_func(prompts: list, completions: list, answer: list, **kwargs) -> list[float]:
|
def r1_soft_format_reward_func(prompts: list, completions: list, answer: list, **kwargs) -> list[float]:
|
||||||
"""Rewards completions with flexible XML format."""
|
"""Rewards completions with flexible XML format."""
|
||||||
pattern = r"<think>.*?</think>\s*<answer>.*?</answer>"
|
pattern = r"<think>.*?</think>\s*<answer>.*?</answer>"
|
||||||
matches = [re.match(pattern, r) for r in completions]
|
matches = [re.match(pattern, r) for r in completions]
|
||||||
return [0.5 if match else 0.0 for match in matches]
|
return [0.5 if match else 0.0 for match in matches]
|
||||||
|
|
||||||
|
def r1_int_reward_func(prompts: list, completions: list, answer: list, **kwargs) -> list[float]:
|
||||||
|
extracted_responses = [r1_extract_xml_answer(r) for r in completions]
|
||||||
|
return [0.5 if r.isdigit() else 0.0 for r in extracted_responses]
|
||||||
|
|
||||||
def r1_strict_format_reward_func(prompts: list, completions: list, answer: list, **kwargs) -> list[float]:
|
def r1_strict_format_reward_func(prompts: list, completions: list, answer: list, **kwargs) -> list[float]:
|
||||||
"""Rewards completions with strict XML format.
|
|
||||||
Args:
|
|
||||||
prompts: List of input prompts
|
|
||||||
completions: List of completion strings
|
|
||||||
answer: Expected answer or list of answers
|
|
||||||
**kwargs: Additional arguments
|
|
||||||
Returns:
|
|
||||||
list[float]: Reward values for each completion
|
|
||||||
"""
|
|
||||||
pattern = r"^<think>\n.*?\n</think>\n<answer>\n.*?\n</answer>\n$"
|
pattern = r"^<think>\n.*?\n</think>\n<answer>\n.*?\n</answer>\n$"
|
||||||
matches = [re.match(pattern, r) for r in completions]
|
matches = [re.match(pattern, r) for r in completions]
|
||||||
return [0.5 if match else 0.0 for match in matches]
|
return [0.5 if match else 0.0 for match in matches]
|
||||||
|
|
||||||
def r1_count_xml(prompts: list, completions: list, answer: list, **kwargs) -> list[float]:
|
def r1_count_xml(prompts: list, completions: list, answer: list, **kwargs) -> list[float]:
|
||||||
"""Calculates score based on XML formatting.
|
|
||||||
Args:
|
|
||||||
prompts: List of input prompts (unused)
|
|
||||||
completions: List of completion strings to evaluate
|
|
||||||
answer: Expected answer or list of answers (unused)
|
|
||||||
**kwargs: Additional arguments
|
|
||||||
Returns:
|
|
||||||
list[float]: List of scores based on XML tag presence and formatting
|
|
||||||
"""
|
|
||||||
scores = []
|
scores = []
|
||||||
for text in completions:
|
for text in completions:
|
||||||
count = 0.0
|
count = 0.0
|
||||||
@ -116,10 +82,9 @@ def r1_count_xml(prompts: list, completions: list, answer: list, **kwargs) -> li
|
|||||||
count += 0.125
|
count += 0.125
|
||||||
count -= (len(text.split("\n</answer>")[-1]) - 1)*0.001
|
count -= (len(text.split("\n</answer>")[-1]) - 1)*0.001
|
||||||
scores.append(count)
|
scores.append(count)
|
||||||
return scores
|
|
||||||
|
|
||||||
|
|
||||||
def generate_grpo(model, prompt, max_tokens, tokenizer, temperature=1.0):
|
def generate_grpo(model, prompt, max_tokens, tokenizer, temperature):
|
||||||
if len(prompt.shape) == 1:
|
if len(prompt.shape) == 1:
|
||||||
prompt = prompt[None, :]
|
prompt = prompt[None, :]
|
||||||
if prompt.shape[1] == 0:
|
if prompt.shape[1] == 0:
|
||||||
@ -172,30 +137,24 @@ def generate_grpo(model, prompt, max_tokens, tokenizer, temperature=1.0):
|
|||||||
|
|
||||||
|
|
||||||
def get_per_token_logps(model, inputs, lengths):
|
def get_per_token_logps(model, inputs, lengths):
|
||||||
logits = model(inputs).astype(mx.float16) # [batch_size, seq_len, vocab_size]
|
logits = model(inputs).astype(mx.float16)
|
||||||
# Remove last position as it corresponds to the next token prediction
|
logits = logits[:, :-1, :]
|
||||||
logits = logits[:, :-1, :] # [batch_size, seq_len-1, vocab_size]
|
targets = inputs[:, 1:]
|
||||||
targets = inputs[:, 1:] # Shift inputs to get targets
|
|
||||||
|
|
||||||
# Process sequences individually to save memory
|
|
||||||
per_token_logps = []
|
per_token_logps = []
|
||||||
for i in range(logits.shape[0]):
|
for i in range(logits.shape[0]):
|
||||||
# Get sequence length for this example
|
seq_len = int(lengths[i]) - 1
|
||||||
seq_len = int(lengths[i]) - 1 # -1 because we removed last position
|
|
||||||
|
|
||||||
# Get logits and targets for this sequence
|
seq_logits = logits[i, :seq_len]
|
||||||
seq_logits = logits[i, :seq_len] # [seq_len, vocab_size]
|
seq_targets = targets[i, :seq_len]
|
||||||
seq_targets = targets[i, :seq_len] # [seq_len]
|
|
||||||
|
|
||||||
# Compute log probabilities
|
log_probs = nn.log_softmax(seq_logits, axis=-1)
|
||||||
log_probs = nn.log_softmax(seq_logits, axis=-1) # [seq_len, vocab_size]
|
|
||||||
|
|
||||||
# Gather log probs for actual tokens
|
|
||||||
token_log_probs = mx.take_along_axis(
|
token_log_probs = mx.take_along_axis(
|
||||||
log_probs,
|
log_probs,
|
||||||
seq_targets.reshape(seq_len, 1),
|
seq_targets.reshape(seq_len, 1),
|
||||||
axis=-1
|
axis=-1
|
||||||
).squeeze(-1) # [seq_len]
|
).squeeze(-1)
|
||||||
|
|
||||||
per_token_logps.append(token_log_probs)
|
per_token_logps.append(token_log_probs)
|
||||||
mx.eval(logits)
|
mx.eval(logits)
|
||||||
@ -316,7 +275,7 @@ def grpo_loss(
|
|||||||
advantages = (rewards - mean_rewards) / (std_rewards + epsilon)
|
advantages = (rewards - mean_rewards) / (std_rewards + epsilon)
|
||||||
|
|
||||||
# Compute KL divergence using Schulman's approximator
|
# Compute KL divergence using Schulman's approximator
|
||||||
kl_div = mx.exp(ref_token_log_probs - token_log_probs) - (ref_token_log_probs - token_log_probs) - 1
|
kl_div = (mx.exp(token_log_probs - ref_token_log_probs) - 1) - (token_log_probs - ref_token_log_probs)
|
||||||
|
|
||||||
# Create mask for valid tokens
|
# Create mask for valid tokens
|
||||||
length_mask = mx.arange(inputs.shape[1] - 1)[None, :] < (lengths[:, None] - 1)
|
length_mask = mx.arange(inputs.shape[1] - 1)[None, :] < (lengths[:, None] - 1)
|
||||||
@ -325,10 +284,10 @@ def grpo_loss(
|
|||||||
policy_ratio = mx.exp(token_log_probs - mx.stop_gradient(token_log_probs))
|
policy_ratio = mx.exp(token_log_probs - mx.stop_gradient(token_log_probs))
|
||||||
|
|
||||||
# Compute per-token loss following GRPO formula
|
# Compute per-token loss following GRPO formula
|
||||||
per_token_loss = -(policy_ratio * advantages.reshape(-1, 1) - beta * kl_div)
|
per_token_loss = -(policy_ratio * advantages.reshape(-1, 1) - beta * kl_div) * length_mask
|
||||||
|
|
||||||
# Average over tokens and sequences
|
# Average over tokens and sequences
|
||||||
sequence_sums = (per_token_loss * length_mask).sum(axis=1)
|
sequence_sums = per_token_loss.sum(axis=1)
|
||||||
sequence_lengths = length_mask.sum(axis=1)
|
sequence_lengths = length_mask.sum(axis=1)
|
||||||
loss = (sequence_sums / sequence_lengths).mean()
|
loss = (sequence_sums / sequence_lengths).mean()
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user