This commit is contained in:
Goekdeniz-Guelmez 2025-02-11 09:09:28 +01:00
parent 88ca747e9e
commit e96afe9e9f
2 changed files with 3 additions and 3 deletions

View File

@ -41,7 +41,7 @@ class GRPODataset:
prompt_tokens = tokenizer.encode(f"""A conversation between User and Assistant. The user asks a question, and the Assistant solves it. prompt_tokens = tokenizer.encode(f"""A conversation between User and Assistant. The user asks a question, and the Assistant solves it.
The assistantfirst thinks about the reasoning process in the mind and then provides the user with the answer. The assistantfirst thinks about the reasoning process in the mind and then provides the user with the answer.
The reasoning process and answer are enclosed within <think> </think> and <answer> </answer> tags, respectively, i.e., <think> reasoning process here </think><answer> answer here </answer>. The reasoning process and answer are enclosed within <think> </think> and <answer> </answer> tags, respectively, i.e., <think> reasoning process here </think><answer> answer here </answer>.
User: {prompt_str}. Assistant: """) User: {prompt_str} Assistant: """)
else: else:
prompt_tokens = tokenizer.encode(prompt_str) prompt_tokens = tokenizer.encode(prompt_str)
answer_tokens = tokenizer.encode(answer_str) answer_tokens = tokenizer.encode(answer_str)

View File

@ -447,8 +447,8 @@ def evaluate_grpo(
mx.eval(all_losses, ntokens) mx.eval(all_losses, ntokens)
# Aggregate across distributed workers # Aggregate across distributed workers
all_losses = mx.distributed.all_sum(all_losses, stream=mx.cpu) all_losses = mx.distributed.all_sum(all_losses, stream=mx.gpu)
ntokens = mx.distributed.all_sum(ntokens, stream=mx.cpu) ntokens = mx.distributed.all_sum(ntokens, stream=mx.gpu)
all_metrics = {k: mx.distributed.all_sum(v) for k, v in all_metrics.items()} all_metrics = {k: mx.distributed.all_sum(v) for k, v in all_metrics.items()}
# Calculate averages # Calculate averages