mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-29 21:01:13 +08:00
grpo_trainer shoudl be done
This commit is contained in:
parent
6c58aa995c
commit
80bcf68956
@ -240,6 +240,7 @@ def evaluate_grpo(
|
|||||||
epslion: float,
|
epslion: float,
|
||||||
group_size: int,
|
group_size: int,
|
||||||
max_seq_length,
|
max_seq_length,
|
||||||
|
reward_funcs = None,
|
||||||
loss: callable = grpo_loss,
|
loss: callable = grpo_loss,
|
||||||
iterate_batches: callable = iterate_batches
|
iterate_batches: callable = iterate_batches
|
||||||
):
|
):
|
||||||
@ -257,52 +258,7 @@ def evaluate_grpo(
|
|||||||
max_seq_length=max_seq_length,
|
max_seq_length=max_seq_length,
|
||||||
),
|
),
|
||||||
):
|
):
|
||||||
losses, toks = loss(
|
prompts = batch
|
||||||
model,
|
|
||||||
*batch
|
|
||||||
)
|
|
||||||
all_losses += losses * toks
|
|
||||||
ntokens += toks
|
|
||||||
mx.eval(all_losses, ntokens)
|
|
||||||
|
|
||||||
all_losses = mx.distributed.all_sum(all_losses, stream=mx.cpu)
|
|
||||||
ntokens = mx.distributed.all_sum(ntokens, stream=mx.cpu)
|
|
||||||
|
|
||||||
return (all_losses / ntokens).item()
|
|
||||||
|
|
||||||
def evaluate_grpo(
|
|
||||||
model,
|
|
||||||
ref_model,
|
|
||||||
dataset,
|
|
||||||
tokenizer,
|
|
||||||
batch_size,
|
|
||||||
num_batches,
|
|
||||||
beta: float,
|
|
||||||
epslion: float,
|
|
||||||
group_size: int,
|
|
||||||
max_seq_length,
|
|
||||||
reward_funcs=None,
|
|
||||||
loss: callable = grpo_loss,
|
|
||||||
iterate_batches: callable = iterate_batches
|
|
||||||
):
|
|
||||||
all_losses = 0
|
|
||||||
ntokens = 0
|
|
||||||
|
|
||||||
index_iterator = iter(range(num_batches)) if num_batches != -1 else iter(int, 1)
|
|
||||||
|
|
||||||
for _, batch in zip(
|
|
||||||
index_iterator,
|
|
||||||
iterate_batches(
|
|
||||||
dataset=dataset,
|
|
||||||
tokenizer=tokenizer,
|
|
||||||
batch_size=batch_size,
|
|
||||||
max_seq_length=max_seq_length,
|
|
||||||
),
|
|
||||||
):
|
|
||||||
# Extract prompts from the batch (assuming the batch contains 'prompts')
|
|
||||||
prompts = batch.get("prompts", None)
|
|
||||||
|
|
||||||
# Call the loss function with the correct arguments
|
|
||||||
losses, toks, metrics = loss(
|
losses, toks, metrics = loss(
|
||||||
model=model,
|
model=model,
|
||||||
tokenizer=tokenizer,
|
tokenizer=tokenizer,
|
||||||
@ -313,15 +269,25 @@ def evaluate_grpo(
|
|||||||
epslion=epslion,
|
epslion=epslion,
|
||||||
ref_model=ref_model
|
ref_model=ref_model
|
||||||
)
|
)
|
||||||
|
|
||||||
all_losses += losses * toks
|
all_losses += losses * toks
|
||||||
ntokens += toks
|
ntokens += toks
|
||||||
|
|
||||||
|
if all_metrics is None:
|
||||||
|
all_metrics = {k: v * toks for k, v in metrics.items()}
|
||||||
|
else:
|
||||||
|
for k, v in metrics.items():
|
||||||
|
all_metrics[k] += v * toks
|
||||||
|
|
||||||
mx.eval(all_losses, ntokens)
|
mx.eval(all_losses, ntokens)
|
||||||
|
|
||||||
all_losses = mx.distributed.all_sum(all_losses, stream=mx.cpu)
|
all_losses = mx.distributed.all_sum(all_losses, stream=mx.cpu)
|
||||||
ntokens = mx.distributed.all_sum(ntokens, stream=mx.cpu)
|
ntokens = mx.distributed.all_sum(ntokens, stream=mx.cpu)
|
||||||
|
all_metrics = {k: mx.distributed.all_sum(v) for k, v in all_metrics.items()}
|
||||||
|
|
||||||
return (all_losses / ntokens).item()
|
avg_metrics = {k: (v / ntokens).item() for k, v in all_metrics.items()}
|
||||||
|
avg_loss = (all_losses / ntokens).item()
|
||||||
|
|
||||||
|
return avg_loss, ntokens, avg_metrics
|
||||||
|
|
||||||
|
|
||||||
def train(
|
def train(
|
||||||
@ -335,7 +301,7 @@ def train(
|
|||||||
iterate_batches: callable = iterate_batches,
|
iterate_batches: callable = iterate_batches,
|
||||||
training_callback: TrainingCallback = None,
|
training_callback: TrainingCallback = None,
|
||||||
):
|
):
|
||||||
print(f"Starting training..., iters: {args.iters}")
|
print(f"Starting GRPO training..., iters: {args.iters}")
|
||||||
world = mx.distributed.init()
|
world = mx.distributed.init()
|
||||||
world_size = world.size()
|
world_size = world.size()
|
||||||
rank = world.rank()
|
rank = world.rank()
|
||||||
@ -349,7 +315,7 @@ def train(
|
|||||||
|
|
||||||
def step(batch):
|
def step(batch):
|
||||||
# Forward and backward pass
|
# Forward and backward pass
|
||||||
(lvalue, toks), grad = loss_value_and_grad(model, *batch)
|
(loss, toks, metrics), grad = loss_value_and_grad(model, *batch)
|
||||||
|
|
||||||
# All reduce the gradients if running in distributed mode
|
# All reduce the gradients if running in distributed mode
|
||||||
grad = average_gradients(grad)
|
grad = average_gradients(grad)
|
||||||
@ -357,18 +323,22 @@ def train(
|
|||||||
# Model update
|
# Model update
|
||||||
optimizer.update(model, grad)
|
optimizer.update(model, grad)
|
||||||
|
|
||||||
return lvalue, toks
|
return loss, toks, metrics
|
||||||
|
|
||||||
loss_value_and_grad = nn.value_and_grad(model, loss)
|
loss_value_and_grad = nn.value_and_grad(model, loss)
|
||||||
|
|
||||||
# Save initial model weights as reference
|
|
||||||
ref_weights = {k: v.copy() for k, v in model.parameters().items()}
|
|
||||||
|
|
||||||
losses = 0
|
losses = 0
|
||||||
n_tokens = 0
|
n_tokens = 0
|
||||||
steps = 0
|
steps = 0
|
||||||
trained_tokens = 0
|
trained_tokens = 0
|
||||||
# Main training loop
|
accumulated_metrics = {
|
||||||
|
'rewards': 0,
|
||||||
|
'rewards_std': 0,
|
||||||
|
'grouped_rewards': 0,
|
||||||
|
'grouped_rewards_std': 0,
|
||||||
|
'kl': 0
|
||||||
|
}
|
||||||
|
|
||||||
start = time.perf_counter()
|
start = time.perf_counter()
|
||||||
for it, batch in zip(
|
for it, batch in zip(
|
||||||
range(1, args.iters + 1),
|
range(1, args.iters + 1),
|
||||||
@ -384,7 +354,7 @@ def train(
|
|||||||
# is always measured before any training.
|
# is always measured before any training.
|
||||||
if it == 1 or it % args.steps_per_eval == 0 or it == args.iters:
|
if it == 1 or it % args.steps_per_eval == 0 or it == args.iters:
|
||||||
stop = time.perf_counter()
|
stop = time.perf_counter()
|
||||||
val_loss = evaluate(
|
val_loss, val_ntokens, val_metrics = evaluate(
|
||||||
model=model,
|
model=model,
|
||||||
dataset=val_dataset,
|
dataset=val_dataset,
|
||||||
loss=loss,
|
loss=loss,
|
||||||
@ -398,61 +368,74 @@ def train(
|
|||||||
if rank == 0:
|
if rank == 0:
|
||||||
print(
|
print(
|
||||||
f"Iter {it}: "
|
f"Iter {it}: "
|
||||||
f"Val loss {val_loss:.3f}, "
|
f"Val loss {val_loss:.8f}, "
|
||||||
|
f"Val rewards {val_metrics['rewards']:.3f}, "
|
||||||
|
f"Val rewards_std {val_metrics['rewards_std']:.3f}, "
|
||||||
|
f"Val grouped_rewards {val_metrics['grouped_rewards']:.3f}, "
|
||||||
|
f"Val grouped_rewards_std {val_metrics['grouped_rewards_std']:.3f}, "
|
||||||
|
f"Val kl {val_metrics['kl']:.3f}, "
|
||||||
f"Val took {val_time:.3f}s",
|
f"Val took {val_time:.3f}s",
|
||||||
flush=True,
|
flush=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
if training_callback is not None:
|
if training_callback is not None:
|
||||||
val_info = {
|
training_callback.on_val_loss_report({
|
||||||
"iteration": it,
|
"iteration": it,
|
||||||
"val_loss": val_loss,
|
"val_loss": val_loss,
|
||||||
|
**{f"val_{k}": v for k, v in val_metrics.items()},
|
||||||
"val_time": val_time,
|
"val_time": val_time,
|
||||||
}
|
})
|
||||||
training_callback.on_val_loss_report(val_info)
|
|
||||||
|
|
||||||
start = time.perf_counter()
|
start = time.perf_counter()
|
||||||
|
|
||||||
lvalue, toks = step(batch)
|
loss, toks, metrics = step(batch)
|
||||||
losses += lvalue
|
losses += loss
|
||||||
n_tokens += toks
|
n_tokens += toks
|
||||||
steps += 1
|
steps += 1
|
||||||
|
for k, v in metrics.items():
|
||||||
|
accumulated_metrics[k] += v
|
||||||
mx.eval(state, losses, n_tokens)
|
mx.eval(state, losses, n_tokens)
|
||||||
|
|
||||||
# Report training loss if needed
|
|
||||||
if it % args.steps_per_report == 0 or it == args.iters:
|
if it % args.steps_per_report == 0 or it == args.iters:
|
||||||
stop = time.perf_counter()
|
stop = time.perf_counter()
|
||||||
|
|
||||||
train_loss = mx.distributed.all_sum(losses, stream=mx.cpu).item()
|
train_loss = mx.distributed.all_sum(losses, stream=mx.cpu).item()
|
||||||
train_loss /= steps * mx.distributed.init().size()
|
train_loss /= steps * mx.distributed.init().size()
|
||||||
|
avg_metrics = {k: v / (steps * world_size) for k, v in accumulated_metrics.items()}
|
||||||
n_tokens = mx.distributed.all_sum(n_tokens, stream=mx.cpu).item()
|
n_tokens = mx.distributed.all_sum(n_tokens, stream=mx.cpu).item()
|
||||||
learning_rate = optimizer.learning_rate.item()
|
learning_rate = optimizer.learning_rate.item()
|
||||||
it_sec = args.steps_per_report / (stop - start)
|
it_sec = args.steps_per_report / (stop - start)
|
||||||
tokens_sec = float(n_tokens) / (stop - start)
|
tokens_sec = float(n_tokens) / (stop - start)
|
||||||
trained_tokens += n_tokens
|
trained_tokens += n_tokens
|
||||||
peak_mem = mx.metal.get_peak_memory() / 1e9
|
peak_mem = mx.metal.get_peak_memory() / 1e9
|
||||||
|
|
||||||
if rank == 0:
|
if rank == 0:
|
||||||
print(
|
print(
|
||||||
f"Iter {it}: Train loss {train_loss:.3f}, "
|
f"Iter {it}: Train loss {train_loss:.8f}, "
|
||||||
|
f"Rewards {avg_metrics['rewards']:.3f}, "
|
||||||
|
f"Rewards_std {avg_metrics['rewards_std']:.3f}, "
|
||||||
|
f"Grouped Rewards {avg_metrics['grouped_rewards']:.3f}, "
|
||||||
|
f"Grouped Rewards {avg_metrics['grouped_rewards']:.3f}, "
|
||||||
|
f"Grouped Rewards_std {val_metrics['grouped_rewards_std']:.3f}, "
|
||||||
|
f"KL {val_metrics['kl']:.3f}, "
|
||||||
f"Learning Rate {learning_rate:.3e}, "
|
f"Learning Rate {learning_rate:.3e}, "
|
||||||
f"It/sec {it_sec:.3f}, "
|
f"It/sec {it_sec:.3f}, "
|
||||||
f"Tokens/sec {tokens_sec:.3f}, "
|
f"Tokens/sec {tokens_sec:.3f}, "
|
||||||
f"Trained Tokens {trained_tokens}, "
|
|
||||||
f"Peak mem {peak_mem:.3f} GB",
|
f"Peak mem {peak_mem:.3f} GB",
|
||||||
flush=True,
|
flush=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
if training_callback is not None:
|
if training_callback is not None:
|
||||||
train_info = {
|
training_callback.on_train_loss_report({
|
||||||
"iteration": it,
|
"iteration": it,
|
||||||
"train_loss": train_loss,
|
"train_loss": train_loss,
|
||||||
|
**{f"train_{k}": v for k, v in avg_metrics.items()},
|
||||||
"learning_rate": learning_rate,
|
"learning_rate": learning_rate,
|
||||||
"iterations_per_second": it_sec,
|
"iterations_per_second": it_sec,
|
||||||
"tokens_per_second": tokens_sec,
|
"tokens_per_second": tokens_sec,
|
||||||
"trained_tokens": trained_tokens,
|
"trained_tokens": trained_tokens,
|
||||||
"peak_memory": peak_mem,
|
"peak_memory": peak_mem,
|
||||||
}
|
})
|
||||||
training_callback.on_train_loss_report(train_info)
|
|
||||||
|
|
||||||
losses = 0
|
losses = 0
|
||||||
n_tokens = 0
|
n_tokens = 0
|
||||||
|
Loading…
Reference in New Issue
Block a user