From e33d9d509bee308daf74f1822708f14954565934 Mon Sep 17 00:00:00 2001 From: Goekdeniz-Guelmez Date: Wed, 12 Feb 2025 11:07:53 +0100 Subject: [PATCH] updates --- llms/mlx_lm/tuner/datasets.py | 34 +++----- llms/mlx_lm/tuner/grpo_trainer.py | 126 +++++++++++++----------------- llms/mlx_lm/tuner/utils.py | 21 +---- 3 files changed, 70 insertions(+), 111 deletions(-) diff --git a/llms/mlx_lm/tuner/datasets.py b/llms/mlx_lm/tuner/datasets.py index d82fa0ff..fb19ba50 100644 --- a/llms/mlx_lm/tuner/datasets.py +++ b/llms/mlx_lm/tuner/datasets.py @@ -2,9 +2,8 @@ import itertools import json import types from pathlib import Path -from typing import Any, Dict, List, Optional, Tuple +from typing import Any, Dict, List, Tuple -from .utils import GRPOExample from transformers import PreTrainedTokenizer @@ -12,7 +11,7 @@ class GRPODataset: """ Dataset wrapper for GRPO training data. Each example should have a 'prompt' and 'answer' field. - Returns data as GRPOExample instances. + Returns data in (prompt_tokens, answer_tokens, prompt_str, answer_str) tuple format. """ def __init__( self, @@ -23,40 +22,33 @@ class GRPODataset: use_chat_template: bool = False, use_prompt: bool = False ): - self._data: List[GRPOExample] = [] + self._data = [] for item in data: prompt_str = str(item[prompt_key]) answer_str = str(item[answer_key]) - if use_chat_template: prompt_tokens = tokenizer.apply_chat_template( [ {'role': 'system', 'content': """A conversation between User and Assistant. The user asks a question, and the Assistant solves it. - The assistantfirst thinks about the reasoning process in the mind and then provides the user with the answer. - The reasoning process and answer are enclosed within and tags, respectively, i.e., reasoning process here answer here ."""}, - {'role': 'user', 'content': prompt_str} + The assistantfirst thinks about the reasoning process in the mind and then provides the user with the answer. + The reasoning process and answer are enclosed within and tags, respectively, i.e., reasoning process here answer here ."""}, + {'role': 'user', 'content': prompt_str} ], ) answer_tokens = tokenizer.encode(answer_str) else: if use_prompt: prompt_tokens = tokenizer.encode(f"""A conversation between User and Assistant. The user asks a question, and the Assistant solves it. - The assistantfirst thinks about the reasoning process in the mind and then provides the user with the answer. - The reasoning process and answer are enclosed within and tags, respectively, i.e., reasoning process here answer here . - User: {prompt_str} Assistant: """) + The assistantfirst thinks about the reasoning process in the mind and then provides the user with the answer. + The reasoning process and answer are enclosed within and tags, respectively, i.e., reasoning process here answer here . + User: {prompt_str} Assistant: """) else: prompt_tokens = tokenizer.encode(prompt_str) answer_tokens = tokenizer.encode(answer_str) - - self._data.append(GRPOExample( - prompt_tokens=prompt_tokens, - answer_tokens=answer_tokens, - prompt_text=prompt_str, - answer_text=answer_str - )) + self._data.append((prompt_tokens, answer_tokens, prompt_str, answer_str)) - def __getitem__(self, idx: int) -> GRPOExample: - """Returns a GRPOExample instance.""" + def __getitem__(self, idx: int) -> Tuple[List[int], List[int], str, str]: + """Returns a (prompt_tokens, answer_tokens, prompt_str, answer_str) tuple.""" return self._data[idx] def __len__(self) -> int: @@ -318,7 +310,7 @@ def load_dataset(args, tokenizer: PreTrainedTokenizer): train, valid, test = load_local_dataset(args, data_path, tokenizer, args) else: print(f"Loading Hugging Face dataset {args.data}.") - train, valid, test = load_hf_dataset(args, args.data, tokenizer, args) + train, valid, test = load_hf_dataset(args.data, tokenizer, args) if args.train and len(train) == 0: raise ValueError( diff --git a/llms/mlx_lm/tuner/grpo_trainer.py b/llms/mlx_lm/tuner/grpo_trainer.py index 13954665..d0fa5fae 100644 --- a/llms/mlx_lm/tuner/grpo_trainer.py +++ b/llms/mlx_lm/tuner/grpo_trainer.py @@ -1,18 +1,18 @@ # Copyright © 2024 Apple Inc. -import time +from typing import List, Optional, Callable from dataclasses import dataclass, field -from typing import List, Iterator, Optional from pathlib import Path +import time import re +from mlx.utils import tree_flatten import mlx.core as mx import mlx.nn as nn import numpy as np -from mlx.utils import tree_flatten -from .utils import GRPOBatch, GRPOExample -from .trainer import grad_checkpoint, TrainingArgs, TrainingCallback, average_gradients, iterate_batches +from .trainer import grad_checkpoint, TrainingArgs, TrainingCallback, average_gradients + @dataclass class GRPOTrainingArgs(TrainingArgs): @@ -37,6 +37,9 @@ class GRPOTrainingArgs(TrainingArgs): ) +RewardFunctions = Callable[[List[str], List[str], List[str]], List[float]] + + def r1_extract_xml_answer(text: str) -> str: """Extracts the answer from an XML formatted text string.""" try: @@ -180,10 +183,10 @@ def get_per_token_logps(model: nn.Module, inputs, lengths): def grpo_loss( - model: nn.Module, - ref_model: Optional[nn.Module], + model, + ref_model, tokenizer, - batch=GRPOBatch, + batch, reward_funcs=None, beta=0.1, group_size=4, @@ -191,18 +194,14 @@ def grpo_loss( max_tokens=64, temperature=1.0 ): - prompts_tokens = batch.prompt_tokens - answers_tokens = batch.answer_tokens - prompts_text = batch.prompt_texts - answers_text = batch.answer_texts - batch_size = len(prompts_tokens) - - # Generation logic remains the same + prompt_tokens, answer_tokens, prompt_text, answer_text = batch + batch_size = len(prompt_tokens) + all_completions = [] all_completion_texts = [] for i in range(0, batch_size, batch_size): - batch_prompts = prompts_tokens[i:i+batch_size] + batch_prompts = prompt_tokens[i:i+batch_size] for prompt in batch_prompts: prompt_tensor = mx.array(prompt) for _ in range(group_size): @@ -212,8 +211,6 @@ def grpo_loss( completion_text = tokenizer.decode(completion_ids.tolist()) all_completions.append(completion_ids) all_completion_texts.append(completion_text) - - # Clear completion tensors mx.eval(completion_ids) del completion_ids except Exception as e: @@ -222,12 +219,11 @@ def grpo_loss( mx.metal.clear_cache() - # Prepare inputs expanded_answers = [] expanded_prompts = [] for i in range(batch_size): - expanded_answers.extend([answers_text[i]] * group_size) - expanded_prompts.extend([prompts_text[i]] * group_size) + expanded_answers.extend([answer_text[i]] * group_size) + expanded_prompts.extend([prompt_text[i]] * group_size) max_length = max(ids.shape[0] for ids in all_completions) padded_completions = [] @@ -260,6 +256,8 @@ def grpo_loss( ref_token_log_probs = token_log_probs else: ref_token_log_probs = get_per_token_logps(ref_model, inputs, lengths) + mx.eval(ref_token_log_probs) + mx.metal.clear_cache() max_len = max(x.shape[0] for x in token_log_probs) padded_log_probs = [] @@ -275,7 +273,7 @@ def grpo_loss( token_log_probs = mx.stack(padded_log_probs) ref_token_log_probs = mx.stack(padded_ref_log_probs) - # Calculate rewards and advantages + # Rewards and advantages rewards = mx.zeros((len(all_completions),)) for reward_func in reward_funcs: func_rewards = mx.array(reward_func( @@ -288,7 +286,7 @@ def grpo_loss( if len(reward_funcs) > 1: rewards /= len(reward_funcs) - # Reshape rewards and compute advantages following GRPO formula + # Reshape rewards and compute advantages rewards_reshaped = rewards.reshape(batch_size, group_size) mean_rewards = mx.broadcast_to(mx.mean(rewards_reshaped, axis=1)[:, None], (rewards_reshaped.shape[0], group_size)).reshape(-1) std_rewards = mx.broadcast_to(mx.std(rewards_reshaped, axis=1)[:, None], (rewards_reshaped.shape[0], group_size)).reshape(-1) @@ -303,7 +301,7 @@ def grpo_loss( # Compute policy ratio policy_ratio = mx.exp(mx.array(token_log_probs - mx.stop_gradient(ref_token_log_probs))) - # Compute per-token loss following GRPO formula + # Compute per-token loss per_token_loss = -((policy_ratio * advantages.reshape(-1, 1) - beta * kl_div) * length_mask) # Average over tokens @@ -339,59 +337,51 @@ def grpo_loss( return loss, sequence_lengths.sum(), metrics -def iterate_grpo_batches( - dataset: List[GRPOExample], - batch_size: int, - max_seq_length: int, - train: bool = False, -) -> Iterator[GRPOBatch]: +def iterate_grpo_batches(dataset, batch_size, max_seq_length, train=False): + if not dataset or not isinstance(dataset[0], tuple) or len(dataset[0]) != 4: + raise ValueError("Dataset must be list of (prompt_tokens, answer_tokens, prompt_str, answer_str) tuples") + + def length_key(i): + return len(dataset[i][0]) + len(dataset[i][1]) + + idx = sorted(range(len(dataset)), key=length_key) + if len(dataset) < batch_size: raise ValueError( f"Dataset must have at least batch_size={batch_size} " f"examples but only has {len(dataset)}." ) - # Get MLX distributed setup step = mx.distributed.init().size() if batch_size % step != 0: raise ValueError("The batch size must be divisible by the number of workers") - # Sort by combined length for efficient batching - def length_key(example: GRPOExample) -> int: - return len(example.prompt_tokens) + len(example.answer_tokens) - - sorted_dataset = sorted(dataset, key=length_key) - - # Create batch indices - num_complete_batches = (len(dataset) - batch_size + 1) // batch_size - batch_starts = range(0, num_complete_batches * batch_size, batch_size) - + def batch_index_generator(): + for i in range(0, len(idx) - batch_size + 1, batch_size): + yield idx[i : i + batch_size : step] + while True: - # Shuffle batch start indices - shuffled_starts = np.random.permutation(batch_starts) + indices = ( + np.random.permutation(list(batch_index_generator())) if train + else batch_index_generator() + ) - for start_idx in shuffled_starts: - # Account for distributed setup by taking every step-th example - batch_idx = list(range(start_idx, start_idx + batch_size, step)) - current_batch = [sorted_dataset[j] for j in batch_idx] + for batch_idx in indices: + current_batch = [dataset[j] for j in batch_idx] - # Create batch using dataclass attributes - batch = GRPOBatch( - prompt_tokens=[ex.prompt_tokens for ex in current_batch], - answer_tokens=[ex.answer_tokens for ex in current_batch], - prompt_texts=[ex.prompt_text for ex in current_batch], - answer_texts=[ex.answer_text for ex in current_batch] - ) - - # Check sequence lengths - if any(len(tokens) > max_seq_length for tokens in batch.prompt_tokens): + prompts_tokens = [item[0] for item in current_batch] + answers_tokens = [item[1] for item in current_batch] + prompts_text = [item[2] for item in current_batch] + answers_text = [item[3] for item in current_batch] + + if any(len(p) > max_seq_length for p in prompts_tokens): print( f"[WARNING] Some prompts are longer than {max_seq_length} tokens. " "Long prompts will be truncated." ) - - yield batch - + + yield prompts_tokens, answers_tokens, prompts_text, answers_text + if not train: break @@ -407,7 +397,7 @@ def evaluate_grpo( epsilon: float, group_size: int, max_seq_length, - reward_funcs = None, + reward_funcs: Optional[List[RewardFunctions]] = None, loss_fn: callable = grpo_loss, iterate_batches: callable = iterate_grpo_batches ): @@ -418,12 +408,10 @@ def evaluate_grpo( """ all_losses = 0 ntokens = 0 - all_metrics = None # Initialize metrics dictionary + all_metrics = None - # Create iterator for batches index_iterator = iter(range(num_batches)) if num_batches != -1 else iter(int, 1) - # Iterate through batches for _, batch in zip( index_iterator, iterate_batches( @@ -432,7 +420,6 @@ def evaluate_grpo( max_seq_length=max_seq_length, ), ): - # Calculate loss for current batch losses, toks, metrics = loss_fn( model=model, tokenizer=tokenizer, @@ -444,18 +431,15 @@ def evaluate_grpo( ref_model=ref_model ) - # Accumulate losses and tokens all_losses += losses * toks ntokens += toks - # Accumulate metrics if all_metrics is None: all_metrics = {k: v * toks for k, v in metrics.items()} else: for k, v in metrics.items(): all_metrics[k] += v * toks - # Evaluate accumulated values mx.eval(all_losses, ntokens) # Aggregate across distributed workers @@ -475,9 +459,9 @@ def train_grpo( ref_model: Optional[nn.Module], tokenizer, optimizer, - train_dataset: List[GRPOExample], - val_dataset: List[GRPOExample], - reward_funcs = [ + train_dataset, + val_dataset, + reward_funcs: Optional[List[RewardFunctions]] = [ r1_accuracy_reward_func, r1_int_reward_func, r1_strict_format_reward_func, diff --git a/llms/mlx_lm/tuner/utils.py b/llms/mlx_lm/tuner/utils.py index d3497177..7586fda4 100644 --- a/llms/mlx_lm/tuner/utils.py +++ b/llms/mlx_lm/tuner/utils.py @@ -2,8 +2,7 @@ import json import types from pathlib import Path -from typing import Dict, List -from dataclasses import dataclass +from typing import Dict import mlx.core as mx import mlx.nn as nn @@ -275,20 +274,4 @@ def print_trainable_parameters(model): print( f"Trainable parameters: {(trainable_p * 100 / total_p):.3f}% " f"({trainable_p:.3f}M/{total_p:.3f}M)" - ) - -@dataclass -class GRPOExample: - """Single example for GRPO training/inference.""" - prompt_tokens: List[int] - answer_tokens: List[int] - prompt_text: str - answer_text: str - -@dataclass -class GRPOBatch: - """A batch of GRPO examples.""" - prompt_tokens: List[List[int]] - answer_tokens: List[List[int]] - prompt_texts: List[str] - answer_texts: List[str] \ No newline at end of file + ) \ No newline at end of file