diff --git a/llms/mlx_lm/tuner/ppo_trainer.py b/llms/mlx_lm/tuner/ppo_trainer.py index 40dffe63..9c5f0a90 100644 --- a/llms/mlx_lm/tuner/ppo_trainer.py +++ b/llms/mlx_lm/tuner/ppo_trainer.py @@ -16,55 +16,77 @@ from mlx.utils import tree_flatten from trainer import TrainingArgs, TrainingCallback, grad_checkpoint - -def compute_ppo_loss( - new_logprobs: mx.array, - old_logprobs: mx.array, - values: mx.array, - old_values: mx.array, - advantages: mx.array, - returns: mx.array, - padding_mask: mx.array, - padding_mask_p1: mx.array = None, - vf_coef: float = 0.5, - cliprange: float = 0.2, - cliprange_value: float = 0.2 -) -> tuple[mx.array, mx.array, mx.array]: - """Compute PPO loss with policy and value components and masking""" - padding_mask_p1 = padding_mask_p1 if padding_mask_p1 is not None else padding_mask - - # Value loss - vpred_clipped = mx.clip(values, old_values - cliprange_value, old_values + cliprange_value) - vf_losses = mx.maximum( - mx.square(values - returns), - mx.square(vpred_clipped - returns) - ) - vf_loss = 0.5 * mx.mean(mx.where(~padding_mask_p1, vf_losses, 0)) - - # Policy loss - ratio = mx.exp(new_logprobs - old_logprobs) - pg_losses = mx.maximum( - -advantages * ratio, - -advantages * mx.clip(ratio, 1.0 - cliprange, 1.0 + cliprange) - ) - pg_loss = mx.mean(mx.where(~padding_mask, pg_losses, 0)) - - total_loss = pg_loss + vf_coef * vf_loss - return total_loss, pg_loss, vf_loss - - @dataclass class PPOTrainingArgs(TrainingArgs): vf_coef: float = field(default=0.5, metadata={"help": "Value function coefficient"}) cliprange: float = field(default=0.2, metadata={"help": "Policy gradient clipping range"}) cliprange_value: float = field(default=0.2, metadata={"help": "Value function clipping range"}) + gamma: float = field(default=0.99, metadata={"help": "Discount factor"}) + lambda_: float = field(default=0.95, metadata={"help": "GAE lambda"}) +def compute_returns( + rewards: mx.array, + gamma: float = 0.99 +) -> mx.array: + """Compute returns with Generalized Advantage Estimation""" + returns = mx.zeros_like(rewards) + running_return = 0 + + for t in reversed(range(len(rewards))): + running_return = rewards[t] + gamma * running_return + returns = returns.at[t].set(running_return) + + return returns + +def compute_advantages( + values: mx.array, + returns: mx.array, + rewards: mx.array, + gamma: float = 0.99, + lambda_: float = 0.95 +) -> mx.array: + """Compute advantages using GAE""" + advantages = mx.zeros_like(returns) + running_advantage = 0 + + for t in reversed(range(len(returns))): + if t < len(returns) - 1: + delta = rewards[t] + gamma * values[t + 1] - values[t] + else: + delta = rewards[t] - values[t] + + running_advantage = delta + gamma * lambda_ * running_advantage + advantages = advantages.at[t].set(running_advantage) + + return (advantages - advantages.mean()) / (advantages.std() + 1e-8) + +def make_predictions(model, x, mask): + inputs = x[:, :-1] + targets = x[:, 1:] + + logits = model(inputs) + logits = logits.astype(mx.float32) + + return -nn.losses.cross_entropy(logits, targets) * mask[:, :-1] + +def compute_rewards(model, x, mask, reward_scale=1.0): + """ + Compute rewards based on model predictions and actual targets. + Basic implementation using log probabilities as rewards. + """ + logits = model(x[:, :-1]) + targets = x[:, 1:] + + log_probs = -nn.losses.cross_entropy(logits, targets, reduction='none') + rewards = log_probs * mask[:, :-1] * reward_scale + + return rewards + def ppo_loss( model, inputs, - targets, - lengths, + mask, old_logprobs, values, old_values, @@ -73,14 +95,10 @@ def ppo_loss( vf_coef=0.5, cliprange=0.2, cliprange_value=0.2 -): - # Get new logits and create length mask - logits = model(inputs).astype(mx.float32) - length_mask = mx.arange(inputs.shape[1])[None, :] < lengths[:, None] - +): # Get new log probs - new_logprobs = nn.losses.cross_entropy(logits, targets) * length_mask - ntoks = length_mask.sum() + new_logprobs = make_predictions(model, inputs, mask) + ntoks = mask[:, :-1].sum() new_logprobs = new_logprobs.sum() / ntoks # Value loss with clipping @@ -101,58 +119,52 @@ def ppo_loss( return total_loss, pg_loss, vf_loss, ntoks -def iterate_batches(dataset, tokenizer, batch_size, max_seq_length, train=False): - # Sort by length: - idx = sorted(range(len(dataset)), key=lambda idx: len(dataset[idx])) - if len(dataset) < batch_size: - raise ValueError( - f"Dataset must have at least batch_size={batch_size}" - f" examples but only has {len(dataset)}." - ) +def iterate_ppo_batches(dataset, tokenizer, batch_size, max_seq_length, train=False): + # Sort by length + idx = sorted(range(len(dataset)), key=lambda idx: len(dataset[idx])) + if len(dataset) < batch_size: + raise ValueError(f"Dataset must have at least batch_size={batch_size} examples but only has {len(dataset)}.") + + # Handle distributed training + step = mx.distributed.init().size() + if batch_size % step != 0: + raise ValueError("The batch size must be divisible by the number of workers") + + # Make batches + batch_idx = [idx[i:i+batch_size:step] for i in range(0, len(idx)-batch_size+1, batch_size)] + + while True: + indices = np.random.permutation(len(batch_idx)) + for i in indices: + batch = [dataset[j] for j in batch_idx[i]] + lengths = [len(x) for x in batch] + + # Handle sequence length + if max(lengths) > max_seq_length: + print(f"[WARNING] Truncating sequences longer than {max_seq_length}") + + # Pad to multiple of 8 + pad_to = 8 + max_length_in_batch = pad_to * ((max(lengths) + pad_to - 1) // pad_to) + max_length_in_batch = min(max_length_in_batch, max_seq_length) + + # Create batch array + batch_arr = np.zeros((batch_size // step, max_length_in_batch), np.int32) + mask = np.zeros((batch_size // step, max_length_in_batch), np.int32) + + for j in range(batch_size // step): + truncated_length = min(lengths[j], max_seq_length) + batch_arr[j, :truncated_length] = batch[j][:truncated_length] + mask[j, :truncated_length] = 1 + lengths[j] = truncated_length - # If running in distributed mode (N machines) then each one should skip N-1 - # samples - step = mx.distributed.init().size() - if batch_size % step != 0: - raise ValueError("The batch size must be divisible by the number of workers") - - # Make the batches: - batch_idx = [ - idx[i : i + batch_size : step] - for i in range(0, len(idx) - batch_size + 1, batch_size) - ] - - while True: - indices = np.random.permutation(len(batch_idx)) - for i in indices: - batch = [dataset[j] for j in batch_idx[i]] - lengths = [len(x) for x in batch] - if max(lengths) > max_seq_length: - print( - f"[WARNING] Some sequences are longer than {max_seq_length} tokens. " - f"The longest sentence {max(lengths)} will be truncated to {max_seq_length}. " - "Consider pre-splitting your data to save memory." - ) - - # Pad to the nearest multiple of 8 or the maximum length - pad_to = 8 - max_length_in_batch = pad_to * ((max(lengths) + pad_to - 1) // pad_to) - max_length_in_batch = min(max_length_in_batch, max_seq_length) - - batch_arr = np.zeros((batch_size // step, max_length_in_batch), np.int32) - - for j in range(batch_size // step): - truncated_length = min(lengths[j], max_seq_length) - batch_arr[j, :truncated_length] = batch[j][:truncated_length] - lengths[j] = ( - truncated_length # Update lengths to match truncated lengths - ) - batch = mx.array(batch_arr) - - yield batch[:, :-1], batch[:, 1:], mx.array(lengths) - - if not train: - break + batch = mx.array(batch_arr) + mask = mx.array(mask) + + yield batch, mask + + if not train: + break def evaluate( @@ -170,8 +182,8 @@ def evaluate( vf_coef=0.5, cliprange=0.2, cliprange_value=0.2, - loss: callable = compute_ppo_loss, - iterate_batches: callable = iterate_batches, + loss: callable = ppo_loss, + iterate_ppo_batches: callable = iterate_ppo_batches, ): total_loss = 0 total_pg_loss = 0 @@ -182,7 +194,7 @@ def evaluate( for _, batch in zip( index_iterator, - iterate_batches( + iterate_ppo_batches( dataset=dataset, tokenizer=tokenizer, batch_size=batch_size, @@ -221,12 +233,12 @@ def train( optimizer, train_dataset, val_dataset, - args: TrainingArgs = TrainingArgs(), + args: PPOTrainingArgs = PPOTrainingArgs(), loss: callable = ppo_loss, - iterate_batches: callable = iterate_batches, + iterate_ppo_batches: callable = iterate_ppo_batches, training_callback: TrainingCallback = None, ): - print(f"Starting training..., iters: {args.iters}") + print(f"Starting PPO training..., iters: {args.iters}") world = mx.distributed.init() world_size = world.size() rank = world.rank() @@ -239,18 +251,38 @@ def train( state = [model.state, optimizer.state] def step(batch): - # Forward and backward pass - (lvalue, toks), grad = loss_value_and_grad(model, *batch) - - # All reduce the gradients if running in distributed mode + x, mask = batch + + # Initial forward pass + old_logprobs = make_predictions(model, x, mask) + values = model.value_head(x[:, :-1]) + old_values = values.copy() + + # Compute rewards (implement reward calculation based on your task) + rewards = compute_rewards(model, x, mask) + + # Compute returns and advantages + returns = compute_returns(rewards, values, gamma=args.gamma) + advantages = compute_advantages(values, returns, rewards, + gamma=args.gamma, + lambda_=args.lambda_) + + def loss_fn(model, x, mask): + total_loss, pg_loss, vf_loss, ntoks = ppo_loss( + model, x, mask, + old_logprobs, values, old_values, + advantages, returns, + vf_coef=args.vf_coef, + cliprange=args.cliprange, + cliprange_value=args.cliprange_value + ) + return total_loss, ntoks, pg_loss, vf_loss + + (loss_val, toks, pg_loss, vf_loss), grad = nn.value_and_grad(model, loss_fn)(x, mask) grad = average_gradients(grad) - - # Model update optimizer.update(model, grad) - - return lvalue, toks - - loss_value_and_grad = nn.value_and_grad(model, loss) + + return loss_val, toks, pg_loss, vf_loss losses = 0 n_tokens = 0 @@ -260,7 +292,7 @@ def train( start = time.perf_counter() for it, batch in zip( range(1, args.iters + 1), - iterate_batches( + iterate_ppo_batches( dataset=train_dataset, tokenizer=tokenizer, batch_size=args.batch_size, @@ -280,7 +312,7 @@ def train( batch_size=args.batch_size, num_batches=args.val_batches, max_seq_length=args.max_seq_length, - iterate_batches=iterate_batches, + iterate_ppo_batches=iterate_ppo_batches, ) val_time = time.perf_counter() - stop if rank == 0: