diff --git a/llms/mlx_lm/lora.py b/llms/mlx_lm/lora.py index b948ae01..68fa93da 100644 --- a/llms/mlx_lm/lora.py +++ b/llms/mlx_lm/lora.py @@ -206,15 +206,15 @@ def build_parser(): ) parser.add_argument( "--use-chat-template", - type=bool, + action="store_true", help="If the model is a Chat model, use the Chat template.", - default=False, + default=None, ) parser.add_argument( "--use-prompt", - type=bool, - help="Rather to use the prompt from teh R1 paper.", - default=False, + action="store_true", + help="Rather to use the prompt from the R1 paper.", + default=None, ) return parser diff --git a/llms/mlx_lm/tuner/grpo_trainer.py b/llms/mlx_lm/tuner/grpo_trainer.py index 0210b44a..64b0bc49 100644 --- a/llms/mlx_lm/tuner/grpo_trainer.py +++ b/llms/mlx_lm/tuner/grpo_trainer.py @@ -12,7 +12,6 @@ from mlx.utils import tree_flatten from .trainer import grad_checkpoint, TrainingArgs, TrainingCallback, average_gradients, iterate_batches - @dataclass class GRPOTrainingArgs(TrainingArgs): group_size: int = field( @@ -35,7 +34,6 @@ class GRPOTrainingArgs(TrainingArgs): } ) - def r1_extract_xml_answer(text: str) -> str: """Extracts the answer from an XML formatted text string.""" try: @@ -46,62 +44,30 @@ def r1_extract_xml_answer(text: str) -> str: print("[extract_xml_answer] Failed to extract answer from: ", text) return "" -def r1_accuracy_reward_func(prompts: list, completions: list, answer: list, **kwargs) -> list[float]: - """Calculates reward based on accuracy of extracted answers. - Args: - prompts: List of input prompts - completions: List of completion strings - answer: Expected answer or list of answers - **kwargs: Additional arguments - Returns: - list[float]: Reward values for each completion - """ - extracted_responses = [r1_extract_xml_answer(r) for r in completions] - return [2.0 if r == a else 0.0 for r, a in zip(extracted_responses, answer)] - def r1_int_reward_func(prompts: list, completions: list, answer: list, **kwargs) -> list[float]: - """Rewards numerical responses. - Args: - prompts: List of input prompts - completions: List of completion strings - answer: Expected answer or list of answers - **kwargs: Additional arguments - Returns: - list[float]: Reward values for each completion - """ extracted_responses = [r1_extract_xml_answer(r) for r in completions] return [0.5 if r.isdigit() else 0.0 for r in extracted_responses] +def r1_accuracy_reward_func(prompts: list, completions: list, answer: list, **kwargs) -> list[float]: + extracted_responses = [r1_extract_xml_answer(r) for r in completions] + return [2.0 if r == a else 0.0 for r, a in zip(extracted_responses, answer)] + def r1_soft_format_reward_func(prompts: list, completions: list, answer: list, **kwargs) -> list[float]: """Rewards completions with flexible XML format.""" pattern = r".*?\s*.*?" matches = [re.match(pattern, r) for r in completions] return [0.5 if match else 0.0 for match in matches] +def r1_int_reward_func(prompts: list, completions: list, answer: list, **kwargs) -> list[float]: + extracted_responses = [r1_extract_xml_answer(r) for r in completions] + return [0.5 if r.isdigit() else 0.0 for r in extracted_responses] + def r1_strict_format_reward_func(prompts: list, completions: list, answer: list, **kwargs) -> list[float]: - """Rewards completions with strict XML format. - Args: - prompts: List of input prompts - completions: List of completion strings - answer: Expected answer or list of answers - **kwargs: Additional arguments - Returns: - list[float]: Reward values for each completion - """ pattern = r"^\n.*?\n\n\n.*?\n\n$" matches = [re.match(pattern, r) for r in completions] return [0.5 if match else 0.0 for match in matches] def r1_count_xml(prompts: list, completions: list, answer: list, **kwargs) -> list[float]: - """Calculates score based on XML formatting. - Args: - prompts: List of input prompts (unused) - completions: List of completion strings to evaluate - answer: Expected answer or list of answers (unused) - **kwargs: Additional arguments - Returns: - list[float]: List of scores based on XML tag presence and formatting - """ scores = [] for text in completions: count = 0.0 @@ -116,10 +82,9 @@ def r1_count_xml(prompts: list, completions: list, answer: list, **kwargs) -> li count += 0.125 count -= (len(text.split("\n")[-1]) - 1)*0.001 scores.append(count) - return scores -def generate_grpo(model, prompt, max_tokens, tokenizer, temperature=1.0): +def generate_grpo(model, prompt, max_tokens, tokenizer, temperature): if len(prompt.shape) == 1: prompt = prompt[None, :] if prompt.shape[1] == 0: @@ -172,30 +137,24 @@ def generate_grpo(model, prompt, max_tokens, tokenizer, temperature=1.0): def get_per_token_logps(model, inputs, lengths): - logits = model(inputs).astype(mx.float16) # [batch_size, seq_len, vocab_size] - # Remove last position as it corresponds to the next token prediction - logits = logits[:, :-1, :] # [batch_size, seq_len-1, vocab_size] - targets = inputs[:, 1:] # Shift inputs to get targets + logits = model(inputs).astype(mx.float16) + logits = logits[:, :-1, :] + targets = inputs[:, 1:] - # Process sequences individually to save memory per_token_logps = [] for i in range(logits.shape[0]): - # Get sequence length for this example - seq_len = int(lengths[i]) - 1 # -1 because we removed last position + seq_len = int(lengths[i]) - 1 - # Get logits and targets for this sequence - seq_logits = logits[i, :seq_len] # [seq_len, vocab_size] - seq_targets = targets[i, :seq_len] # [seq_len] + seq_logits = logits[i, :seq_len] + seq_targets = targets[i, :seq_len] - # Compute log probabilities - log_probs = nn.log_softmax(seq_logits, axis=-1) # [seq_len, vocab_size] + log_probs = nn.log_softmax(seq_logits, axis=-1) - # Gather log probs for actual tokens token_log_probs = mx.take_along_axis( log_probs, seq_targets.reshape(seq_len, 1), axis=-1 - ).squeeze(-1) # [seq_len] + ).squeeze(-1) per_token_logps.append(token_log_probs) mx.eval(logits) @@ -316,7 +275,7 @@ def grpo_loss( advantages = (rewards - mean_rewards) / (std_rewards + epsilon) # Compute KL divergence using Schulman's approximator - kl_div = mx.exp(ref_token_log_probs - token_log_probs) - (ref_token_log_probs - token_log_probs) - 1 + kl_div = (mx.exp(token_log_probs - ref_token_log_probs) - 1) - (token_log_probs - ref_token_log_probs) # Create mask for valid tokens length_mask = mx.arange(inputs.shape[1] - 1)[None, :] < (lengths[:, None] - 1) @@ -325,10 +284,10 @@ def grpo_loss( policy_ratio = mx.exp(token_log_probs - mx.stop_gradient(token_log_probs)) # Compute per-token loss following GRPO formula - per_token_loss = -(policy_ratio * advantages.reshape(-1, 1) - beta * kl_div) + per_token_loss = -(policy_ratio * advantages.reshape(-1, 1) - beta * kl_div) * length_mask # Average over tokens and sequences - sequence_sums = (per_token_loss * length_mask).sum(axis=1) + sequence_sums = per_token_loss.sum(axis=1) sequence_lengths = length_mask.sum(axis=1) loss = (sequence_sums / sequence_lengths).mean()