mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-24 09:21:18 +08:00
clean up
This commit is contained in:
parent
0bc2a881ad
commit
e88f0fad4b
@ -121,9 +121,9 @@ def generate_grpo(
|
||||
prompt_tokens,
|
||||
max_tokens: int,
|
||||
group_size: int,
|
||||
end_token: str = "</answer>",
|
||||
temperature: float = 0.8,
|
||||
batch_size: int = 1,
|
||||
temperature: float,
|
||||
batch_size: int,
|
||||
end_token: str = "</answer>"
|
||||
):
|
||||
try:
|
||||
end_sequence = mx.array(tokenizer.encode(end_token))
|
||||
@ -239,7 +239,6 @@ def grpo_loss(
|
||||
|
||||
expanded_answers = []
|
||||
expanded_prompts = []
|
||||
|
||||
unique_prompt_indices = sorted(set(batch_indices))
|
||||
grouped_completions = {idx: [] for idx in unique_prompt_indices}
|
||||
|
||||
@ -262,7 +261,6 @@ def grpo_loss(
|
||||
all_completions = ordered_completions
|
||||
all_completion_texts = ordered_completion_texts
|
||||
batch_indices = ordered_batch_indices
|
||||
|
||||
max_length = max(ids.shape[0] for ids in all_completions)
|
||||
padded_completions = []
|
||||
attention_masks = []
|
||||
@ -617,11 +615,8 @@ def train_grpo(
|
||||
state = [model.state, optimizer.state]
|
||||
|
||||
def step(batch):
|
||||
# Extract prompt tokens from the batch
|
||||
prompt_tokens, targets, prompt_lens, target_lens = batch
|
||||
|
||||
# First, generate completions without gradient tracking
|
||||
# The model will be frozen during this call
|
||||
all_completions, all_completion_texts, batch_indices = generate_grpo(
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
@ -630,9 +625,7 @@ def train_grpo(
|
||||
group_size=args.group_size,
|
||||
temperature=args.temperature
|
||||
)
|
||||
|
||||
# Now calculate loss and gradients with pre-generated completions
|
||||
# We need to update loss_fn to accept these pre-generated completions
|
||||
|
||||
(loss, toks, metrics), grad = loss_value_and_grad(
|
||||
model,
|
||||
tokenizer=tokenizer,
|
||||
|
Loading…
Reference in New Issue
Block a user