mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-24 09:21:18 +08:00
clean up
This commit is contained in:
parent
0bc2a881ad
commit
e88f0fad4b
@ -121,9 +121,9 @@ def generate_grpo(
|
|||||||
prompt_tokens,
|
prompt_tokens,
|
||||||
max_tokens: int,
|
max_tokens: int,
|
||||||
group_size: int,
|
group_size: int,
|
||||||
end_token: str = "</answer>",
|
temperature: float,
|
||||||
temperature: float = 0.8,
|
batch_size: int,
|
||||||
batch_size: int = 1,
|
end_token: str = "</answer>"
|
||||||
):
|
):
|
||||||
try:
|
try:
|
||||||
end_sequence = mx.array(tokenizer.encode(end_token))
|
end_sequence = mx.array(tokenizer.encode(end_token))
|
||||||
@ -239,7 +239,6 @@ def grpo_loss(
|
|||||||
|
|
||||||
expanded_answers = []
|
expanded_answers = []
|
||||||
expanded_prompts = []
|
expanded_prompts = []
|
||||||
|
|
||||||
unique_prompt_indices = sorted(set(batch_indices))
|
unique_prompt_indices = sorted(set(batch_indices))
|
||||||
grouped_completions = {idx: [] for idx in unique_prompt_indices}
|
grouped_completions = {idx: [] for idx in unique_prompt_indices}
|
||||||
|
|
||||||
@ -262,7 +261,6 @@ def grpo_loss(
|
|||||||
all_completions = ordered_completions
|
all_completions = ordered_completions
|
||||||
all_completion_texts = ordered_completion_texts
|
all_completion_texts = ordered_completion_texts
|
||||||
batch_indices = ordered_batch_indices
|
batch_indices = ordered_batch_indices
|
||||||
|
|
||||||
max_length = max(ids.shape[0] for ids in all_completions)
|
max_length = max(ids.shape[0] for ids in all_completions)
|
||||||
padded_completions = []
|
padded_completions = []
|
||||||
attention_masks = []
|
attention_masks = []
|
||||||
@ -617,11 +615,8 @@ def train_grpo(
|
|||||||
state = [model.state, optimizer.state]
|
state = [model.state, optimizer.state]
|
||||||
|
|
||||||
def step(batch):
|
def step(batch):
|
||||||
# Extract prompt tokens from the batch
|
|
||||||
prompt_tokens, targets, prompt_lens, target_lens = batch
|
prompt_tokens, targets, prompt_lens, target_lens = batch
|
||||||
|
|
||||||
# First, generate completions without gradient tracking
|
|
||||||
# The model will be frozen during this call
|
|
||||||
all_completions, all_completion_texts, batch_indices = generate_grpo(
|
all_completions, all_completion_texts, batch_indices = generate_grpo(
|
||||||
model=model,
|
model=model,
|
||||||
tokenizer=tokenizer,
|
tokenizer=tokenizer,
|
||||||
@ -630,9 +625,7 @@ def train_grpo(
|
|||||||
group_size=args.group_size,
|
group_size=args.group_size,
|
||||||
temperature=args.temperature
|
temperature=args.temperature
|
||||||
)
|
)
|
||||||
|
|
||||||
# Now calculate loss and gradients with pre-generated completions
|
|
||||||
# We need to update loss_fn to accept these pre-generated completions
|
|
||||||
(loss, toks, metrics), grad = loss_value_and_grad(
|
(loss, toks, metrics), grad = loss_value_and_grad(
|
||||||
model,
|
model,
|
||||||
tokenizer=tokenizer,
|
tokenizer=tokenizer,
|
||||||
|
Loading…
Reference in New Issue
Block a user