From ec50a869b0c5cd9c502e4dfe0318ede67bb6efa4 Mon Sep 17 00:00:00 2001 From: cshang Date: Wed, 5 Feb 2025 18:50:14 -0800 Subject: [PATCH] init grpo --- llms/mlx_lm/convert2.py | 19 + llms/mlx_lm/lora.py | 163 ++++++- llms/mlx_lm/test_grpo | 1 + llms/mlx_lm/tuner/datasets.py | 92 +++- llms/mlx_lm/tuner/grpo_trainer.py | 690 ++++++++++++++++++++++++++++++ 5 files changed, 932 insertions(+), 33 deletions(-) create mode 100644 llms/mlx_lm/convert2.py create mode 160000 llms/mlx_lm/test_grpo create mode 100644 llms/mlx_lm/tuner/grpo_trainer.py diff --git a/llms/mlx_lm/convert2.py b/llms/mlx_lm/convert2.py new file mode 100644 index 00000000..d4bf996d --- /dev/null +++ b/llms/mlx_lm/convert2.py @@ -0,0 +1,19 @@ +import pandas as pd +import os + +# Define dataset directory +dataset_dir = "/Users/cshang/Desktop/test_grpo/data" + +# Convert each Parquet file to JSONL +for file in os.listdir(dataset_dir): + if file.endswith(".parquet"): + parquet_path = os.path.join(dataset_dir, file) + jsonl_path = os.path.join(dataset_dir, file.replace(".parquet", ".jsonl")) + + # Load Parquet file + df = pd.read_parquet(parquet_path) + + # Convert to JSONL format + df.to_json(jsonl_path, orient="records", lines=True) + + print(f"Converted {parquet_path} -> {jsonl_path}") \ No newline at end of file diff --git a/llms/mlx_lm/lora.py b/llms/mlx_lm/lora.py index 43f508c3..2672f3b7 100644 --- a/llms/mlx_lm/lora.py +++ b/llms/mlx_lm/lora.py @@ -15,6 +15,7 @@ import yaml from .tokenizer_utils import TokenizerWrapper from .tuner.datasets import load_dataset from .tuner.trainer import TrainingArgs, TrainingCallback, evaluate, train +from .tuner.grpo_trainer import GRPOTrainingArgs, evaluate_grpo, train_grpo from .tuner.utils import ( build_schedule, linear_to_lora_layers, @@ -42,6 +43,7 @@ yaml_loader.add_implicit_resolver( CONFIG_DEFAULTS = { "model": "mlx_model", "train": False, + "training_mode": "normal", "fine_tune_type": "lora", "data": "data/", "seed": 0, @@ -62,6 +64,15 @@ CONFIG_DEFAULTS = { "grad_checkpoint": False, "lr_schedule": None, "lora_parameters": {"rank": 8, "alpha": 16, "dropout": 0.0, "scale": 10.0}, + + # GRPO args + "reference_model_path": None, + "group_size": 4, + "beta": 0.1, + "epsilon": 1e-4, + "max_completion_length": 512, + "use_chat_template": False, + "use_prompt": False, } @@ -94,6 +105,12 @@ def build_parser(): choices=["lora", "dora", "full"], help="Type of fine-tuning to perform: lora, dora, or full.", ) + parser.add_argument( + "--training-mode", + type=str, + choices=["normal", "grpo"], + help="Training mode: normal or GRPO", + ) parser.add_argument( "--num-layers", type=int, @@ -161,6 +178,44 @@ def build_parser(): default=None, ) parser.add_argument("--seed", type=int, help="The PRNG seed") + + # GRPO args + parser.add_argument( + "--group-size", + type=int, + help="Number of generations.", + default=4, + ) + parser.add_argument( + "--max-completion-length", + type=int, + help="Maximum length of the prompt. If the prompt is longer than this value, it will be truncated left.", + default=512, + ) + parser.add_argument( + "--beta", + type=float, + help="KL penalty coefficient.", + default=0.1, + ) + parser.add_argument( + "--epsilon", + type=float, + help="The Epsilon for numerical stability.", + default=1e-4, + ) + parser.add_argument( + "--use-chat-template", + action="store_true", + help="If the model is a Chat model, use the Chat template.", + default=None, + ) + parser.add_argument( + "--use-prompt", + action="store_true", + help="Rather to use the prompt from the R1 paper.", + default=None, + ) return parser @@ -220,32 +275,102 @@ def train_model( ) ) # Train model - train( - model=model, - tokenizer=tokenizer, - args=training_args, - optimizer=opt, - train_dataset=train_set, - val_dataset=valid_set, - training_callback=training_callback, - ) + if args.training_mode == "grpo": + training_args = GRPOTrainingArgs( + batch_size=args.batch_size, + iters=args.iters, + val_batches=args.val_batches, + steps_per_report=args.steps_per_report, + steps_per_eval=args.steps_per_eval, + steps_per_save=args.save_every, + adapter_file=adapter_file, + max_seq_length=args.max_seq_length, + max_completion_length=args.max_completion_length, + grad_checkpoint=args.grad_checkpoint, + beta=args.beta, + group_size=args.group_size, + epsilon=args.epsilon, + reference_model_path=args.reference_model_path + ) + + if args.reference_model_path: + reference_model, _ = load(args.reference_model_path) + reference_model = reference_model.freeze() + else: + reference_model, _ = load(args.model) + + train_grpo( + model=model, + ref_model=reference_model, + tokenizer=tokenizer, + optimizer=opt, + train_dataset=train_set, + val_dataset=valid_set, + args=training_args, + training_callback=training_callback, + ) + else: + training_args = TrainingArgs( + batch_size=args.batch_size, + iters=args.iters, + val_batches=args.val_batches, + steps_per_report=args.steps_per_report, + steps_per_eval=args.steps_per_eval, + steps_per_save=args.save_every, + adapter_file=adapter_file, + max_seq_length=args.max_seq_length, + grad_checkpoint=args.grad_checkpoint + ) + + train( + model=model, + tokenizer=tokenizer, + args=training_args, + optimizer=opt, + train_dataset=train_set, + val_dataset=valid_set, + training_callback=training_callback, + ) def evaluate_model(args, model: nn.Module, tokenizer: TokenizerWrapper, test_set): model.eval() - test_loss = evaluate( - model=model, - dataset=test_set, - tokenizer=tokenizer, - batch_size=args.batch_size, - num_batches=args.test_batches, - max_seq_length=args.max_seq_length, - ) + if args.training_mode == "grpo": + if args.reference_model_path: + reference_model, _ = load(args.reference_model_path) + else: + reference_model = model - test_ppl = math.exp(test_loss) + test_loss, _, test_rewards = evaluate_grpo( + model=model, + ref_model=reference_model, + dataset=test_set, + tokenizer=tokenizer, + batch_size=args.batch_size, + num_batches=args.test_batches, + max_seq_length=args.max_seq_length, + beta=args.beta, + group_size=args.group_size, + epsilon=args.epsilon + ) - print(f"Test loss {test_loss:.3f}, Test ppl {test_ppl:.3f}.") + test_ppl = math.exp(test_loss) + + print(f"Test loss {test_loss:.3f}, Test ppl {test_ppl:.3f}, Rewards: {test_rewards[0]:.3f}, {test_rewards[1]:.3f}") + else: + test_loss = evaluate( + model=model, + dataset=test_set, + tokenizer=tokenizer, + batch_size=args.batch_size, + num_batches=args.test_batches, + max_seq_length=args.max_seq_length, + ) + + test_ppl = math.exp(test_loss) + + print(f"Test loss {test_loss:.3f}, Test ppl {test_ppl:.3f}.") def run(args, training_callback: TrainingCallback = None): diff --git a/llms/mlx_lm/test_grpo b/llms/mlx_lm/test_grpo new file mode 160000 index 00000000..a74695c9 --- /dev/null +++ b/llms/mlx_lm/test_grpo @@ -0,0 +1 @@ +Subproject commit a74695c9280dd46208ea000f507f44bc8ddd9533 diff --git a/llms/mlx_lm/tuner/datasets.py b/llms/mlx_lm/tuner/datasets.py index 377e7cae..5d0ff68e 100644 --- a/llms/mlx_lm/tuner/datasets.py +++ b/llms/mlx_lm/tuner/datasets.py @@ -1,10 +1,59 @@ import json from pathlib import Path -from typing import Dict, List, Optional +from typing import Dict, List, Optional, Tuple from transformers import PreTrainedTokenizer +class GRPODataset: + """ + Dataset wrapper for GRPO training data. + Each example should have a 'prompt' and 'answer' field. + Returns data in (prompt_tokens, answer_tokens, prompt_str, answer_str) tuple format. + """ + def __init__( + self, + data: List[Dict[str, str]], + tokenizer: PreTrainedTokenizer, + prompt_key: str = "prompt", + answer_key: str = "answer", + use_chat_template: bool = False, + use_prompt: bool = False + ): + self._data = [] + for item in data: + prompt_str = str(item[prompt_key]) + answer_str = str(item[answer_key]) + if use_chat_template: + prompt_tokens = tokenizer.apply_chat_template( + [ + {'role': 'system', 'content': """A conversation between User and Assistant. The user asks a question, and the Assistant solves it. + The assistantfirst thinks about the reasoning process in the mind and then provides the user with the answer. + The reasoning process and answer are enclosed within and tags, respectively, i.e., reasoning process here answer here ."""}, + {'role': 'user', 'content': prompt_str} + ], + ) + answer_tokens = tokenizer.encode(answer_str) + else: + if use_prompt: + prompt_tokens = tokenizer.encode(f"""A conversation between User and Assistant. The user asks a question, and the Assistant solves it. + The assistantfirst thinks about the reasoning process in the mind and then provides the user with the answer. + The reasoning process and answer are enclosed within and tags, respectively, i.e., reasoning process here answer here . + User: {prompt_str}. Assistant: """) + else: + prompt_tokens = tokenizer.encode(prompt_str) + answer_tokens = tokenizer.encode(answer_str) + self._data.append((prompt_tokens, answer_tokens, prompt_str, answer_str)) + + def __getitem__(self, idx: int) -> Tuple[List[int], List[int], str, str]: + """Returns a (prompt_tokens, answer_tokens, prompt_str, answer_str) tuple.""" + return self._data[idx] + + def __len__(self) -> int: + """Returns the number of examples in the dataset.""" + return len(self._data) + + class Dataset: """ Light-weight wrapper to hold a dataset. @@ -82,6 +131,7 @@ class CompletionsDataset: def create_dataset( + args, data, tokenizer: PreTrainedTokenizer, prompt_feature: Optional[str] = None, @@ -90,31 +140,44 @@ def create_dataset( prompt_feature = prompt_feature or "prompt" completion_feature = completion_feature or "completion" sample = data[0] - if "messages" in sample: - return ChatDataset(data, tokenizer) - elif prompt_feature in sample and completion_feature in sample: - return CompletionsDataset(data, tokenizer, prompt_feature, completion_feature) - elif "text" in sample: - return Dataset(data, tokenizer) + + if args.training_mode == "normal": + if "messages" in sample: + return ChatDataset(data, tokenizer) + elif prompt_feature in sample and completion_feature in sample: + return CompletionsDataset(data, tokenizer, prompt_feature, completion_feature) + elif "text" in sample: + return Dataset(data, tokenizer) + else: + raise ValueError( + "Unsupported data format, check the supported formats here:\n" + "https://github.com/ml-explore/mlx-examples/blob/main/llms/mlx_lm/LORA.md#data." + ) else: - raise ValueError( - "Unsupported data format, check the supported formats here:\n" - "https://github.com/ml-explore/mlx-examples/blob/main/llms/mlx_lm/LORA.md#data." + return GRPODataset( + data=data, + tokenizer=tokenizer, + prompt_key="prompt", + answer_key="answer", + use_chat_template=args.use_chat_template, + use_prompt=args.use_prompt ) def load_local_dataset( + args, data_path: Path, tokenizer: PreTrainedTokenizer, prompt_feature: Optional[str] = None, completion_feature: Optional[str] = None, ): def load_subset(path): + print(path) if not path.exists(): return [] with open(path, "r") as fid: data = [json.loads(l) for l in fid] - return create_dataset(data, tokenizer, prompt_feature, completion_feature) + return create_dataset(args, data, tokenizer, prompt_feature, completion_feature) names = ("train", "valid", "test") train, valid, test = [load_subset(data_path / f"{n}.jsonl") for n in names] @@ -122,6 +185,7 @@ def load_local_dataset( def load_hf_dataset( + args, data_id: str, tokenizer: PreTrainedTokenizer, prompt_feature: Optional[str] = None, @@ -137,7 +201,7 @@ def load_hf_dataset( train, valid, test = [ ( create_dataset( - dataset[n], tokenizer, prompt_feature, completion_feature + args, dataset[n], tokenizer, prompt_feature, completion_feature ) if n in dataset.keys() else [] @@ -202,12 +266,12 @@ def load_dataset(args, tokenizer: PreTrainedTokenizer): completion_feature = getattr(args, "completion_feature", None) if data_path.exists(): train, valid, test = load_local_dataset( - data_path, tokenizer, prompt_feature, completion_feature + args, data_path, tokenizer, prompt_feature, completion_feature ) else: print(f"Loading Hugging Face dataset {args.data}.") train, valid, test = load_hf_dataset( - args.data, tokenizer, prompt_feature, completion_feature + args, args.data, tokenizer, prompt_feature, completion_feature ) if args.train and len(train) == 0: diff --git a/llms/mlx_lm/tuner/grpo_trainer.py b/llms/mlx_lm/tuner/grpo_trainer.py new file mode 100644 index 00000000..639c3126 --- /dev/null +++ b/llms/mlx_lm/tuner/grpo_trainer.py @@ -0,0 +1,690 @@ +# Copyright © 2024 Apple Inc. + +import time +from dataclasses import dataclass, field +from pathlib import Path +import re + +import mlx.core as mx +import mlx.nn as nn +import numpy as np +from mlx.utils import tree_flatten + +from .trainer import grad_checkpoint, TrainingArgs, TrainingCallback, average_gradients, iterate_batches + +@dataclass +class GRPOTrainingArgs(TrainingArgs): + group_size: int = field( + default=4, + metadata={"help": "Number of responses per prompt."}, + ) + beta: float = field( + default=0.1, metadata={"help": "KL penalty coefficient."} + ) + epsilon: float = field( + default=1e-4, metadata={"help": "The Epsilon for numerical stability."} + ) + max_completion_length: int = field( + default=512, metadata={"help": "Number of Generations."} + ) + reference_model_path: str = field( + default=None, + metadata={ + "help": "Path to reference model weights. If None, uses the same model." + } + ) + + +def r1_extract_xml_answer(text: str) -> str: + """Extracts the answer from an XML formatted text string.""" + try: + answer = text.split("")[-1] + answer = answer.split("")[0] + return answer.strip() + except: + print("r1_extract_xml_answer returned empty string") + return "" + + +def r1_int_reward_func(prompts: list, completions: list, answer: list, **kwargs) -> list[float]: + """Ensures we always return a list of floats.""" + if not completions: + return [0.0] * len(prompts) + extracted_responses = [r1_extract_xml_answer(r) for r in completions] + return [0.5 if r and r.isdigit() else 0.0 for r in extracted_responses] + +def r1_accuracy_reward_func(prompts: list, completions: list, answer: list, **kwargs) -> list[float]: + """Ensures we always return a list of floats.""" + if not completions or not answer: + return [0.0] * len(prompts) + extracted_responses = [r1_extract_xml_answer(r) for r in completions] + return [2.0 if r and a and r == a else 0.0 for r, a in zip(extracted_responses, answer)] + + +def r1_soft_format_reward_func(prompts: list, completions: list, answer: list, **kwargs) -> list[float]: + """Ensures we always return a list of floats.""" + if not completions: + return [0.0] * len(prompts) + pattern = r".*?\s*.*?" + matches = [bool(re.search(pattern, r)) if r else False for r in completions] + return [0.5 if match else 0.0 for match in matches] + + +def r1_strict_format_reward_func(prompts: list, completions: list, answer: list, **kwargs) -> list[float]: + """Ensures we always return a list of floats.""" + if not completions: + return [0.0] * len(prompts) + pattern = r"^\n.*?\n\n\n.*?\n\n$" + matches = [bool(re.search(pattern, r)) if r else False for r in completions] + return [0.5 if match else 0.0 for match in matches] + + +def r1_count_xml(prompts: list, completions: list, answer: list, **kwargs) -> list[float]: + """Ensures we always return a list of floats.""" + if not completions: + return [0.0] * len(prompts) + + scores = [] + for text in completions: + if not text: + scores.append(0.0) + continue + + count = 0.0 + if text.count("\n") == 1: + count += 0.125 + if text.count("\n\n") == 1: + count += 0.125 + if text.count("\n\n") == 1: + count += 0.125 + if text.count("\n\n") == 1: + count += 0.125 + + # Penalize extra text after + end_text = text.split("\n\n")[-1] + count -= len(end_text) * 0.001 if len(end_text) > 0 else 0 + + scores.append(max(0.0, count)) # Ensure non-negative score + + return scores + + +def generate_grpo(model, prompt, max_tokens, tokenizer, temperature): + if len(prompt.shape) == 1: + prompt = prompt[None, :] + if prompt.shape[1] == 0: + return None + + end_sequence = tokenizer.encode("") + end_sequence_length = len(end_sequence) + output = mx.zeros((prompt.shape[1] + max_tokens,), dtype=mx.int32) + output[:prompt.shape[1]] = prompt[0] + current_length = prompt.shape[1] + + try: + def sample(logits): + if temperature > 0: + logits /= temperature + logprobs = logits - mx.logsumexp(logits, keepdims=True) + return mx.random.categorical(logprobs[None, :]).astype(mx.int32)[0] + + for _ in range(max_tokens): + current_input = output[:current_length][None, :] + logits = model(current_input) + token_logits = logits[0, -1] + next_token = sample(token_logits) + token_value = next_token.item() + output[current_length] = token_value + current_length += 1 + + if token_value == tokenizer.eos_token_id: + break + + if current_length >= end_sequence_length: + last_tokens = output[current_length - end_sequence_length:current_length].tolist() + # print(f"Last tokens: {last_tokens}") + # print(f"Decoded text: {tokenizer.decode(last_tokens)}") + # print(f"Target sequence: {end_sequence}") + if last_tokens == end_sequence: + break + + if current_length > prompt.shape[1]: + return output[:current_length] + + except Exception as e: + print(f"Generation error: {str(e)}") + return None + + return None + + +def get_per_token_logps(model, inputs, lengths): + logits = model(inputs).astype(mx.float16) + logits = logits[:, :-1, :] + targets = inputs[:, 1:] + + per_token_logps = [] + for i in range(logits.shape[0]): + seq_len = int(lengths[i]) - 1 + + seq_logits = logits[i, :seq_len] + seq_targets = targets[i, :seq_len] + + log_probs = nn.log_softmax(seq_logits, axis=-1) + + token_log_probs = mx.take_along_axis( + log_probs, + seq_targets.reshape(seq_len, 1), + axis=-1 + ).squeeze(-1) + + per_token_logps.append(token_log_probs) + mx.eval(logits) + return per_token_logps + + +def grpo_loss( + model, + tokenizer, + batch, + reward_funcs=None, + beta=0.1, + group_size=4, + epsilon=1e-4, + ref_model=None, + max_tokens=64, + temperature=1.0 +): + prompt_tokens, answer_tokens, prompt_text, answer_text = batch + batch_size = len(prompt_tokens) + + # Generation logic remains the same + all_completions = [] + all_completion_texts = [] + + for i in range(0, batch_size, batch_size): + batch_prompts = prompt_tokens[i:i+batch_size] + for prompt in batch_prompts: + prompt_tensor = mx.array(prompt) + for _ in range(group_size): + try: + completion_ids = generate_grpo(model, prompt_tensor, max_tokens, tokenizer, temperature) + if completion_ids is not None: + completion_text = tokenizer.decode(completion_ids.tolist()) + all_completions.append(completion_ids) + all_completion_texts.append(completion_text) + + # Clear completion tensors + mx.eval(completion_ids) + del completion_ids + except Exception as e: + print(f"Generation error: {e}") + continue + + mx.metal.clear_cache() + + # Prepare inputs + expanded_answers = [] + expanded_prompts = [] + for i in range(batch_size): + expanded_answers.extend([answer_text[i]] * group_size) + expanded_prompts.extend([prompt_text[i]] * group_size) + + max_length = max(ids.shape[0] for ids in all_completions) + padded_completions = [] + attention_masks = [] + + for completion_ids in all_completions: + padding_length = max_length - completion_ids.shape[0] + if padding_length > 0: + padding = mx.zeros((padding_length,), dtype=completion_ids.dtype) + padded_ids = mx.concatenate([completion_ids, padding]) + mask = mx.concatenate([mx.ones_like(completion_ids), mx.zeros_like(padding)]) + else: + padded_ids = completion_ids + mask = mx.ones_like(completion_ids) + padded_completions.append(padded_ids) + attention_masks.append(mask) + + inputs = mx.stack(padded_completions) + attention_mask = mx.stack(attention_masks) + lengths = attention_mask.sum(axis=1) + + # Current policy probabilities + token_log_probs = get_per_token_logps(model, inputs, lengths) + + mx.eval(token_log_probs) + mx.metal.clear_cache() + + # Reference policy probabilities + if ref_model is not None: + ref_token_log_probs = get_per_token_logps(ref_model, inputs, lengths) + else: + ref_token_log_probs = token_log_probs + + max_len = max(x.shape[0] for x in token_log_probs) + padded_log_probs = [] + padded_ref_log_probs = [] + + for i in range(len(token_log_probs)): + seq_len = token_log_probs[i].shape[0] + padding = mx.zeros((max_len - seq_len,), dtype=mx.float16) + + padded_log_probs.append(mx.concatenate([token_log_probs[i], padding])) + padded_ref_log_probs.append(mx.concatenate([ref_token_log_probs[i], padding])) + + token_log_probs = mx.stack(padded_log_probs) + ref_token_log_probs = mx.stack(padded_ref_log_probs) + + # Calculate rewards and advantages + rewards = mx.zeros((len(all_completions),)) + for reward_func in reward_funcs: + func_rewards = mx.array(reward_func( + prompts=expanded_prompts, + completions=all_completion_texts, + answer=expanded_answers + )) + rewards += func_rewards + + if len(reward_funcs) > 1: + rewards /= len(reward_funcs) + + # Reshape rewards and compute advantages following GRPO formula + rewards_reshaped = rewards.reshape(batch_size, group_size) + mean_rewards = mx.broadcast_to(mx.mean(rewards_reshaped, axis=1)[:, None], (rewards_reshaped.shape[0], group_size)).reshape(-1) + std_rewards = mx.broadcast_to(mx.std(rewards_reshaped, axis=1)[:, None], (rewards_reshaped.shape[0], group_size)).reshape(-1) + advantages = (rewards - mean_rewards) / (std_rewards + epsilon) + + # Compute KL divergence using Schulman's approximator + kl_div = (mx.exp(token_log_probs - ref_token_log_probs) - 1) - (token_log_probs - ref_token_log_probs) + + # Create mask for valid tokens + length_mask = mx.arange(inputs.shape[1] - 1)[None, :] < (lengths[:, None] - 1) + + # Compute policy ratio + policy_ratio = mx.exp(token_log_probs - mx.stop_gradient(token_log_probs)) + + # Compute per-token loss following GRPO formula + per_token_loss = -((policy_ratio * advantages.reshape(-1, 1) - beta * kl_div) * length_mask) + + # Average over tokens and sequences + sequence_sums = per_token_loss.sum(axis=1) + sequence_lengths = length_mask.sum(axis=1) + + loss = (sequence_sums / sequence_lengths).mean() + + # Calculate mean KL divergence for metrics + mean_kl = ((kl_div * length_mask).sum(axis=1) / length_mask.sum(axis=1)).mean() + + # Collect reward metrics + reward_metrics = {} + for i, reward_func in enumerate(reward_funcs): + func_name = reward_func.__name__ + func_rewards = mx.array(reward_func( + prompts=expanded_prompts, + completions=all_completion_texts, + answer=expanded_answers + )) + reward_metrics[f'{func_name}_mean'] = mx.mean(func_rewards) + reward_metrics[f'{func_name}_std'] = mx.std(func_rewards) + + metrics = { + 'total_rewards_mean': mx.mean(rewards), + 'total_rewards_std': mx.std(rewards), + 'grouped_rewards_mean': mx.mean(rewards_reshaped), + 'grouped_rewards_std': mx.std(rewards_reshaped), + 'kl': mean_kl, + **reward_metrics + } + mx.metal.clear_cache() + + return loss, sequence_lengths.sum(), metrics + + +def iterate_grpo_batches(dataset, tokenizer, batch_size, max_seq_length, train=False): + """Memory-optimized version of iterate_grpo_batches""" + if not dataset or not isinstance(dataset[0], tuple) or len(dataset[0]) != 4: + raise ValueError("Dataset must be list of (prompt_tokens, answer_tokens, prompt_str, answer_str) tuples") + + # Sort by length but use generator to avoid keeping full sorted list in memory + def length_key(i): + return len(dataset[i][0]) + len(dataset[i][1]) + + idx = sorted(range(len(dataset)), key=length_key) + + if len(dataset) < batch_size: + raise ValueError( + f"Dataset must have at least batch_size={batch_size} " + f"examples but only has {len(dataset)}." + ) + + step = mx.distributed.init().size() + if batch_size % step != 0: + raise ValueError("The batch size must be divisible by the number of workers") + + # Use generator for batch indices + def batch_index_generator(): + for i in range(0, len(idx) - batch_size + 1, batch_size): + yield idx[i : i + batch_size : step] + + while True: + indices = ( + np.random.permutation(list(batch_index_generator())) if train + else batch_index_generator() + ) + + for batch_idx in indices: + current_batch = [dataset[j] for j in batch_idx] + + prompts_tokens = [item[0] for item in current_batch] + answers_tokens = [item[1] for item in current_batch] + prompts_text = [item[2] for item in current_batch] + answers_text = [item[3] for item in current_batch] + + if any(len(p) > max_seq_length for p in prompts_tokens): + print( + f"[WARNING] Some prompts are longer than {max_seq_length} tokens. " + "Long prompts will be truncated." + ) + + yield prompts_tokens, answers_tokens, prompts_text, answers_text + + if not train: + break + + +def evaluate_grpo( + model, + ref_model, + dataset, + tokenizer, + batch_size, + num_batches, + beta: float, + epsilon: float, + group_size: int, + max_seq_length, + reward_funcs = None, + loss_fn: callable = grpo_loss, + iterate_batches: callable = iterate_grpo_batches +): + """ + Evaluate model using GRPO loss. + Returns: + tuple: (average loss, number of tokens, average metrics) + """ + all_losses = 0 + ntokens = 0 + all_metrics = None # Initialize metrics dictionary + + # Create iterator for batches + index_iterator = iter(range(num_batches)) if num_batches != -1 else iter(int, 1) + + # Iterate through batches + for _, batch in zip( + index_iterator, + iterate_batches( + dataset=dataset, + tokenizer=tokenizer, + batch_size=batch_size, + max_seq_length=max_seq_length, + ), + ): + # Calculate loss for current batch + losses, toks, metrics = loss_fn( + model=model, + tokenizer=tokenizer, + batch=batch, + reward_funcs=reward_funcs, + beta=beta, + group_size=group_size, + epsilon=epsilon, + ref_model=ref_model + ) + + # Accumulate losses and tokens + all_losses += losses * toks + ntokens += toks + + # Accumulate metrics + if all_metrics is None: + all_metrics = {k: v * toks for k, v in metrics.items()} + else: + for k, v in metrics.items(): + all_metrics[k] += v * toks + + # Evaluate accumulated values + mx.eval(all_losses, ntokens) + + # Aggregate across distributed workers + all_losses = mx.distributed.all_sum(all_losses, stream=mx.cpu) + ntokens = mx.distributed.all_sum(ntokens, stream=mx.cpu) + all_metrics = {k: mx.distributed.all_sum(v) for k, v in all_metrics.items()} + + # Calculate averages + avg_metrics = {k: (v / ntokens).item() for k, v in all_metrics.items()} + avg_loss = (all_losses / ntokens).item() + + return avg_loss, ntokens, avg_metrics + + +def train_grpo( + model, + ref_model, + tokenizer, + optimizer, + train_dataset, + val_dataset, + reward_funcs = [ + r1_accuracy_reward_func, + r1_int_reward_func, + r1_strict_format_reward_func, + r1_soft_format_reward_func, + r1_count_xml + ], + args: GRPOTrainingArgs = GRPOTrainingArgs(), + loss_fn: callable = grpo_loss, + iterate_batches: callable = iterate_grpo_batches, + training_callback: TrainingCallback = None, +): + print(f"Starting GRPO training with {len(reward_funcs)} reward functions..., iters: {args.iters}") + world = mx.distributed.init() + world_size = world.size() + rank = world.rank() + if world_size > 1: + print(f"Node {rank} of {world_size}") + + if args.grad_checkpoint: + grad_checkpoint(model.layers[0]) + + state = [model.state, optimizer.state] + + def step(batch): + + # Forward and backward pass + (loss, toks, metrics), grad = loss_value_and_grad( + model, + tokenizer=tokenizer, + batch=batch, + reward_funcs=reward_funcs, + beta=args.beta, + group_size=args.group_size, + epsilon=args.epsilon, + ref_model=ref_model, + max_tokens=args.max_completion_length, + ) + + # All reduce the gradients if running in distributed mode + grad = average_gradients(grad) + + # Model update + optimizer.update(model, grad) + + return loss, toks, metrics + + loss_value_and_grad = nn.value_and_grad(model, loss_fn) + + losses = 0 + n_tokens = 0 + steps = 0 + trained_tokens = 0 + accumulated_metrics = { + 'total_rewards_mean': 0, + 'total_rewards_std': 0, + 'grouped_rewards_mean': 0, + 'grouped_rewards_std': 0, + 'kl': 0 + } + for reward_func in reward_funcs: + func_name = reward_func.__name__ + accumulated_metrics[f'{func_name}_mean'] = 0 + accumulated_metrics[f'{func_name}_std'] = 0 + + start = time.perf_counter() + for it, batch in zip( + range(1, args.iters + 1), + iterate_batches( + dataset=train_dataset, + tokenizer=tokenizer, + batch_size=args.batch_size, + max_seq_length=args.max_seq_length, + train=True, + ), + ): + # Report validation loss if needed, the first validation loss + # is always measured before any training. + if it == 1 or it % args.steps_per_eval == 0 or it == args.iters: + stop = time.perf_counter() + val_loss, val_ntokens, val_metrics = evaluate_grpo( + model=model, + dataset=val_dataset, + loss_fn=loss_fn, + ref_model=ref_model, + reward_funcs=reward_funcs, + tokenizer=tokenizer, + group_size=args.group_size, + batch_size=args.batch_size, + num_batches=args.val_batches, + max_seq_length=args.max_seq_length, + beta=args.beta, + epsilon=args.epsilon, + iterate_batches=iterate_batches, + ) + val_time = time.perf_counter() - stop + if rank == 0: + val_metrics_str = ( + f"Val loss {val_loss:.8f}, " + f"Val total_rewards_mean {val_metrics['total_rewards_mean']:.3f}, " + f"Val total_rewards_std {val_metrics['total_rewards_std']:.3f}, " + f"Val grouped_rewards_mean {val_metrics['grouped_rewards_mean']:.3f}, " + f"Val grouped_rewards_std {val_metrics['grouped_rewards_std']:.3f}, " + f"Val kl {val_metrics['kl']:.3f}" + ) + + # Add reward function specific metrics + for i, reward_func in enumerate(reward_funcs): + val_metrics_str += ( + f", Val {reward_func.__name__}_mean {val_metrics[f'{reward_func.__name__}_mean']:.3f}, " + f"Val {reward_func.__name__}_std {val_metrics[f'{reward_func.__name__}_std']:.3f}" + ) + + print( + f"Iter {it}: {val_metrics_str}, " + f"Val took {val_time:.3f}s", + flush=True, + ) + + if training_callback is not None: + training_callback.on_val_loss_report({ + "iteration": it, + "val_loss": val_loss, + **{f"val_{k}": v for k, v in val_metrics.items()}, + "val_time": val_time, + }) + + start = time.perf_counter() + + loss, toks, metrics = step(batch) + losses += loss + n_tokens += toks + steps += 1 + + for k, v in metrics.items(): + accumulated_metrics[k] += v + + mx.eval(state, losses, n_tokens) + + if it % args.steps_per_report == 0 or it == args.iters: + stop = time.perf_counter() + + train_loss = mx.distributed.all_sum(losses, stream=mx.cpu).item() + train_loss /= steps * mx.distributed.init().size() + avg_metrics = {k: v / (steps * world_size) for k, v in accumulated_metrics.items()} + n_tokens = mx.distributed.all_sum(n_tokens, stream=mx.cpu).item() + learning_rate = optimizer.learning_rate.item() + it_sec = args.steps_per_report / (stop - start) + tokens_sec = float(n_tokens) / (stop - start) + trained_tokens += n_tokens + peak_mem = mx.metal.get_peak_memory() / 1e9 + + if rank == 0: + train_metrics_str = ( + f"Train loss {train_loss:.8f}, " + f"Total rewards mean {avg_metrics['total_rewards_mean']:.3f}, " + f"Total rewards std {avg_metrics['total_rewards_std']:.3f}, " + f"Grouped rewards mean {avg_metrics['grouped_rewards_mean']:.3f}, " + f"Grouped rewards std {avg_metrics['grouped_rewards_std']:.3f}, " + f"KL {avg_metrics['kl']:.3f}" + ) + + # Add reward function specific metrics + for i, reward_func in enumerate(reward_funcs): + func_name = reward_func.__name__ + train_metrics_str += ( + f", {func_name} mean {avg_metrics[f'{func_name}_mean']:.3f}, " + f"{func_name} std {avg_metrics[f'{func_name}_std']:.3f}" + ) + + print( + f"Iter {it}: {train_metrics_str}, " + f"Learning Rate {learning_rate:.3e}, " + f"It/sec {it_sec:.3f}, " + f"Tokens/sec {tokens_sec:.3f}, " + f"Peak mem {peak_mem:.3f} GB", + flush=True, + ) + + if training_callback is not None: + training_callback.on_train_loss_report({ + "iteration": it, + "train_loss": train_loss, + **{f"train_{k}": v for k, v in avg_metrics.items()}, + "learning_rate": learning_rate, + "iterations_per_second": it_sec, + "tokens_per_second": tokens_sec, + "trained_tokens": trained_tokens, + "peak_memory": peak_mem, + }) + + losses = 0 + n_tokens = 0 + steps = 0 + start = time.perf_counter() + + # Save adapter weights + if it % args.steps_per_save == 0: + adapter_weights = dict(tree_flatten(model.trainable_parameters())) + mx.save_safetensors(str(args.adapter_file), adapter_weights) + checkpoint = ( + Path(args.adapter_file).parent / f"{it:07d}_adapters.safetensors" + ) + mx.save_safetensors(str(checkpoint), adapter_weights) + print( + f"Iter {it}: Saved adapter weights to " + f"{args.adapter_file} and {checkpoint}." + ) + + # Save final weights + adapter_weights = dict(tree_flatten(model.trainable_parameters())) + mx.save_safetensors(str(args.adapter_file), adapter_weights) + print(f"Saved final weights to {args.adapter_file}.") \ No newline at end of file