mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-25 18:11:17 +08:00
updates
This commit is contained in:
parent
88ca747e9e
commit
e96afe9e9f
@ -41,7 +41,7 @@ class GRPODataset:
|
||||
prompt_tokens = tokenizer.encode(f"""A conversation between User and Assistant. The user asks a question, and the Assistant solves it.
|
||||
The assistantfirst thinks about the reasoning process in the mind and then provides the user with the answer.
|
||||
The reasoning process and answer are enclosed within <think> </think> and <answer> </answer> tags, respectively, i.e., <think> reasoning process here </think><answer> answer here </answer>.
|
||||
User: {prompt_str}. Assistant: """)
|
||||
User: {prompt_str} Assistant: """)
|
||||
else:
|
||||
prompt_tokens = tokenizer.encode(prompt_str)
|
||||
answer_tokens = tokenizer.encode(answer_str)
|
||||
|
@ -447,8 +447,8 @@ def evaluate_grpo(
|
||||
mx.eval(all_losses, ntokens)
|
||||
|
||||
# Aggregate across distributed workers
|
||||
all_losses = mx.distributed.all_sum(all_losses, stream=mx.cpu)
|
||||
ntokens = mx.distributed.all_sum(ntokens, stream=mx.cpu)
|
||||
all_losses = mx.distributed.all_sum(all_losses, stream=mx.gpu)
|
||||
ntokens = mx.distributed.all_sum(ntokens, stream=mx.gpu)
|
||||
all_metrics = {k: mx.distributed.all_sum(v) for k, v in all_metrics.items()}
|
||||
|
||||
# Calculate averages
|
||||
|
Loading…
Reference in New Issue
Block a user