diff --git a/llms/mlx_lm/lora.py b/llms/mlx_lm/lora.py index 4bb39832..4d79b1ac 100644 --- a/llms/mlx_lm/lora.py +++ b/llms/mlx_lm/lora.py @@ -1,21 +1,21 @@ # Copyright © 2024 Apple Inc. +from pathlib import Path import argparse +import types import math import os import re -import types -from pathlib import Path -import mlx.nn as nn import mlx.optimizers as optim +import mlx.nn as nn import numpy as np import yaml +from .tuner.grpo_trainer import GRPOTrainingArgs, evaluate_grpo, train_grpo +from .tuner.trainer import TrainingArgs, TrainingCallback, evaluate, train from .tokenizer_utils import TokenizerWrapper from .tuner.datasets import load_dataset -from .tuner.trainer import TrainingArgs, TrainingCallback, evaluate, train -from .tuner.grpo_trainer import GRPOTrainingArgs, evaluate_grpo, train_grpo from .tuner.utils import ( build_schedule, linear_to_lora_layers, @@ -73,6 +73,8 @@ CONFIG_DEFAULTS = { "max_completion_length": 512, "use_chat_template": False, "use_prompt": False, + "temperature": 1.0, + "reward_weights": None, } @@ -224,6 +226,18 @@ def build_parser(): help="Rather to use the prompt from the R1 paper.", default=None, ) + parser.add_argument( + "--temperature", + type=float, + help="Temperature for sampling. The higher the temperature, the more random the completions.", + default=1.0, + ) + parser.add_argument( + "--reward-weights", + type=str, + help="Weights for each reward function. Must match the number of reward functions and be in this format [0.1, 0.2, 0.3, 0.4, 0.5]. If not given, all rewards are weighted equally with weight `1.0`.", + default=None, + ) return parser def train_model_grpo(model, tokenizer, args, opt, train_set, valid_set, adapter_file, training_callback): @@ -241,7 +255,9 @@ def train_model_grpo(model, tokenizer, args, opt, train_set, valid_set, adapter_ beta=args.beta, group_size=args.group_size, epsilon=args.epsilon, - reference_model_path=args.reference_model_path + reference_model_path=args.reference_model_path, + temperature=args.temperature, + reward_weights=[float(x) for x in args.reward_weights.strip('[]').split(',')] if args.reward_weights else None ) if args.reference_model_path: diff --git a/llms/mlx_lm/tuner/grpo_trainer.py b/llms/mlx_lm/tuner/grpo_trainer.py index d0fa5fae..e96b8f29 100644 --- a/llms/mlx_lm/tuner/grpo_trainer.py +++ b/llms/mlx_lm/tuner/grpo_trainer.py @@ -35,13 +35,24 @@ class GRPOTrainingArgs(TrainingArgs): "help": "Path to reference model weights. If None, uses the same model." } ) + temperature: float = field( + default=1.0, + metadata={ + "help": "Temperature for sampling. The higher the temperature, the more random the completions." + } + ) + reward_weights: Optional[List[float]] = field( + default=None, + metadata={ + "help": "Weights for each reward function. Must match the number of reward functions. If `None`, all rewards are weighted equally with weight `1.0`." + } + ) RewardFunctions = Callable[[List[str], List[str], List[str]], List[float]] def r1_extract_xml_answer(text: str) -> str: - """Extracts the answer from an XML formatted text string.""" try: answer = text.split("")[-1] answer = answer.split("")[0] @@ -52,14 +63,12 @@ def r1_extract_xml_answer(text: str) -> str: def r1_int_reward_func(prompts: list, completions: list, answer: list, **kwargs) -> list[float]: - """Ensures we always return a list of floats.""" if not completions: return [0.0] * len(prompts) extracted_responses = [r1_extract_xml_answer(r) for r in completions] return [0.5 if r and r.isdigit() else 0.0 for r in extracted_responses] def r1_accuracy_reward_func(prompts: list, completions: list, answer: list, **kwargs) -> list[float]: - """Ensures we always return a list of floats.""" if not completions or not answer: return [0.0] * len(prompts) extracted_responses = [r1_extract_xml_answer(r) for r in completions] @@ -67,7 +76,6 @@ def r1_accuracy_reward_func(prompts: list, completions: list, answer: list, **kw def r1_soft_format_reward_func(prompts: list, completions: list, answer: list, **kwargs) -> list[float]: - """Ensures we always return a list of floats.""" if not completions: return [0.0] * len(prompts) pattern = r".*?\s*.*?" @@ -76,7 +84,6 @@ def r1_soft_format_reward_func(prompts: list, completions: list, answer: list, * def r1_strict_format_reward_func(prompts: list, completions: list, answer: list, **kwargs) -> list[float]: - """Ensures we always return a list of floats.""" if not completions: return [0.0] * len(prompts) pattern = r"^\n.*?\n\n\n.*?\n\n$" @@ -85,16 +92,13 @@ def r1_strict_format_reward_func(prompts: list, completions: list, answer: list, def r1_count_xml(prompts: list, completions: list, answer: list, **kwargs) -> list[float]: - """Ensures we always return a list of floats.""" if not completions: return [0.0] * len(prompts) - scores = [] for text in completions: if not text: scores.append(0.0) continue - count = 0.0 if text.count("\n") == 1: count += 0.125 @@ -104,13 +108,9 @@ def r1_count_xml(prompts: list, completions: list, answer: list, **kwargs) -> li count += 0.125 if text.count("\n\n") == 1: count += 0.125 - - # Penalize extra text after end_text = text.split("\n\n")[-1] count -= len(end_text) * 0.001 if len(end_text) > 0 else 0 - - scores.append(max(0.0, count)) # Ensure non-negative score - + scores.append(max(0.0, count)) return scores @@ -119,22 +119,18 @@ def generate_grpo(model: nn.Module, prompt, max_tokens, tokenizer, temperature): prompt = prompt[None, :] if prompt.shape[1] == 0: return None - end_sequence = tokenizer.encode("") end_sequence_length = len(end_sequence) - initial_length = prompt.shape[1] output = mx.zeros((initial_length + max_tokens,), dtype=mx.int32) output[:initial_length] = prompt[0] current_length = initial_length - try: def sample(logits): if temperature > 0: logits /= temperature logprobs = logits - mx.logsumexp(logits, keepdims=True) return mx.random.categorical(logprobs[None, :]).astype(mx.int32)[0] - for _ in range(max_tokens): current_input = output[:current_length][None, :] logits = model(current_input) @@ -143,18 +139,14 @@ def generate_grpo(model: nn.Module, prompt, max_tokens, tokenizer, temperature): token_value = next_token.item() output[current_length] = token_value current_length += 1 - if token_value == tokenizer.eos_token_id: break - if current_length >= end_sequence_length: last_tokens = output[current_length - end_sequence_length:current_length].tolist() if last_tokens == end_sequence: break - if current_length > initial_length: return output[:current_length] - except Exception as e: print(f"Generation error: {str(e)}") return None @@ -192,9 +184,10 @@ def grpo_loss( group_size=4, epsilon=1e-4, max_tokens=64, - temperature=1.0 + temperature=1.0, + reward_weights=None ): - prompt_tokens, answer_tokens, prompt_text, answer_text = batch + prompt_tokens, _, prompt_text, answer_text = batch batch_size = len(prompt_tokens) all_completions = [] @@ -273,18 +266,34 @@ def grpo_loss( token_log_probs = mx.stack(padded_log_probs) ref_token_log_probs = mx.stack(padded_ref_log_probs) - # Rewards and advantages - rewards = mx.zeros((len(all_completions),)) + # Create array to store rewards from each function + all_func_rewards = [] + + # Collect rewards from each function separately for reward_func in reward_funcs: func_rewards = mx.array(reward_func( prompts=expanded_prompts, completions=all_completion_texts, answer=expanded_answers )) - rewards += func_rewards + all_func_rewards.append(func_rewards) - if len(reward_funcs) > 1: - rewards /= len(reward_funcs) + # Stack rewards to shape (num_samples, num_funcs) + rewards = mx.stack(all_func_rewards, axis=1) + print(f"Rewards: {rewards}") + + # Apply weights and sum + if reward_weights is not None: + if len(reward_weights) != len(reward_funcs): + raise ValueError( + f"Number of reward weights ({len(reward_weights)}) must match number of reward " + f"functions ({len(reward_funcs)})" + ) + reward_weights = mx.array(reward_weights, dtype=mx.float32) + else: + reward_weights = mx.ones(len(reward_funcs), dtype=mx.float32) + rewards = (rewards * mx.expand_dims(reward_weights, 0)).sum(axis=1) + print(f"Rewards after weights: {rewards}") # Reshape rewards and compute advantages rewards_reshaped = rewards.reshape(batch_size, group_size) @@ -397,15 +406,11 @@ def evaluate_grpo( epsilon: float, group_size: int, max_seq_length, + temperature: float, reward_funcs: Optional[List[RewardFunctions]] = None, loss_fn: callable = grpo_loss, iterate_batches: callable = iterate_grpo_batches ): - """ - Evaluate model using GRPO loss. - Returns: - tuple: (average loss, number of tokens, average metrics) - """ all_losses = 0 ntokens = 0 all_metrics = None @@ -428,7 +433,8 @@ def evaluate_grpo( beta=beta, group_size=group_size, epsilon=epsilon, - ref_model=ref_model + ref_model=ref_model, + temperature=temperature ) all_losses += losses * toks @@ -442,12 +448,10 @@ def evaluate_grpo( mx.eval(all_losses, ntokens) - # Aggregate across distributed workers all_losses = mx.distributed.all_sum(all_losses, stream=mx.cpu) ntokens = mx.distributed.all_sum(ntokens, stream=mx.cpu) all_metrics = {k: mx.distributed.all_sum(v) for k, v in all_metrics.items()} - # Calculate averages avg_metrics = {k: (v / ntokens).item() for k, v in all_metrics.items()} avg_loss = (all_losses / ntokens).item() @@ -486,8 +490,6 @@ def train_grpo( state = [model.state, optimizer.state] def step(batch): - - # Forward and backward pass (loss, toks, metrics), grad = loss_value_and_grad( model, tokenizer=tokenizer, @@ -498,12 +500,11 @@ def train_grpo( epsilon=args.epsilon, ref_model=ref_model, max_tokens=args.max_completion_length, + temperature=args.temperature ) - # All reduce the gradients if running in distributed mode grad = average_gradients(grad) - # Model update optimizer.update(model, grad) return loss, toks, metrics @@ -536,8 +537,6 @@ def train_grpo( train=True, ), ): - # Report validation loss if needed, the first validation loss - # is always measured before any training. if it == 1 or it % args.steps_per_eval == 0 or it == args.iters: stop = time.perf_counter() val_loss, val_ntokens, val_metrics = evaluate_grpo( @@ -553,6 +552,7 @@ def train_grpo( max_seq_length=args.max_seq_length, beta=args.beta, epsilon=args.epsilon, + temperature=args.temperature, iterate_batches=iterate_batches, ) val_time = time.perf_counter() - stop @@ -566,7 +566,6 @@ def train_grpo( f"Val kl {val_metrics['kl']:.3f}" ) - # Add reward function specific metrics for i, reward_func in enumerate(reward_funcs): val_metrics_str += ( f", Val {reward_func.__name__}_mean {val_metrics[f'{reward_func.__name__}_mean']:.3f}, " @@ -622,7 +621,6 @@ def train_grpo( f"KL {avg_metrics['kl']:.3f}" ) - # Add reward function specific metrics for i, reward_func in enumerate(reward_funcs): func_name = reward_func.__name__ train_metrics_str += ( @@ -656,7 +654,6 @@ def train_grpo( steps = 0 start = time.perf_counter() - # Save adapter weights if it % args.steps_per_save == 0: adapter_weights = dict(tree_flatten(model.trainable_parameters())) mx.save_safetensors(str(args.adapter_file), adapter_weights) @@ -669,7 +666,6 @@ def train_grpo( f"{args.adapter_file} and {checkpoint}." ) - # Save final weights adapter_weights = dict(tree_flatten(model.trainable_parameters())) mx.save_safetensors(str(args.adapter_file), adapter_weights) print(f"Saved final weights to {args.adapter_file}.") \ No newline at end of file