mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-27 03:05:20 +08:00
updates
This commit is contained in:
parent
93370ff1c3
commit
6c58aa995c
@ -44,7 +44,7 @@ class GRPOTrainingArgs(TrainingArgs):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def compute_rewards(sequences, batch_size, group_size):
|
def compute_default_rewards(sequences, batch_size, group_size):
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
sequences: List of word sequences
|
sequences: List of word sequences
|
||||||
@ -72,6 +72,7 @@ def grpo_loss(
|
|||||||
model,
|
model,
|
||||||
tokenizer,
|
tokenizer,
|
||||||
prompts,
|
prompts,
|
||||||
|
reward_funcs=None,
|
||||||
beta=0.1,
|
beta=0.1,
|
||||||
group_size=4,
|
group_size=4,
|
||||||
epslion=1e-4,
|
epslion=1e-4,
|
||||||
@ -134,7 +135,10 @@ def grpo_loss(
|
|||||||
kl_div = (mx.exp(ref_token_log_probs - token_log_probs) - (ref_token_log_probs - token_log_probs) - 1)
|
kl_div = (mx.exp(ref_token_log_probs - token_log_probs) - (ref_token_log_probs - token_log_probs) - 1)
|
||||||
|
|
||||||
# Calculate rewards
|
# Calculate rewards
|
||||||
rewards = compute_rewards(all_completions, batch_size, group_size)
|
if reward_funcs:
|
||||||
|
rewards = mx.array([sum(rf(all_completions) for rf in reward_funcs)])
|
||||||
|
else:
|
||||||
|
rewards = compute_default_rewards(all_completions, batch_size, group_size)
|
||||||
|
|
||||||
# Compute grouped-wise rewards
|
# Compute grouped-wise rewards
|
||||||
grouped_rewards = rewards.reshape(batch_size, group_size)
|
grouped_rewards = rewards.reshape(batch_size, group_size)
|
||||||
@ -266,6 +270,59 @@ def evaluate_grpo(
|
|||||||
|
|
||||||
return (all_losses / ntokens).item()
|
return (all_losses / ntokens).item()
|
||||||
|
|
||||||
|
def evaluate_grpo(
|
||||||
|
model,
|
||||||
|
ref_model,
|
||||||
|
dataset,
|
||||||
|
tokenizer,
|
||||||
|
batch_size,
|
||||||
|
num_batches,
|
||||||
|
beta: float,
|
||||||
|
epslion: float,
|
||||||
|
group_size: int,
|
||||||
|
max_seq_length,
|
||||||
|
reward_funcs=None,
|
||||||
|
loss: callable = grpo_loss,
|
||||||
|
iterate_batches: callable = iterate_batches
|
||||||
|
):
|
||||||
|
all_losses = 0
|
||||||
|
ntokens = 0
|
||||||
|
|
||||||
|
index_iterator = iter(range(num_batches)) if num_batches != -1 else iter(int, 1)
|
||||||
|
|
||||||
|
for _, batch in zip(
|
||||||
|
index_iterator,
|
||||||
|
iterate_batches(
|
||||||
|
dataset=dataset,
|
||||||
|
tokenizer=tokenizer,
|
||||||
|
batch_size=batch_size,
|
||||||
|
max_seq_length=max_seq_length,
|
||||||
|
),
|
||||||
|
):
|
||||||
|
# Extract prompts from the batch (assuming the batch contains 'prompts')
|
||||||
|
prompts = batch.get("prompts", None)
|
||||||
|
|
||||||
|
# Call the loss function with the correct arguments
|
||||||
|
losses, toks, metrics = loss(
|
||||||
|
model=model,
|
||||||
|
tokenizer=tokenizer,
|
||||||
|
prompts=prompts,
|
||||||
|
reward_funcs=reward_funcs,
|
||||||
|
beta=beta,
|
||||||
|
group_size=group_size,
|
||||||
|
epslion=epslion,
|
||||||
|
ref_model=ref_model
|
||||||
|
)
|
||||||
|
|
||||||
|
all_losses += losses * toks
|
||||||
|
ntokens += toks
|
||||||
|
mx.eval(all_losses, ntokens)
|
||||||
|
|
||||||
|
all_losses = mx.distributed.all_sum(all_losses, stream=mx.cpu)
|
||||||
|
ntokens = mx.distributed.all_sum(ntokens, stream=mx.cpu)
|
||||||
|
|
||||||
|
return (all_losses / ntokens).item()
|
||||||
|
|
||||||
|
|
||||||
def train(
|
def train(
|
||||||
model,
|
model,
|
||||||
|
Loading…
Reference in New Issue
Block a user