mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-07-10 03:21:13 +08:00
generation speed improvement in training too from 3 t/s to 15 t/s
This commit is contained in:
parent
79de353530
commit
235348c211
@ -13,6 +13,7 @@ import numpy as np
|
|||||||
|
|
||||||
from .trainer import grad_checkpoint, TrainingArgs, TrainingCallback, average_gradients
|
from .trainer import grad_checkpoint, TrainingArgs, TrainingCallback, average_gradients
|
||||||
from ..utils import generate_step
|
from ..utils import generate_step
|
||||||
|
from ..models import cache
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class GRPOTrainingArgs(TrainingArgs):
|
class GRPOTrainingArgs(TrainingArgs):
|
||||||
@ -96,15 +97,13 @@ def r1_soft_format_reward_func(prompts: list, completions: list, answer: list, *
|
|||||||
if reason_content and answer_content:
|
if reason_content and answer_content:
|
||||||
scores.append(0.5)
|
scores.append(0.5)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
scores.append(0.0)
|
scores.append(0.0)
|
||||||
|
|
||||||
return scores
|
return scores
|
||||||
|
|
||||||
def r1_strict_format_reward_func(prompts: list, completions: list, answer: list, **kwargs) -> list[float]:
|
def r1_strict_format_reward_func(prompts: list, completions: list, answer: list, **kwargs) -> list[float]:
|
||||||
if not completions:
|
if not completions:
|
||||||
return [0.0] * len(prompts)
|
return [0.0] * len(prompts)
|
||||||
pattern = r"^<think>\n.*?\n</think>\n<answer>\n.*?\n</answer>\n$"
|
pattern = r"<think>\n.*?\n</think>\n<answer>*?</answer>"
|
||||||
matches = [bool(re.search(pattern, r)) if r else False for r in completions]
|
matches = [bool(re.search(pattern, r)) if r else False for r in completions]
|
||||||
return [0.5 if match else 0.0 for match in matches]
|
return [0.5 if match else 0.0 for match in matches]
|
||||||
|
|
||||||
@ -120,28 +119,31 @@ def r1_count_xml(prompts: list, completions: list, answer: list, **kwargs) -> li
|
|||||||
count = 0.0
|
count = 0.0
|
||||||
if text.count("<think>\n") == 1:
|
if text.count("<think>\n") == 1:
|
||||||
count += 0.125
|
count += 0.125
|
||||||
if text.count("\n</think>\n") == 1:
|
if text.count("</think>") == 1:
|
||||||
count += 0.125
|
count += 0.125
|
||||||
if text.count("\n<answer>\n") == 1:
|
if text.count("<answer>") == 1:
|
||||||
count += 0.125
|
count += 0.125
|
||||||
if text.count("\n</answer>\n") == 1:
|
if text.count("</answer>") == 1:
|
||||||
count += 0.125
|
count += 0.125
|
||||||
end_text = text.split("\n</answer>\n")[-1]
|
end_text = text.split("</answer>")[-1]
|
||||||
count -= len(end_text) * 0.001 if len(end_text) > 0 else 0
|
count -= len(end_text) * 0.001 if len(end_text) > 0 else 0
|
||||||
scores.append(max(0.0, count))
|
scores.append(max(0.0, count))
|
||||||
return scores
|
return scores
|
||||||
|
|
||||||
|
|
||||||
def generate_grpo(model: nn.Module, prompts, max_tokens, tokenizer, group_size, is_training=False):
|
def generate_grpo(model: nn.Module, prompts, max_tokens, tokenizer, group_size, is_training=False, end_token: str = "</answer>"):
|
||||||
|
if model.training == False:
|
||||||
|
print("Model is in training mode", model.training, "Manually setting to eval mode")
|
||||||
|
model.train()
|
||||||
|
|
||||||
if len(prompts.shape) == 1:
|
if len(prompts.shape) == 1:
|
||||||
prompts = prompts[None, :]
|
prompts = prompts[None, :]
|
||||||
if prompts.shape[1] == 0:
|
if prompts.shape[1] == 0:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
model.eval()
|
|
||||||
batch_size = prompts.shape[0] * group_size
|
batch_size = prompts.shape[0] * group_size
|
||||||
expanded_prompts = mx.repeat(prompts, group_size, axis=0)
|
expanded_prompts = mx.repeat(prompts, group_size, axis=0)
|
||||||
end_sequence = mx.array(tokenizer.encode("</answer>"))
|
end_sequence = mx.array(tokenizer.encode(end_token))
|
||||||
|
|
||||||
results = []
|
results = []
|
||||||
tokens_generated = 0
|
tokens_generated = 0
|
||||||
@ -153,13 +155,15 @@ def generate_grpo(model: nn.Module, prompts, max_tokens, tokenizer, group_size,
|
|||||||
|
|
||||||
if is_training:
|
if is_training:
|
||||||
current_input = expanded_prompts[idx]
|
current_input = expanded_prompts[idx]
|
||||||
|
prompt_cache = cache.make_prompt_cache(model)
|
||||||
|
|
||||||
|
# Initial forward pass with the prompt
|
||||||
|
logits = model(current_input[None], cache=prompt_cache)[:, -1]
|
||||||
|
|
||||||
while len(current_tokens) < max_tokens:
|
while len(current_tokens) < max_tokens:
|
||||||
logits = model(current_input[None])[:, -1]
|
|
||||||
probs = nn.softmax(logits, axis=-1)
|
probs = nn.softmax(logits, axis=-1)
|
||||||
next_token = mx.argmax(probs, axis=-1)
|
next_token = mx.argmax(probs, axis=-1)
|
||||||
token = next_token.item()
|
token = next_token.item()
|
||||||
current_tokens.append(token)
|
|
||||||
tokens_generated += 1
|
|
||||||
if token == tokenizer.eos_token_id:
|
if token == tokenizer.eos_token_id:
|
||||||
break
|
break
|
||||||
if (len(current_tokens) >= len(end_sequence) and
|
if (len(current_tokens) >= len(end_sequence) and
|
||||||
@ -168,10 +172,13 @@ def generate_grpo(model: nn.Module, prompts, max_tokens, tokenizer, group_size,
|
|||||||
end_sequence
|
end_sequence
|
||||||
)):
|
)):
|
||||||
break
|
break
|
||||||
current_input = mx.concatenate([current_input, mx.array([token])])
|
|
||||||
if len(current_tokens) % 32 == 0:
|
current_tokens.append(token)
|
||||||
mx.eval(current_input)
|
tokens_generated += 1
|
||||||
mx.metal.clear_cache()
|
current_input = mx.array([token])
|
||||||
|
logits = model(current_input[None], cache=prompt_cache)[:, -1]
|
||||||
|
mx.eval(current_input)
|
||||||
|
mx.metal.clear_cache()
|
||||||
else:
|
else:
|
||||||
generator = generate_step(
|
generator = generate_step(
|
||||||
expanded_prompts[idx],
|
expanded_prompts[idx],
|
||||||
@ -180,10 +187,10 @@ def generate_grpo(model: nn.Module, prompts, max_tokens, tokenizer, group_size,
|
|||||||
sampler=lambda x: mx.argmax(x, axis=-1)
|
sampler=lambda x: mx.argmax(x, axis=-1)
|
||||||
)
|
)
|
||||||
for token, _ in generator:
|
for token, _ in generator:
|
||||||
current_tokens.append(token)
|
|
||||||
tokens_generated += 1
|
|
||||||
if token == tokenizer.eos_token_id:
|
if token == tokenizer.eos_token_id:
|
||||||
break
|
break
|
||||||
|
current_tokens.append(token)
|
||||||
|
tokens_generated += 1
|
||||||
|
|
||||||
if current_tokens:
|
if current_tokens:
|
||||||
results.append(mx.array(current_tokens))
|
results.append(mx.array(current_tokens))
|
||||||
@ -401,7 +408,7 @@ def grpo_loss(
|
|||||||
|
|
||||||
if is_validation:
|
if is_validation:
|
||||||
print(f"\nValidation sample generation:\n{all_completion_texts[-1]}\n")
|
print(f"\nValidation sample generation:\n{all_completion_texts[-1]}\n")
|
||||||
|
print(f"Validation sample answer:\n{answer_text[-1]}\n")
|
||||||
mx.metal.clear_cache()
|
mx.metal.clear_cache()
|
||||||
|
|
||||||
return loss, sequence_lengths.sum(), metrics
|
return loss, sequence_lengths.sum(), metrics
|
||||||
|
Loading…
Reference in New Issue
Block a user