diff --git a/llms/mlx_lm/tuner/datasets.py b/llms/mlx_lm/tuner/datasets.py index 0c19031b..163b1be7 100644 --- a/llms/mlx_lm/tuner/datasets.py +++ b/llms/mlx_lm/tuner/datasets.py @@ -9,7 +9,7 @@ class GRPODataset: """ Dataset wrapper for GRPO training data. Each example should have a 'prompt' and 'answer' field. - Returns data in (prompt, answer) tuple format required by GRPO trainer. + Returns data in (prompt_tokens, answer_tokens, prompt_str, answer_str) tuple format. """ def __init__( self, @@ -20,15 +20,14 @@ class GRPODataset: ): self._data = [] for item in data: - # Get prompt and answer text - prompt = str(item[prompt_key]) - answer = str(item[answer_key]) - - # Store as (prompt, answer) tuple - self._data.append((prompt, answer)) + prompt_str = str(item[prompt_key]) + answer_str = str(item[answer_key]) + prompt_tokens = tokenizer.encode(prompt_str) + answer_tokens = tokenizer.encode(answer_str) + self._data.append((prompt_tokens, answer_tokens, prompt_str, answer_str)) - def __getitem__(self, idx: int) -> Tuple[str, str]: - """Returns a (prompt, answer) tuple for the given index.""" + def __getitem__(self, idx: int) -> Tuple[List[int], List[int], str, str]: + """Returns a (prompt_tokens, answer_tokens, prompt_str, answer_str) tuple.""" return self._data[idx] def __len__(self) -> int: diff --git a/llms/mlx_lm/tuner/grpo_trainer.py b/llms/mlx_lm/tuner/grpo_trainer.py index e997b504..16125611 100644 --- a/llms/mlx_lm/tuner/grpo_trainer.py +++ b/llms/mlx_lm/tuner/grpo_trainer.py @@ -12,7 +12,7 @@ from mlx.utils import tree_flatten from .trainer import grad_checkpoint, TrainingArgs, TrainingCallback, average_gradients, iterate_batches -from mlx_lm import generate +from mlx_lm.utils import generate_step @dataclass @@ -35,6 +35,70 @@ class GRPOTrainingArgs(TrainingArgs): ) +def generate_for_grpo( + model, + prompt, + max_tokens, + tokenizer, + temperature=1.0 +): + try: + + # Ensure prompt is the right shape + if len(prompt.shape) == 1: + prompt = prompt[None, :] + + # Initialize generation + generated = [] + current_prompt = prompt[0] + + for step in range(max_tokens): + try: + # Get model output with explicit shape checking + current_batch = current_prompt[None, :] + + logits = model(current_batch) + + # Ensure we have the last token logits + token_logits = logits[0, -1] + + # Apply temperature and get probabilities + if temperature > 0: + token_logits = token_logits / temperature + probs = mx.softmax(token_logits) + + # Sample the next token + next_token = mx.random.categorical(probs[None, :]) + next_token = next_token[0] + + # Force evaluation to catch any issues + mx.eval(next_token) + token_value = next_token.item() + + # Add to generated sequence + generated.append(next_token) + current_prompt = mx.concatenate([current_prompt, next_token[None]]) + + if token_value == tokenizer.eos_token_id: + break + + except Exception as e: + raise + + if not generated: + return prompt[0] + + try: + result = mx.concatenate([prompt[0], mx.stack(generated)]) + mx.eval(result) + return result + except Exception as e: + raise + + except Exception as e: + raise + + def r1_extract_xml_answer(text: str) -> str: """Extracts the answer from an XML formatted text string.""" try: @@ -45,42 +109,45 @@ def r1_extract_xml_answer(text: str) -> str: print("[extract_xml_answer] Failed to extract answer from: ", text) return "" -def r1_accuracy_reward_func(prompts, completions, answer, **kwargs) -> list[float]: +def r1_accuracy_reward_func(prompts: list, completions: list, answer: list, **kwargs) -> list[float]: """Calculates reward based on accuracy of extracted answers. - Args: prompts: List of input prompts completions: List of completion strings answer: Expected answer or list of answers **kwargs: Additional arguments - Returns: list[float]: Reward values for each completion """ extracted_responses = [r1_extract_xml_answer(r) for r in completions] - q = prompts[0] if isinstance(prompts[0], str) else prompts[0][-1]['content'] return [2.0 if r == a else 0.0 for r, a in zip(extracted_responses, answer)] -def r1_int_reward_func(completions, **kwargs) -> list[float]: +def r1_int_reward_func(prompts: list, completions: list, answer: list, **kwargs) -> list[float]: """Rewards numerical responses. - Args: + prompts: List of input prompts completions: List of completion strings + answer: Expected answer or list of answers **kwargs: Additional arguments - Returns: list[float]: Reward values for each completion """ extracted_responses = [r1_extract_xml_answer(r) for r in completions] return [0.5 if r.isdigit() else 0.0 for r in extracted_responses] -def r1_strict_format_reward_func(completions, **kwargs) -> list[float]: +def r1_soft_format_reward_func(prompts: list, completions: list, answer: list, **kwargs) -> list[float]: + """Rewards completions with flexible XML format.""" + pattern = r".*?\s*.*?" + matches = [re.match(pattern, r) for r in completions] + return [0.5 if match else 0.0 for match in matches] + +def r1_strict_format_reward_func(prompts: list, completions: list, answer: list, **kwargs) -> list[float]: """Rewards completions with strict XML format. - Args: + prompts: List of input prompts completions: List of completion strings + answer: Expected answer or list of answers **kwargs: Additional arguments - Returns: list[float]: Reward values for each completion """ @@ -88,98 +155,128 @@ def r1_strict_format_reward_func(completions, **kwargs) -> list[float]: matches = [re.match(pattern, r) for r in completions] return [0.5 if match else 0.0 for match in matches] -def r1_soft_format_reward_func(completions, **kwargs) -> list[float]: - """Rewards completions with flexible XML format. - - Args: - completions: List of completion strings - **kwargs: Additional arguments - - Returns: - list[float]: Reward values for each completion - """ - pattern = r".*?\s*.*?" - matches = [re.match(pattern, r) for r in completions] - return [0.5 if match else 0.0 for match in matches] - -def r1_count_xml(text: str) -> float: +def r1_count_xml(prompts: list, completions: list, answer: list, **kwargs) -> list[float]: """Calculates score based on XML formatting. - Args: - text: Input text string - + prompts: List of input prompts (unused) + completions: List of completion strings to evaluate + answer: Expected answer or list of answers (unused) + **kwargs: Additional arguments Returns: - float: Score based on XML tag presence and formatting + list[float]: List of scores based on XML tag presence and formatting """ - count = 0.0 - if text.count("\n") == 1: - count += 0.125 - if text.count("\n\n") == 1: - count += 0.125 - if text.count("\n\n") == 1: - count += 0.125 + scores = [] + for text in completions: + count = 0.0 + if text.count("\n") == 1: + count += 0.125 + if text.count("\n\n") == 1: + count += 0.125 + if text.count("\n\n") == 1: + count += 0.125 count -= len(text.split("\n\n")[-1])*0.001 - if text.count("\n") == 1: - count += 0.125 + if text.count("\n") == 1: + count += 0.125 count -= (len(text.split("\n")[-1]) - 1)*0.001 - return count + scores.append(count) + return scores def grpo_loss( - model, - tokenizer, - prompts, - reward_funcs=None, - beta=0.1, - group_size=4, - epsilon=1e-4, - ref_model=None - ): - """ - Calculates the GRPO loss with support for multiple reward functions. + model, + tokenizer, + batch, + reward_funcs=None, + beta=0.1, + group_size=4, + epsilon=1e-4, + ref_model=None, + max_tokens=128, + temperature=1.0 +): + """Modified GRPO loss function with better error handling""" + prompt_tokens, answer_tokens, prompt_text, answer_text = batch + batch_size = len(prompt_tokens) - Args: - model: The model to optimize - tokenizer: Tokenizer for processing inputs - prompts: List of input prompts - reward_funcs: List of reward functions to use - beta: KL penalty coefficient - group_size: Number of completions per prompt - epsilon: Small constant for numerical stability - ref_model: Optional reference model for KL divergence - - Returns: - tuple: (loss, total_sequence_length, metrics_dict) - """ - batch_size = len(prompts) - # Generate multiple completions for each prompt + # Generate completions for each prompt all_completions = [] + all_completion_texts = [] - for prompt in prompts: + for prompt in prompt_tokens: + prompt_tensor = mx.array(prompt) prompt_completions = [] + prompt_completion_texts = [] + + # Generate group_size completions for each prompt for _ in range(group_size): - completion = generate(model, tokenizer, prompt) - prompt_completions.append(completion) + try: + completion_ids = generate_for_grpo( + model, + prompt_tensor, + max_tokens, + tokenizer=tokenizer, + temperature=temperature + ) + + # Verify completion_ids is not None + if completion_ids is None: + print("Warning: generate_for_grpo returned None") + break + + completion_text = tokenizer.decode(completion_ids.tolist()) + + prompt_completions.append(completion_ids) + prompt_completion_texts.append(completion_text) + + except Exception as e: + print(f"Error in completion generation: {str(e)}") + # Fallback to using original prompt + prompt_completions.append(prompt_tensor) + prompt_completion_texts.append(tokenizer.decode(prompt_tensor.tolist())) + all_completions.extend(prompt_completions) + all_completion_texts.extend(prompt_completion_texts) + + # Verify we have the expected number of completions + assert len(all_completions) == batch_size * group_size + assert len(all_completion_texts) == batch_size * group_size - # Tokenize all prompts + completions - tokenized_inputs = tokenizer( - [p + c for p, c in zip(prompts * group_size, all_completions)], - return_tensors="np", - padding=True - ) + # Expand answer_text and prompt_text to match completion groups + expanded_answers = [] + expanded_prompts = [] + for i in range(batch_size): + expanded_answers.extend([answer_text[i]] * group_size) + expanded_prompts.extend([prompt_text[i]] * group_size) + + # Verify we have the expected number of completions + assert len(all_completions) == batch_size * group_size + assert len(all_completion_texts) == batch_size * group_size - inputs = mx.array(tokenized_inputs["input_ids"]) - attention_mask = mx.array(tokenized_inputs["attention_mask"]) + max_length = max(ids.shape[0] for ids in all_completions) + padded_completions = [] + attention_masks = [] - # Get lengths for proper masking + for completion_ids in all_completions: + padding_length = max_length - completion_ids.shape[0] + if padding_length > 0: + padding = mx.zeros((padding_length,), dtype=completion_ids.dtype) + padded_ids = mx.concatenate([completion_ids, padding]) + mask = mx.concatenate([mx.ones_like(completion_ids), mx.zeros_like(padding)]) + else: + padded_ids = completion_ids + mask = mx.ones_like(completion_ids) + padded_completions.append(padded_ids) + attention_masks.append(mask) + + inputs = mx.stack(padded_completions) + attention_mask = mx.stack(attention_masks) lengths = attention_mask.sum(axis=1) # Get logits from current model logits = model(inputs).astype(mx.float32) # Calculate log probabilities - log_probs = mx.log_softmax(logits[:, :-1, :], axis=-1) + log_probs = nn.log_softmax(logits[:, :-1, :], axis=-1) # Prepare targets targets = inputs[:, 1:] @@ -197,7 +294,7 @@ def grpo_loss( else: ref_logits = model(inputs).astype(mx.float32) - ref_log_probs = mx.log_softmax(ref_logits[:, :-1, :], axis=-1) + ref_log_probs = nn.log_softmax(ref_logits[:, :-1, :], axis=-1) ref_token_log_probs = mx.take_along_axis( ref_log_probs, targets.reshape(*targets.shape, 1), @@ -210,7 +307,11 @@ def grpo_loss( # Calculate combined rewards from all reward functions rewards = mx.zeros((len(all_completions),)) for reward_func in reward_funcs: - func_rewards = mx.array(reward_func(all_completions)) + func_rewards = mx.array(reward_func( + prompts=prompt_text, + completions=all_completion_texts, + answer=answer_text + )) rewards += func_rewards # Normalize rewards if using multiple reward functions @@ -245,8 +346,12 @@ def grpo_loss( # Collect metrics for each reward function separately reward_metrics = {} for i, reward_func in enumerate(reward_funcs): - func_rewards = mx.array(reward_func(all_completions)) - func_grouped_rewards = func_rewards.reshape(batch_size, group_size) + func_rewards = mx.array(reward_func( + prompts=prompt_text, + completions=all_completion_texts, + answer=answer_text + )) + # func_grouped_rewards = func_rewards.reshape(batch_size, group_size) reward_metrics[f'reward_func_{i}_mean'] = mx.mean(func_rewards) reward_metrics[f'reward_func_{i}_std'] = mx.std(func_rewards) @@ -264,26 +369,30 @@ def grpo_loss( def iterate_grpo_batches(dataset, tokenizer, batch_size, max_seq_length, train=False): """ - Creates batches from prompt-answer pairs for GRPO training. + Creates batches from dataset entries for GRPO training. Args: - dataset: List of (prompt, answer) pairs + dataset: List of (prompt_tokens, answer_tokens, prompt_str, answer_str) tuples tokenizer: Tokenizer for processing inputs batch_size: Size of each batch max_seq_length: Maximum sequence length train: Whether this is for training Yields: - List of prompts for the current batch + Tuple containing: + - prompts_tokens: List of token sequences for current batch + - answers_tokens: List of token sequences + - prompts_text: List of prompt strings + - answers_text: List of answer strings """ - # Verify dataset is not empty and has correct format - if not dataset or not isinstance(dataset[0], (tuple, list)) or len(dataset[0]) != 2: - raise ValueError("Dataset must be a list of (prompt, answer) pairs") - - # Sort by combined length of prompt + answer + # Verify dataset format + if not dataset or not isinstance(dataset[0], tuple) or len(dataset[0]) != 4: + raise ValueError("Dataset must be list of (prompt_tokens, answer_tokens, prompt_str, answer_str) tuples") + + # Sort by combined length of prompt + answer tokens idx = sorted(range(len(dataset)), key=lambda i: len(dataset[i][0]) + len(dataset[i][1])) - + if len(dataset) < batch_size: raise ValueError( f"Dataset must have at least batch_size={batch_size} " @@ -306,22 +415,22 @@ def iterate_grpo_batches(dataset, tokenizer, batch_size, max_seq_length, train=F indices = np.random.permutation(len(batch_idx)) if train else range(len(batch_idx)) for i in indices: - # Get current batch of prompt-answer pairs + # Get current batch current_batch = [dataset[j] for j in batch_idx[i]] - # Extract prompts and answers - prompts = [pair[0] for pair in current_batch] - answers = [pair[1] for pair in current_batch] - - if any(len(p) > max_seq_length for p in prompts): + # Extract all components + prompts_tokens = [item[0] for item in current_batch] + answers_tokens = [item[1] for item in current_batch] + prompts_text = [item[2] for item in current_batch] + answers_text = [item[3] for item in current_batch] + + if any(len(p) > max_seq_length for p in prompts_tokens): print( f"[WARNING] Some prompts are longer than {max_seq_length} tokens. " "Long prompts will be truncated." ) - - # For GRPO, we only need to yield the prompts - # The answers will be used by the reward functions - yield prompts + + yield prompts_tokens, answers_tokens, prompts_text, answers_text if not train: break @@ -342,11 +451,19 @@ def evaluate_grpo( loss: callable = grpo_loss, iterate_batches: callable = iterate_grpo_batches ): + """ + Evaluate model using GRPO loss. + Returns: + tuple: (average loss, number of tokens, average metrics) + """ all_losses = 0 ntokens = 0 - + all_metrics = None # Initialize metrics dictionary + + # Create iterator for batches index_iterator = iter(range(num_batches)) if num_batches != -1 else iter(int, 1) - + + # Iterate through batches for _, batch in zip( index_iterator, iterate_batches( @@ -356,35 +473,41 @@ def evaluate_grpo( max_seq_length=max_seq_length, ), ): - prompts = batch + # Calculate loss for current batch losses, toks, metrics = loss( model=model, tokenizer=tokenizer, - prompts=prompts, + batch=batch, reward_funcs=reward_funcs, beta=beta, group_size=group_size, epsilon=epsilon, ref_model=ref_model ) + + # Accumulate losses and tokens all_losses += losses * toks ntokens += toks - + + # Accumulate metrics if all_metrics is None: all_metrics = {k: v * toks for k, v in metrics.items()} else: for k, v in metrics.items(): all_metrics[k] += v * toks - + + # Evaluate accumulated values mx.eval(all_losses, ntokens) - + + # Aggregate across distributed workers all_losses = mx.distributed.all_sum(all_losses, stream=mx.cpu) ntokens = mx.distributed.all_sum(ntokens, stream=mx.cpu) all_metrics = {k: mx.distributed.all_sum(v) for k, v in all_metrics.items()} - + + # Calculate averages avg_metrics = {k: (v / ntokens).item() for k, v in all_metrics.items()} avg_loss = (all_losses / ntokens).item() - + return avg_loss, ntokens, avg_metrics @@ -420,8 +543,18 @@ def train_grpo( state = [model.state, optimizer.state] def step(batch): + # Forward and backward pass - (loss, toks, metrics), grad = loss_value_and_grad(model, *batch) + (loss, toks, metrics), grad = loss_value_and_grad( + model, + tokenizer=tokenizer, + batch=batch, + reward_funcs=reward_funcs, + beta=args.beta, + group_size=args.group_size, + epsilon=args.epsilon, + ref_model=ref_model + ) # All reduce the gradients if running in distributed mode grad = average_gradients(grad) @@ -430,7 +563,7 @@ def train_grpo( optimizer.update(model, grad) return loss, toks, metrics - + loss_value_and_grad = nn.value_and_grad(model, loss) losses = 0