mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-26 02:33:23 +08:00
updates
This commit is contained in:
parent
35a2d99cf9
commit
0a19522ec4
@ -45,44 +45,48 @@ def r1_extract_xml_answer(text: str) -> str:
|
|||||||
print("r1_extract_xml_answer returned empty string")
|
print("r1_extract_xml_answer returned empty string")
|
||||||
return ""
|
return ""
|
||||||
|
|
||||||
|
|
||||||
def r1_int_reward_func(prompts: list, completions: list, answer: list, **kwargs) -> list[float]:
|
def r1_int_reward_func(prompts: list, completions: list, answer: list, **kwargs) -> list[float]:
|
||||||
"""Ensures we always return a list of floats."""
|
"""Ensures we always return a list of floats."""
|
||||||
if not completions: # Handle empty completions
|
if not completions:
|
||||||
return [0.0] * len(prompts)
|
return [0.0] * len(prompts)
|
||||||
extracted_responses = [r1_extract_xml_answer(r) for r in completions]
|
extracted_responses = [r1_extract_xml_answer(r) for r in completions]
|
||||||
return [0.5 if r and r.isdigit() else 0.0 for r in extracted_responses]
|
return [0.5 if r and r.isdigit() else 0.0 for r in extracted_responses]
|
||||||
|
|
||||||
def r1_accuracy_reward_func(prompts: list, completions: list, answer: list, **kwargs) -> list[float]:
|
def r1_accuracy_reward_func(prompts: list, completions: list, answer: list, **kwargs) -> list[float]:
|
||||||
"""Ensures we always return a list of floats."""
|
"""Ensures we always return a list of floats."""
|
||||||
if not completions or not answer: # Handle empty inputs
|
if not completions or not answer:
|
||||||
return [0.0] * len(prompts)
|
return [0.0] * len(prompts)
|
||||||
extracted_responses = [r1_extract_xml_answer(r) for r in completions]
|
extracted_responses = [r1_extract_xml_answer(r) for r in completions]
|
||||||
return [2.0 if r and a and r == a else 0.0 for r, a in zip(extracted_responses, answer)]
|
return [2.0 if r and a and r == a else 0.0 for r, a in zip(extracted_responses, answer)]
|
||||||
|
|
||||||
|
|
||||||
def r1_soft_format_reward_func(prompts: list, completions: list, answer: list, **kwargs) -> list[float]:
|
def r1_soft_format_reward_func(prompts: list, completions: list, answer: list, **kwargs) -> list[float]:
|
||||||
"""Ensures we always return a list of floats."""
|
"""Ensures we always return a list of floats."""
|
||||||
if not completions: # Handle empty completions
|
if not completions:
|
||||||
return [0.0] * len(prompts)
|
return [0.0] * len(prompts)
|
||||||
pattern = r"<think>.*?</think>\s*<answer>.*?</answer>"
|
pattern = r"<think>.*?</think>\s*<answer>.*?</answer>"
|
||||||
matches = [bool(re.search(pattern, r)) if r else False for r in completions]
|
matches = [bool(re.search(pattern, r)) if r else False for r in completions]
|
||||||
return [0.5 if match else 0.0 for match in matches]
|
return [0.5 if match else 0.0 for match in matches]
|
||||||
|
|
||||||
|
|
||||||
def r1_strict_format_reward_func(prompts: list, completions: list, answer: list, **kwargs) -> list[float]:
|
def r1_strict_format_reward_func(prompts: list, completions: list, answer: list, **kwargs) -> list[float]:
|
||||||
"""Ensures we always return a list of floats."""
|
"""Ensures we always return a list of floats."""
|
||||||
if not completions: # Handle empty completions
|
if not completions:
|
||||||
return [0.0] * len(prompts)
|
return [0.0] * len(prompts)
|
||||||
pattern = r"^<think>\n.*?\n</think>\n<answer>\n.*?\n</answer>\n$"
|
pattern = r"^<think>\n.*?\n</think>\n<answer>\n.*?\n</answer>\n$"
|
||||||
matches = [bool(re.search(pattern, r)) if r else False for r in completions]
|
matches = [bool(re.search(pattern, r)) if r else False for r in completions]
|
||||||
return [0.5 if match else 0.0 for match in matches]
|
return [0.5 if match else 0.0 for match in matches]
|
||||||
|
|
||||||
|
|
||||||
def r1_count_xml(prompts: list, completions: list, answer: list, **kwargs) -> list[float]:
|
def r1_count_xml(prompts: list, completions: list, answer: list, **kwargs) -> list[float]:
|
||||||
"""Ensures we always return a list of floats."""
|
"""Ensures we always return a list of floats."""
|
||||||
if not completions: # Handle empty completions
|
if not completions:
|
||||||
return [0.0] * len(prompts)
|
return [0.0] * len(prompts)
|
||||||
|
|
||||||
scores = []
|
scores = []
|
||||||
for text in completions:
|
for text in completions:
|
||||||
if not text: # Handle None or empty text
|
if not text:
|
||||||
scores.append(0.0)
|
scores.append(0.0)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
@ -137,11 +141,9 @@ def generate_grpo(model, prompt, max_tokens, tokenizer, temperature):
|
|||||||
output[current_length] = token_value
|
output[current_length] = token_value
|
||||||
current_length += 1
|
current_length += 1
|
||||||
|
|
||||||
# Check for EOS token
|
|
||||||
if token_value == tokenizer.eos_token_id:
|
if token_value == tokenizer.eos_token_id:
|
||||||
break
|
break
|
||||||
|
|
||||||
# Check for "</answer>" sequence
|
|
||||||
if current_length >= end_sequence_length:
|
if current_length >= end_sequence_length:
|
||||||
last_tokens = output[current_length - end_sequence_length:current_length].tolist()
|
last_tokens = output[current_length - end_sequence_length:current_length].tolist()
|
||||||
if last_tokens == end_sequence:
|
if last_tokens == end_sequence:
|
||||||
@ -255,7 +257,6 @@ def grpo_loss(
|
|||||||
mx.eval(token_log_probs)
|
mx.eval(token_log_probs)
|
||||||
mx.metal.clear_cache()
|
mx.metal.clear_cache()
|
||||||
|
|
||||||
|
|
||||||
# Reference policy probabilities
|
# Reference policy probabilities
|
||||||
if ref_model is not None:
|
if ref_model is not None:
|
||||||
ref_token_log_probs = get_per_token_logps(ref_model, inputs, lengths)
|
ref_token_log_probs = get_per_token_logps(ref_model, inputs, lengths)
|
||||||
@ -305,11 +306,12 @@ def grpo_loss(
|
|||||||
policy_ratio = mx.exp(token_log_probs - mx.stop_gradient(token_log_probs))
|
policy_ratio = mx.exp(token_log_probs - mx.stop_gradient(token_log_probs))
|
||||||
|
|
||||||
# Compute per-token loss following GRPO formula
|
# Compute per-token loss following GRPO formula
|
||||||
per_token_loss = -(policy_ratio * advantages.reshape(-1, 1) - beta * kl_div) * length_mask
|
per_token_loss = -((policy_ratio * advantages.reshape(-1, 1) - beta * kl_div) * length_mask)
|
||||||
|
|
||||||
# Average over tokens and sequences
|
# Average over tokens and sequences
|
||||||
sequence_sums = per_token_loss.sum(axis=1)
|
sequence_sums = per_token_loss.sum(axis=1)
|
||||||
sequence_lengths = length_mask.sum(axis=1)
|
sequence_lengths = length_mask.sum(axis=1)
|
||||||
|
|
||||||
loss = (sequence_sums / sequence_lengths).mean()
|
loss = (sequence_sums / sequence_lengths).mean()
|
||||||
|
|
||||||
# Calculate mean KL divergence for metrics
|
# Calculate mean KL divergence for metrics
|
||||||
|
Loading…
Reference in New Issue
Block a user