mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-08-21 20:46:50 +08:00
updates
This commit is contained in:
parent
3ad6405298
commit
595125ad4e
@ -13,67 +13,92 @@ import numpy as np
|
|||||||
from mlx.nn.utils import average_gradients
|
from mlx.nn.utils import average_gradients
|
||||||
from mlx.utils import tree_flatten
|
from mlx.utils import tree_flatten
|
||||||
|
|
||||||
|
from trainer import TrainingArgs, TrainingCallback, grad_checkpoint
|
||||||
|
|
||||||
def grad_checkpoint(layer):
|
|
||||||
"""
|
|
||||||
Update all instances of type(layer) to use gradient checkpointing.
|
|
||||||
"""
|
|
||||||
fn = type(layer).__call__
|
|
||||||
|
|
||||||
def checkpointed_fn(model, *args, **kwargs):
|
|
||||||
def inner_fn(params, *args, **kwargs):
|
|
||||||
model.update(params)
|
|
||||||
return fn(model, *args, **kwargs)
|
|
||||||
|
|
||||||
return mx.checkpoint(inner_fn)(model.trainable_parameters(), *args, **kwargs)
|
def compute_ppo_loss(
|
||||||
|
new_logprobs: mx.array,
|
||||||
type(layer).__call__ = checkpointed_fn
|
old_logprobs: mx.array,
|
||||||
|
values: mx.array,
|
||||||
|
old_values: mx.array,
|
||||||
|
advantages: mx.array,
|
||||||
|
returns: mx.array,
|
||||||
|
padding_mask: mx.array,
|
||||||
|
padding_mask_p1: mx.array = None,
|
||||||
|
vf_coef: float = 0.5,
|
||||||
|
cliprange: float = 0.2,
|
||||||
|
cliprange_value: float = 0.2
|
||||||
|
) -> tuple[mx.array, mx.array, mx.array]:
|
||||||
|
"""Compute PPO loss with policy and value components and masking"""
|
||||||
|
padding_mask_p1 = padding_mask_p1 if padding_mask_p1 is not None else padding_mask
|
||||||
|
|
||||||
|
# Value loss
|
||||||
|
vpred_clipped = mx.clip(values, old_values - cliprange_value, old_values + cliprange_value)
|
||||||
|
vf_losses = mx.maximum(
|
||||||
|
mx.square(values - returns),
|
||||||
|
mx.square(vpred_clipped - returns)
|
||||||
|
)
|
||||||
|
vf_loss = 0.5 * mx.mean(mx.where(~padding_mask_p1, vf_losses, 0))
|
||||||
|
|
||||||
|
# Policy loss
|
||||||
|
ratio = mx.exp(new_logprobs - old_logprobs)
|
||||||
|
pg_losses = mx.maximum(
|
||||||
|
-advantages * ratio,
|
||||||
|
-advantages * mx.clip(ratio, 1.0 - cliprange, 1.0 + cliprange)
|
||||||
|
)
|
||||||
|
pg_loss = mx.mean(mx.where(~padding_mask, pg_losses, 0))
|
||||||
|
|
||||||
|
total_loss = pg_loss + vf_coef * vf_loss
|
||||||
|
return total_loss, pg_loss, vf_loss
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class TrainingArgs:
|
class PPOTrainingArgs(TrainingArgs):
|
||||||
batch_size: int = field(default=4, metadata={"help": "Minibatch size."})
|
vf_coef: float = field(default=0.5, metadata={"help": "Value function coefficient"})
|
||||||
iters: int = field(default=100, metadata={"help": "Iterations to train for."})
|
cliprange: float = field(default=0.2, metadata={"help": "Policy gradient clipping range"})
|
||||||
val_batches: int = field(
|
cliprange_value: float = field(default=0.2, metadata={"help": "Value function clipping range"})
|
||||||
default=25,
|
|
||||||
metadata={
|
|
||||||
"help": "Number of validation batches, -1 uses the entire validation set."
|
|
||||||
},
|
|
||||||
)
|
|
||||||
steps_per_report: int = field(
|
|
||||||
default=10,
|
|
||||||
metadata={"help": "Number of training steps between loss reporting."},
|
|
||||||
)
|
|
||||||
steps_per_eval: int = field(
|
|
||||||
default=200, metadata={"help": "Number of training steps between validations."}
|
|
||||||
)
|
|
||||||
steps_per_save: int = field(
|
|
||||||
default=100, metadata={"help": "Save the model every number steps"}
|
|
||||||
)
|
|
||||||
max_seq_length: int = field(
|
|
||||||
default=2048, metadata={"help": "Maximum sequence length."}
|
|
||||||
)
|
|
||||||
adapter_file: str = field(
|
|
||||||
default="adapters.safetensors",
|
|
||||||
metadata={"help": "Save/load path for the trained adapter weights."},
|
|
||||||
)
|
|
||||||
grad_checkpoint: bool = field(
|
|
||||||
default=False,
|
|
||||||
metadata={"help": "Use gradient checkpointing to reduce memory use."},
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def default_loss(model, inputs, targets, lengths):
|
def ppo_loss(
|
||||||
logits = model(inputs)
|
model,
|
||||||
logits = logits.astype(mx.float32)
|
inputs,
|
||||||
|
targets,
|
||||||
|
lengths,
|
||||||
|
old_logprobs,
|
||||||
|
values,
|
||||||
|
old_values,
|
||||||
|
advantages,
|
||||||
|
returns,
|
||||||
|
vf_coef=0.5,
|
||||||
|
cliprange=0.2,
|
||||||
|
cliprange_value=0.2
|
||||||
|
):
|
||||||
|
# Get new logits and create length mask
|
||||||
|
logits = model(inputs).astype(mx.float32)
|
||||||
|
length_mask = mx.arange(inputs.shape[1])[None, :] < lengths[:, None]
|
||||||
|
|
||||||
|
# Get new log probs
|
||||||
|
new_logprobs = nn.losses.cross_entropy(logits, targets) * length_mask
|
||||||
|
ntoks = length_mask.sum()
|
||||||
|
new_logprobs = new_logprobs.sum() / ntoks
|
||||||
|
|
||||||
length_mask = mx.arange(inputs.shape[1])[None, :] < lengths[:, None]
|
# Value loss with clipping
|
||||||
|
vpred_clipped = mx.clip(values, old_values - cliprange_value, old_values + cliprange_value)
|
||||||
|
vf_loss = 0.5 * mx.maximum(
|
||||||
|
mx.square(values - returns),
|
||||||
|
mx.square(vpred_clipped - returns)
|
||||||
|
).mean()
|
||||||
|
|
||||||
ce = nn.losses.cross_entropy(logits, targets) * length_mask
|
# Policy loss with clipping
|
||||||
ntoks = length_mask.sum()
|
ratio = mx.exp(new_logprobs - old_logprobs)
|
||||||
ce = ce.sum() / ntoks
|
pg_loss = mx.maximum(
|
||||||
|
-advantages * ratio,
|
||||||
|
-advantages * mx.clip(ratio, 1.0 - cliprange, 1.0 + cliprange)
|
||||||
|
).mean()
|
||||||
|
|
||||||
return ce, ntoks
|
total_loss = pg_loss + vf_coef * vf_loss
|
||||||
|
return total_loss, pg_loss, vf_loss, ntoks
|
||||||
|
|
||||||
|
|
||||||
def iterate_batches(dataset, tokenizer, batch_size, max_seq_length, train=False):
|
def iterate_batches(dataset, tokenizer, batch_size, max_seq_length, train=False):
|
||||||
@ -131,49 +156,63 @@ def iterate_batches(dataset, tokenizer, batch_size, max_seq_length, train=False)
|
|||||||
|
|
||||||
|
|
||||||
def evaluate(
|
def evaluate(
|
||||||
model,
|
model,
|
||||||
dataset,
|
dataset,
|
||||||
tokenizer,
|
tokenizer,
|
||||||
batch_size,
|
batch_size,
|
||||||
num_batches,
|
num_batches,
|
||||||
max_seq_length=2048,
|
max_seq_length=2048,
|
||||||
loss: callable = default_loss,
|
old_logprobs=None,
|
||||||
iterate_batches: callable = iterate_batches,
|
values=None,
|
||||||
|
old_values=None,
|
||||||
|
advantages=None,
|
||||||
|
returns=None,
|
||||||
|
vf_coef=0.5,
|
||||||
|
cliprange=0.2,
|
||||||
|
cliprange_value=0.2,
|
||||||
|
loss: callable = compute_ppo_loss,
|
||||||
|
iterate_batches: callable = iterate_batches,
|
||||||
):
|
):
|
||||||
all_losses = 0
|
total_loss = 0
|
||||||
ntokens = 0
|
total_pg_loss = 0
|
||||||
|
total_vf_loss = 0
|
||||||
|
ntokens = 0
|
||||||
|
|
||||||
index_iterator = iter(range(num_batches)) if num_batches != -1 else iter(int, 1)
|
index_iterator = iter(range(num_batches)) if num_batches != -1 else iter(int, 1)
|
||||||
|
|
||||||
for _, batch in zip(
|
for _, batch in zip(
|
||||||
index_iterator,
|
index_iterator,
|
||||||
iterate_batches(
|
iterate_batches(
|
||||||
dataset=dataset,
|
dataset=dataset,
|
||||||
tokenizer=tokenizer,
|
tokenizer=tokenizer,
|
||||||
batch_size=batch_size,
|
batch_size=batch_size,
|
||||||
max_seq_length=max_seq_length,
|
max_seq_length=max_seq_length,
|
||||||
),
|
),
|
||||||
):
|
):
|
||||||
losses, toks = loss(model, *batch)
|
losses, pg_loss, vf_loss, toks = loss(
|
||||||
all_losses += losses * toks
|
model, *batch,
|
||||||
ntokens += toks
|
old_logprobs=old_logprobs,
|
||||||
mx.eval(all_losses, ntokens)
|
values=values,
|
||||||
|
old_values=old_values,
|
||||||
|
advantages=advantages,
|
||||||
|
returns=returns,
|
||||||
|
vf_coef=vf_coef,
|
||||||
|
cliprange=cliprange,
|
||||||
|
cliprange_value=cliprange_value
|
||||||
|
)
|
||||||
|
|
||||||
|
total_loss += losses * toks
|
||||||
|
total_pg_loss += pg_loss * toks
|
||||||
|
total_vf_loss += vf_loss * toks
|
||||||
|
ntokens += toks
|
||||||
|
mx.eval(total_loss, total_pg_loss, total_vf_loss, ntokens)
|
||||||
|
|
||||||
all_losses = mx.distributed.all_sum(all_losses, stream=mx.cpu)
|
total_loss = mx.distributed.all_sum(total_loss, stream=mx.cpu)
|
||||||
ntokens = mx.distributed.all_sum(ntokens, stream=mx.cpu)
|
total_pg_loss = mx.distributed.all_sum(total_pg_loss, stream=mx.cpu)
|
||||||
|
total_vf_loss = mx.distributed.all_sum(total_vf_loss, stream=mx.cpu)
|
||||||
|
ntokens = mx.distributed.all_sum(ntokens, stream=mx.cpu)
|
||||||
|
|
||||||
return (all_losses / ntokens).item()
|
return (total_loss / ntokens).item(), (total_pg_loss / ntokens).item(), (total_vf_loss / ntokens).item()
|
||||||
|
|
||||||
|
|
||||||
class TrainingCallback:
|
|
||||||
|
|
||||||
def on_train_loss_report(self, train_info: dict):
|
|
||||||
"""Called to report training loss at specified intervals."""
|
|
||||||
pass
|
|
||||||
|
|
||||||
def on_val_loss_report(self, val_info: dict):
|
|
||||||
"""Called to report validation loss at specified intervals or the beginning."""
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
def train(
|
def train(
|
||||||
@ -183,7 +222,7 @@ def train(
|
|||||||
train_dataset,
|
train_dataset,
|
||||||
val_dataset,
|
val_dataset,
|
||||||
args: TrainingArgs = TrainingArgs(),
|
args: TrainingArgs = TrainingArgs(),
|
||||||
loss: callable = default_loss,
|
loss: callable = ppo_loss,
|
||||||
iterate_batches: callable = iterate_batches,
|
iterate_batches: callable = iterate_batches,
|
||||||
training_callback: TrainingCallback = None,
|
training_callback: TrainingCallback = None,
|
||||||
):
|
):
|
||||||
|
Loading…
Reference in New Issue
Block a user