mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-26 02:33:23 +08:00
updates
This commit is contained in:
parent
88ca747e9e
commit
e96afe9e9f
@ -41,7 +41,7 @@ class GRPODataset:
|
|||||||
prompt_tokens = tokenizer.encode(f"""A conversation between User and Assistant. The user asks a question, and the Assistant solves it.
|
prompt_tokens = tokenizer.encode(f"""A conversation between User and Assistant. The user asks a question, and the Assistant solves it.
|
||||||
The assistantfirst thinks about the reasoning process in the mind and then provides the user with the answer.
|
The assistantfirst thinks about the reasoning process in the mind and then provides the user with the answer.
|
||||||
The reasoning process and answer are enclosed within <think> </think> and <answer> </answer> tags, respectively, i.e., <think> reasoning process here </think><answer> answer here </answer>.
|
The reasoning process and answer are enclosed within <think> </think> and <answer> </answer> tags, respectively, i.e., <think> reasoning process here </think><answer> answer here </answer>.
|
||||||
User: {prompt_str}. Assistant: """)
|
User: {prompt_str} Assistant: """)
|
||||||
else:
|
else:
|
||||||
prompt_tokens = tokenizer.encode(prompt_str)
|
prompt_tokens = tokenizer.encode(prompt_str)
|
||||||
answer_tokens = tokenizer.encode(answer_str)
|
answer_tokens = tokenizer.encode(answer_str)
|
||||||
|
@ -447,8 +447,8 @@ def evaluate_grpo(
|
|||||||
mx.eval(all_losses, ntokens)
|
mx.eval(all_losses, ntokens)
|
||||||
|
|
||||||
# Aggregate across distributed workers
|
# Aggregate across distributed workers
|
||||||
all_losses = mx.distributed.all_sum(all_losses, stream=mx.cpu)
|
all_losses = mx.distributed.all_sum(all_losses, stream=mx.gpu)
|
||||||
ntokens = mx.distributed.all_sum(ntokens, stream=mx.cpu)
|
ntokens = mx.distributed.all_sum(ntokens, stream=mx.gpu)
|
||||||
all_metrics = {k: mx.distributed.all_sum(v) for k, v in all_metrics.items()}
|
all_metrics = {k: mx.distributed.all_sum(v) for k, v in all_metrics.items()}
|
||||||
|
|
||||||
# Calculate averages
|
# Calculate averages
|
||||||
|
Loading…
Reference in New Issue
Block a user