diff --git a/llms/mlx_lm/convert2.py b/llms/mlx_lm/convert2.py
new file mode 100644
index 00000000..d4bf996d
--- /dev/null
+++ b/llms/mlx_lm/convert2.py
@@ -0,0 +1,19 @@
+import pandas as pd
+import os
+
+# Define dataset directory
+dataset_dir = "/Users/cshang/Desktop/test_grpo/data"
+
+# Convert each Parquet file to JSONL
+for file in os.listdir(dataset_dir):
+ if file.endswith(".parquet"):
+ parquet_path = os.path.join(dataset_dir, file)
+ jsonl_path = os.path.join(dataset_dir, file.replace(".parquet", ".jsonl"))
+
+ # Load Parquet file
+ df = pd.read_parquet(parquet_path)
+
+ # Convert to JSONL format
+ df.to_json(jsonl_path, orient="records", lines=True)
+
+ print(f"Converted {parquet_path} -> {jsonl_path}")
\ No newline at end of file
diff --git a/llms/mlx_lm/lora.py b/llms/mlx_lm/lora.py
index 43f508c3..2672f3b7 100644
--- a/llms/mlx_lm/lora.py
+++ b/llms/mlx_lm/lora.py
@@ -15,6 +15,7 @@ import yaml
from .tokenizer_utils import TokenizerWrapper
from .tuner.datasets import load_dataset
from .tuner.trainer import TrainingArgs, TrainingCallback, evaluate, train
+from .tuner.grpo_trainer import GRPOTrainingArgs, evaluate_grpo, train_grpo
from .tuner.utils import (
build_schedule,
linear_to_lora_layers,
@@ -42,6 +43,7 @@ yaml_loader.add_implicit_resolver(
CONFIG_DEFAULTS = {
"model": "mlx_model",
"train": False,
+ "training_mode": "normal",
"fine_tune_type": "lora",
"data": "data/",
"seed": 0,
@@ -62,6 +64,15 @@ CONFIG_DEFAULTS = {
"grad_checkpoint": False,
"lr_schedule": None,
"lora_parameters": {"rank": 8, "alpha": 16, "dropout": 0.0, "scale": 10.0},
+
+ # GRPO args
+ "reference_model_path": None,
+ "group_size": 4,
+ "beta": 0.1,
+ "epsilon": 1e-4,
+ "max_completion_length": 512,
+ "use_chat_template": False,
+ "use_prompt": False,
}
@@ -94,6 +105,12 @@ def build_parser():
choices=["lora", "dora", "full"],
help="Type of fine-tuning to perform: lora, dora, or full.",
)
+ parser.add_argument(
+ "--training-mode",
+ type=str,
+ choices=["normal", "grpo"],
+ help="Training mode: normal or GRPO",
+ )
parser.add_argument(
"--num-layers",
type=int,
@@ -161,6 +178,44 @@ def build_parser():
default=None,
)
parser.add_argument("--seed", type=int, help="The PRNG seed")
+
+ # GRPO args
+ parser.add_argument(
+ "--group-size",
+ type=int,
+ help="Number of generations.",
+ default=4,
+ )
+ parser.add_argument(
+ "--max-completion-length",
+ type=int,
+ help="Maximum length of the prompt. If the prompt is longer than this value, it will be truncated left.",
+ default=512,
+ )
+ parser.add_argument(
+ "--beta",
+ type=float,
+ help="KL penalty coefficient.",
+ default=0.1,
+ )
+ parser.add_argument(
+ "--epsilon",
+ type=float,
+ help="The Epsilon for numerical stability.",
+ default=1e-4,
+ )
+ parser.add_argument(
+ "--use-chat-template",
+ action="store_true",
+ help="If the model is a Chat model, use the Chat template.",
+ default=None,
+ )
+ parser.add_argument(
+ "--use-prompt",
+ action="store_true",
+ help="Rather to use the prompt from the R1 paper.",
+ default=None,
+ )
return parser
@@ -220,32 +275,102 @@ def train_model(
)
)
# Train model
- train(
- model=model,
- tokenizer=tokenizer,
- args=training_args,
- optimizer=opt,
- train_dataset=train_set,
- val_dataset=valid_set,
- training_callback=training_callback,
- )
+ if args.training_mode == "grpo":
+ training_args = GRPOTrainingArgs(
+ batch_size=args.batch_size,
+ iters=args.iters,
+ val_batches=args.val_batches,
+ steps_per_report=args.steps_per_report,
+ steps_per_eval=args.steps_per_eval,
+ steps_per_save=args.save_every,
+ adapter_file=adapter_file,
+ max_seq_length=args.max_seq_length,
+ max_completion_length=args.max_completion_length,
+ grad_checkpoint=args.grad_checkpoint,
+ beta=args.beta,
+ group_size=args.group_size,
+ epsilon=args.epsilon,
+ reference_model_path=args.reference_model_path
+ )
+
+ if args.reference_model_path:
+ reference_model, _ = load(args.reference_model_path)
+ reference_model = reference_model.freeze()
+ else:
+ reference_model, _ = load(args.model)
+
+ train_grpo(
+ model=model,
+ ref_model=reference_model,
+ tokenizer=tokenizer,
+ optimizer=opt,
+ train_dataset=train_set,
+ val_dataset=valid_set,
+ args=training_args,
+ training_callback=training_callback,
+ )
+ else:
+ training_args = TrainingArgs(
+ batch_size=args.batch_size,
+ iters=args.iters,
+ val_batches=args.val_batches,
+ steps_per_report=args.steps_per_report,
+ steps_per_eval=args.steps_per_eval,
+ steps_per_save=args.save_every,
+ adapter_file=adapter_file,
+ max_seq_length=args.max_seq_length,
+ grad_checkpoint=args.grad_checkpoint
+ )
+
+ train(
+ model=model,
+ tokenizer=tokenizer,
+ args=training_args,
+ optimizer=opt,
+ train_dataset=train_set,
+ val_dataset=valid_set,
+ training_callback=training_callback,
+ )
def evaluate_model(args, model: nn.Module, tokenizer: TokenizerWrapper, test_set):
model.eval()
- test_loss = evaluate(
- model=model,
- dataset=test_set,
- tokenizer=tokenizer,
- batch_size=args.batch_size,
- num_batches=args.test_batches,
- max_seq_length=args.max_seq_length,
- )
+ if args.training_mode == "grpo":
+ if args.reference_model_path:
+ reference_model, _ = load(args.reference_model_path)
+ else:
+ reference_model = model
- test_ppl = math.exp(test_loss)
+ test_loss, _, test_rewards = evaluate_grpo(
+ model=model,
+ ref_model=reference_model,
+ dataset=test_set,
+ tokenizer=tokenizer,
+ batch_size=args.batch_size,
+ num_batches=args.test_batches,
+ max_seq_length=args.max_seq_length,
+ beta=args.beta,
+ group_size=args.group_size,
+ epsilon=args.epsilon
+ )
- print(f"Test loss {test_loss:.3f}, Test ppl {test_ppl:.3f}.")
+ test_ppl = math.exp(test_loss)
+
+ print(f"Test loss {test_loss:.3f}, Test ppl {test_ppl:.3f}, Rewards: {test_rewards[0]:.3f}, {test_rewards[1]:.3f}")
+ else:
+ test_loss = evaluate(
+ model=model,
+ dataset=test_set,
+ tokenizer=tokenizer,
+ batch_size=args.batch_size,
+ num_batches=args.test_batches,
+ max_seq_length=args.max_seq_length,
+ )
+
+ test_ppl = math.exp(test_loss)
+
+ print(f"Test loss {test_loss:.3f}, Test ppl {test_ppl:.3f}.")
def run(args, training_callback: TrainingCallback = None):
diff --git a/llms/mlx_lm/test_grpo b/llms/mlx_lm/test_grpo
new file mode 160000
index 00000000..a74695c9
--- /dev/null
+++ b/llms/mlx_lm/test_grpo
@@ -0,0 +1 @@
+Subproject commit a74695c9280dd46208ea000f507f44bc8ddd9533
diff --git a/llms/mlx_lm/tuner/datasets.py b/llms/mlx_lm/tuner/datasets.py
index 377e7cae..5d0ff68e 100644
--- a/llms/mlx_lm/tuner/datasets.py
+++ b/llms/mlx_lm/tuner/datasets.py
@@ -1,10 +1,59 @@
import json
from pathlib import Path
-from typing import Dict, List, Optional
+from typing import Dict, List, Optional, Tuple
from transformers import PreTrainedTokenizer
+class GRPODataset:
+ """
+ Dataset wrapper for GRPO training data.
+ Each example should have a 'prompt' and 'answer' field.
+ Returns data in (prompt_tokens, answer_tokens, prompt_str, answer_str) tuple format.
+ """
+ def __init__(
+ self,
+ data: List[Dict[str, str]],
+ tokenizer: PreTrainedTokenizer,
+ prompt_key: str = "prompt",
+ answer_key: str = "answer",
+ use_chat_template: bool = False,
+ use_prompt: bool = False
+ ):
+ self._data = []
+ for item in data:
+ prompt_str = str(item[prompt_key])
+ answer_str = str(item[answer_key])
+ if use_chat_template:
+ prompt_tokens = tokenizer.apply_chat_template(
+ [
+ {'role': 'system', 'content': """A conversation between User and Assistant. The user asks a question, and the Assistant solves it.
+ The assistantfirst thinks about the reasoning process in the mind and then provides the user with the answer.
+ The reasoning process and answer are enclosed within and tags, respectively, i.e., reasoning process here answer here ."""},
+ {'role': 'user', 'content': prompt_str}
+ ],
+ )
+ answer_tokens = tokenizer.encode(answer_str)
+ else:
+ if use_prompt:
+ prompt_tokens = tokenizer.encode(f"""A conversation between User and Assistant. The user asks a question, and the Assistant solves it.
+ The assistantfirst thinks about the reasoning process in the mind and then provides the user with the answer.
+ The reasoning process and answer are enclosed within and tags, respectively, i.e., reasoning process here answer here .
+ User: {prompt_str}. Assistant: """)
+ else:
+ prompt_tokens = tokenizer.encode(prompt_str)
+ answer_tokens = tokenizer.encode(answer_str)
+ self._data.append((prompt_tokens, answer_tokens, prompt_str, answer_str))
+
+ def __getitem__(self, idx: int) -> Tuple[List[int], List[int], str, str]:
+ """Returns a (prompt_tokens, answer_tokens, prompt_str, answer_str) tuple."""
+ return self._data[idx]
+
+ def __len__(self) -> int:
+ """Returns the number of examples in the dataset."""
+ return len(self._data)
+
+
class Dataset:
"""
Light-weight wrapper to hold a dataset.
@@ -82,6 +131,7 @@ class CompletionsDataset:
def create_dataset(
+ args,
data,
tokenizer: PreTrainedTokenizer,
prompt_feature: Optional[str] = None,
@@ -90,31 +140,44 @@ def create_dataset(
prompt_feature = prompt_feature or "prompt"
completion_feature = completion_feature or "completion"
sample = data[0]
- if "messages" in sample:
- return ChatDataset(data, tokenizer)
- elif prompt_feature in sample and completion_feature in sample:
- return CompletionsDataset(data, tokenizer, prompt_feature, completion_feature)
- elif "text" in sample:
- return Dataset(data, tokenizer)
+
+ if args.training_mode == "normal":
+ if "messages" in sample:
+ return ChatDataset(data, tokenizer)
+ elif prompt_feature in sample and completion_feature in sample:
+ return CompletionsDataset(data, tokenizer, prompt_feature, completion_feature)
+ elif "text" in sample:
+ return Dataset(data, tokenizer)
+ else:
+ raise ValueError(
+ "Unsupported data format, check the supported formats here:\n"
+ "https://github.com/ml-explore/mlx-examples/blob/main/llms/mlx_lm/LORA.md#data."
+ )
else:
- raise ValueError(
- "Unsupported data format, check the supported formats here:\n"
- "https://github.com/ml-explore/mlx-examples/blob/main/llms/mlx_lm/LORA.md#data."
+ return GRPODataset(
+ data=data,
+ tokenizer=tokenizer,
+ prompt_key="prompt",
+ answer_key="answer",
+ use_chat_template=args.use_chat_template,
+ use_prompt=args.use_prompt
)
def load_local_dataset(
+ args,
data_path: Path,
tokenizer: PreTrainedTokenizer,
prompt_feature: Optional[str] = None,
completion_feature: Optional[str] = None,
):
def load_subset(path):
+ print(path)
if not path.exists():
return []
with open(path, "r") as fid:
data = [json.loads(l) for l in fid]
- return create_dataset(data, tokenizer, prompt_feature, completion_feature)
+ return create_dataset(args, data, tokenizer, prompt_feature, completion_feature)
names = ("train", "valid", "test")
train, valid, test = [load_subset(data_path / f"{n}.jsonl") for n in names]
@@ -122,6 +185,7 @@ def load_local_dataset(
def load_hf_dataset(
+ args,
data_id: str,
tokenizer: PreTrainedTokenizer,
prompt_feature: Optional[str] = None,
@@ -137,7 +201,7 @@ def load_hf_dataset(
train, valid, test = [
(
create_dataset(
- dataset[n], tokenizer, prompt_feature, completion_feature
+ args, dataset[n], tokenizer, prompt_feature, completion_feature
)
if n in dataset.keys()
else []
@@ -202,12 +266,12 @@ def load_dataset(args, tokenizer: PreTrainedTokenizer):
completion_feature = getattr(args, "completion_feature", None)
if data_path.exists():
train, valid, test = load_local_dataset(
- data_path, tokenizer, prompt_feature, completion_feature
+ args, data_path, tokenizer, prompt_feature, completion_feature
)
else:
print(f"Loading Hugging Face dataset {args.data}.")
train, valid, test = load_hf_dataset(
- args.data, tokenizer, prompt_feature, completion_feature
+ args, args.data, tokenizer, prompt_feature, completion_feature
)
if args.train and len(train) == 0:
diff --git a/llms/mlx_lm/tuner/grpo_trainer.py b/llms/mlx_lm/tuner/grpo_trainer.py
new file mode 100644
index 00000000..639c3126
--- /dev/null
+++ b/llms/mlx_lm/tuner/grpo_trainer.py
@@ -0,0 +1,690 @@
+# Copyright © 2024 Apple Inc.
+
+import time
+from dataclasses import dataclass, field
+from pathlib import Path
+import re
+
+import mlx.core as mx
+import mlx.nn as nn
+import numpy as np
+from mlx.utils import tree_flatten
+
+from .trainer import grad_checkpoint, TrainingArgs, TrainingCallback, average_gradients, iterate_batches
+
+@dataclass
+class GRPOTrainingArgs(TrainingArgs):
+ group_size: int = field(
+ default=4,
+ metadata={"help": "Number of responses per prompt."},
+ )
+ beta: float = field(
+ default=0.1, metadata={"help": "KL penalty coefficient."}
+ )
+ epsilon: float = field(
+ default=1e-4, metadata={"help": "The Epsilon for numerical stability."}
+ )
+ max_completion_length: int = field(
+ default=512, metadata={"help": "Number of Generations."}
+ )
+ reference_model_path: str = field(
+ default=None,
+ metadata={
+ "help": "Path to reference model weights. If None, uses the same model."
+ }
+ )
+
+
+def r1_extract_xml_answer(text: str) -> str:
+ """Extracts the answer from an XML formatted text string."""
+ try:
+ answer = text.split("")[-1]
+ answer = answer.split("")[0]
+ return answer.strip()
+ except:
+ print("r1_extract_xml_answer returned empty string")
+ return ""
+
+
+def r1_int_reward_func(prompts: list, completions: list, answer: list, **kwargs) -> list[float]:
+ """Ensures we always return a list of floats."""
+ if not completions:
+ return [0.0] * len(prompts)
+ extracted_responses = [r1_extract_xml_answer(r) for r in completions]
+ return [0.5 if r and r.isdigit() else 0.0 for r in extracted_responses]
+
+def r1_accuracy_reward_func(prompts: list, completions: list, answer: list, **kwargs) -> list[float]:
+ """Ensures we always return a list of floats."""
+ if not completions or not answer:
+ return [0.0] * len(prompts)
+ extracted_responses = [r1_extract_xml_answer(r) for r in completions]
+ return [2.0 if r and a and r == a else 0.0 for r, a in zip(extracted_responses, answer)]
+
+
+def r1_soft_format_reward_func(prompts: list, completions: list, answer: list, **kwargs) -> list[float]:
+ """Ensures we always return a list of floats."""
+ if not completions:
+ return [0.0] * len(prompts)
+ pattern = r".*?\s*.*?"
+ matches = [bool(re.search(pattern, r)) if r else False for r in completions]
+ return [0.5 if match else 0.0 for match in matches]
+
+
+def r1_strict_format_reward_func(prompts: list, completions: list, answer: list, **kwargs) -> list[float]:
+ """Ensures we always return a list of floats."""
+ if not completions:
+ return [0.0] * len(prompts)
+ pattern = r"^\n.*?\n\n\n.*?\n\n$"
+ matches = [bool(re.search(pattern, r)) if r else False for r in completions]
+ return [0.5 if match else 0.0 for match in matches]
+
+
+def r1_count_xml(prompts: list, completions: list, answer: list, **kwargs) -> list[float]:
+ """Ensures we always return a list of floats."""
+ if not completions:
+ return [0.0] * len(prompts)
+
+ scores = []
+ for text in completions:
+ if not text:
+ scores.append(0.0)
+ continue
+
+ count = 0.0
+ if text.count("\n") == 1:
+ count += 0.125
+ if text.count("\n\n") == 1:
+ count += 0.125
+ if text.count("\n\n") == 1:
+ count += 0.125
+ if text.count("\n\n") == 1:
+ count += 0.125
+
+ # Penalize extra text after
+ end_text = text.split("\n\n")[-1]
+ count -= len(end_text) * 0.001 if len(end_text) > 0 else 0
+
+ scores.append(max(0.0, count)) # Ensure non-negative score
+
+ return scores
+
+
+def generate_grpo(model, prompt, max_tokens, tokenizer, temperature):
+ if len(prompt.shape) == 1:
+ prompt = prompt[None, :]
+ if prompt.shape[1] == 0:
+ return None
+
+ end_sequence = tokenizer.encode("")
+ end_sequence_length = len(end_sequence)
+ output = mx.zeros((prompt.shape[1] + max_tokens,), dtype=mx.int32)
+ output[:prompt.shape[1]] = prompt[0]
+ current_length = prompt.shape[1]
+
+ try:
+ def sample(logits):
+ if temperature > 0:
+ logits /= temperature
+ logprobs = logits - mx.logsumexp(logits, keepdims=True)
+ return mx.random.categorical(logprobs[None, :]).astype(mx.int32)[0]
+
+ for _ in range(max_tokens):
+ current_input = output[:current_length][None, :]
+ logits = model(current_input)
+ token_logits = logits[0, -1]
+ next_token = sample(token_logits)
+ token_value = next_token.item()
+ output[current_length] = token_value
+ current_length += 1
+
+ if token_value == tokenizer.eos_token_id:
+ break
+
+ if current_length >= end_sequence_length:
+ last_tokens = output[current_length - end_sequence_length:current_length].tolist()
+ # print(f"Last tokens: {last_tokens}")
+ # print(f"Decoded text: {tokenizer.decode(last_tokens)}")
+ # print(f"Target sequence: {end_sequence}")
+ if last_tokens == end_sequence:
+ break
+
+ if current_length > prompt.shape[1]:
+ return output[:current_length]
+
+ except Exception as e:
+ print(f"Generation error: {str(e)}")
+ return None
+
+ return None
+
+
+def get_per_token_logps(model, inputs, lengths):
+ logits = model(inputs).astype(mx.float16)
+ logits = logits[:, :-1, :]
+ targets = inputs[:, 1:]
+
+ per_token_logps = []
+ for i in range(logits.shape[0]):
+ seq_len = int(lengths[i]) - 1
+
+ seq_logits = logits[i, :seq_len]
+ seq_targets = targets[i, :seq_len]
+
+ log_probs = nn.log_softmax(seq_logits, axis=-1)
+
+ token_log_probs = mx.take_along_axis(
+ log_probs,
+ seq_targets.reshape(seq_len, 1),
+ axis=-1
+ ).squeeze(-1)
+
+ per_token_logps.append(token_log_probs)
+ mx.eval(logits)
+ return per_token_logps
+
+
+def grpo_loss(
+ model,
+ tokenizer,
+ batch,
+ reward_funcs=None,
+ beta=0.1,
+ group_size=4,
+ epsilon=1e-4,
+ ref_model=None,
+ max_tokens=64,
+ temperature=1.0
+):
+ prompt_tokens, answer_tokens, prompt_text, answer_text = batch
+ batch_size = len(prompt_tokens)
+
+ # Generation logic remains the same
+ all_completions = []
+ all_completion_texts = []
+
+ for i in range(0, batch_size, batch_size):
+ batch_prompts = prompt_tokens[i:i+batch_size]
+ for prompt in batch_prompts:
+ prompt_tensor = mx.array(prompt)
+ for _ in range(group_size):
+ try:
+ completion_ids = generate_grpo(model, prompt_tensor, max_tokens, tokenizer, temperature)
+ if completion_ids is not None:
+ completion_text = tokenizer.decode(completion_ids.tolist())
+ all_completions.append(completion_ids)
+ all_completion_texts.append(completion_text)
+
+ # Clear completion tensors
+ mx.eval(completion_ids)
+ del completion_ids
+ except Exception as e:
+ print(f"Generation error: {e}")
+ continue
+
+ mx.metal.clear_cache()
+
+ # Prepare inputs
+ expanded_answers = []
+ expanded_prompts = []
+ for i in range(batch_size):
+ expanded_answers.extend([answer_text[i]] * group_size)
+ expanded_prompts.extend([prompt_text[i]] * group_size)
+
+ max_length = max(ids.shape[0] for ids in all_completions)
+ padded_completions = []
+ attention_masks = []
+
+ for completion_ids in all_completions:
+ padding_length = max_length - completion_ids.shape[0]
+ if padding_length > 0:
+ padding = mx.zeros((padding_length,), dtype=completion_ids.dtype)
+ padded_ids = mx.concatenate([completion_ids, padding])
+ mask = mx.concatenate([mx.ones_like(completion_ids), mx.zeros_like(padding)])
+ else:
+ padded_ids = completion_ids
+ mask = mx.ones_like(completion_ids)
+ padded_completions.append(padded_ids)
+ attention_masks.append(mask)
+
+ inputs = mx.stack(padded_completions)
+ attention_mask = mx.stack(attention_masks)
+ lengths = attention_mask.sum(axis=1)
+
+ # Current policy probabilities
+ token_log_probs = get_per_token_logps(model, inputs, lengths)
+
+ mx.eval(token_log_probs)
+ mx.metal.clear_cache()
+
+ # Reference policy probabilities
+ if ref_model is not None:
+ ref_token_log_probs = get_per_token_logps(ref_model, inputs, lengths)
+ else:
+ ref_token_log_probs = token_log_probs
+
+ max_len = max(x.shape[0] for x in token_log_probs)
+ padded_log_probs = []
+ padded_ref_log_probs = []
+
+ for i in range(len(token_log_probs)):
+ seq_len = token_log_probs[i].shape[0]
+ padding = mx.zeros((max_len - seq_len,), dtype=mx.float16)
+
+ padded_log_probs.append(mx.concatenate([token_log_probs[i], padding]))
+ padded_ref_log_probs.append(mx.concatenate([ref_token_log_probs[i], padding]))
+
+ token_log_probs = mx.stack(padded_log_probs)
+ ref_token_log_probs = mx.stack(padded_ref_log_probs)
+
+ # Calculate rewards and advantages
+ rewards = mx.zeros((len(all_completions),))
+ for reward_func in reward_funcs:
+ func_rewards = mx.array(reward_func(
+ prompts=expanded_prompts,
+ completions=all_completion_texts,
+ answer=expanded_answers
+ ))
+ rewards += func_rewards
+
+ if len(reward_funcs) > 1:
+ rewards /= len(reward_funcs)
+
+ # Reshape rewards and compute advantages following GRPO formula
+ rewards_reshaped = rewards.reshape(batch_size, group_size)
+ mean_rewards = mx.broadcast_to(mx.mean(rewards_reshaped, axis=1)[:, None], (rewards_reshaped.shape[0], group_size)).reshape(-1)
+ std_rewards = mx.broadcast_to(mx.std(rewards_reshaped, axis=1)[:, None], (rewards_reshaped.shape[0], group_size)).reshape(-1)
+ advantages = (rewards - mean_rewards) / (std_rewards + epsilon)
+
+ # Compute KL divergence using Schulman's approximator
+ kl_div = (mx.exp(token_log_probs - ref_token_log_probs) - 1) - (token_log_probs - ref_token_log_probs)
+
+ # Create mask for valid tokens
+ length_mask = mx.arange(inputs.shape[1] - 1)[None, :] < (lengths[:, None] - 1)
+
+ # Compute policy ratio
+ policy_ratio = mx.exp(token_log_probs - mx.stop_gradient(token_log_probs))
+
+ # Compute per-token loss following GRPO formula
+ per_token_loss = -((policy_ratio * advantages.reshape(-1, 1) - beta * kl_div) * length_mask)
+
+ # Average over tokens and sequences
+ sequence_sums = per_token_loss.sum(axis=1)
+ sequence_lengths = length_mask.sum(axis=1)
+
+ loss = (sequence_sums / sequence_lengths).mean()
+
+ # Calculate mean KL divergence for metrics
+ mean_kl = ((kl_div * length_mask).sum(axis=1) / length_mask.sum(axis=1)).mean()
+
+ # Collect reward metrics
+ reward_metrics = {}
+ for i, reward_func in enumerate(reward_funcs):
+ func_name = reward_func.__name__
+ func_rewards = mx.array(reward_func(
+ prompts=expanded_prompts,
+ completions=all_completion_texts,
+ answer=expanded_answers
+ ))
+ reward_metrics[f'{func_name}_mean'] = mx.mean(func_rewards)
+ reward_metrics[f'{func_name}_std'] = mx.std(func_rewards)
+
+ metrics = {
+ 'total_rewards_mean': mx.mean(rewards),
+ 'total_rewards_std': mx.std(rewards),
+ 'grouped_rewards_mean': mx.mean(rewards_reshaped),
+ 'grouped_rewards_std': mx.std(rewards_reshaped),
+ 'kl': mean_kl,
+ **reward_metrics
+ }
+ mx.metal.clear_cache()
+
+ return loss, sequence_lengths.sum(), metrics
+
+
+def iterate_grpo_batches(dataset, tokenizer, batch_size, max_seq_length, train=False):
+ """Memory-optimized version of iterate_grpo_batches"""
+ if not dataset or not isinstance(dataset[0], tuple) or len(dataset[0]) != 4:
+ raise ValueError("Dataset must be list of (prompt_tokens, answer_tokens, prompt_str, answer_str) tuples")
+
+ # Sort by length but use generator to avoid keeping full sorted list in memory
+ def length_key(i):
+ return len(dataset[i][0]) + len(dataset[i][1])
+
+ idx = sorted(range(len(dataset)), key=length_key)
+
+ if len(dataset) < batch_size:
+ raise ValueError(
+ f"Dataset must have at least batch_size={batch_size} "
+ f"examples but only has {len(dataset)}."
+ )
+
+ step = mx.distributed.init().size()
+ if batch_size % step != 0:
+ raise ValueError("The batch size must be divisible by the number of workers")
+
+ # Use generator for batch indices
+ def batch_index_generator():
+ for i in range(0, len(idx) - batch_size + 1, batch_size):
+ yield idx[i : i + batch_size : step]
+
+ while True:
+ indices = (
+ np.random.permutation(list(batch_index_generator())) if train
+ else batch_index_generator()
+ )
+
+ for batch_idx in indices:
+ current_batch = [dataset[j] for j in batch_idx]
+
+ prompts_tokens = [item[0] for item in current_batch]
+ answers_tokens = [item[1] for item in current_batch]
+ prompts_text = [item[2] for item in current_batch]
+ answers_text = [item[3] for item in current_batch]
+
+ if any(len(p) > max_seq_length for p in prompts_tokens):
+ print(
+ f"[WARNING] Some prompts are longer than {max_seq_length} tokens. "
+ "Long prompts will be truncated."
+ )
+
+ yield prompts_tokens, answers_tokens, prompts_text, answers_text
+
+ if not train:
+ break
+
+
+def evaluate_grpo(
+ model,
+ ref_model,
+ dataset,
+ tokenizer,
+ batch_size,
+ num_batches,
+ beta: float,
+ epsilon: float,
+ group_size: int,
+ max_seq_length,
+ reward_funcs = None,
+ loss_fn: callable = grpo_loss,
+ iterate_batches: callable = iterate_grpo_batches
+):
+ """
+ Evaluate model using GRPO loss.
+ Returns:
+ tuple: (average loss, number of tokens, average metrics)
+ """
+ all_losses = 0
+ ntokens = 0
+ all_metrics = None # Initialize metrics dictionary
+
+ # Create iterator for batches
+ index_iterator = iter(range(num_batches)) if num_batches != -1 else iter(int, 1)
+
+ # Iterate through batches
+ for _, batch in zip(
+ index_iterator,
+ iterate_batches(
+ dataset=dataset,
+ tokenizer=tokenizer,
+ batch_size=batch_size,
+ max_seq_length=max_seq_length,
+ ),
+ ):
+ # Calculate loss for current batch
+ losses, toks, metrics = loss_fn(
+ model=model,
+ tokenizer=tokenizer,
+ batch=batch,
+ reward_funcs=reward_funcs,
+ beta=beta,
+ group_size=group_size,
+ epsilon=epsilon,
+ ref_model=ref_model
+ )
+
+ # Accumulate losses and tokens
+ all_losses += losses * toks
+ ntokens += toks
+
+ # Accumulate metrics
+ if all_metrics is None:
+ all_metrics = {k: v * toks for k, v in metrics.items()}
+ else:
+ for k, v in metrics.items():
+ all_metrics[k] += v * toks
+
+ # Evaluate accumulated values
+ mx.eval(all_losses, ntokens)
+
+ # Aggregate across distributed workers
+ all_losses = mx.distributed.all_sum(all_losses, stream=mx.cpu)
+ ntokens = mx.distributed.all_sum(ntokens, stream=mx.cpu)
+ all_metrics = {k: mx.distributed.all_sum(v) for k, v in all_metrics.items()}
+
+ # Calculate averages
+ avg_metrics = {k: (v / ntokens).item() for k, v in all_metrics.items()}
+ avg_loss = (all_losses / ntokens).item()
+
+ return avg_loss, ntokens, avg_metrics
+
+
+def train_grpo(
+ model,
+ ref_model,
+ tokenizer,
+ optimizer,
+ train_dataset,
+ val_dataset,
+ reward_funcs = [
+ r1_accuracy_reward_func,
+ r1_int_reward_func,
+ r1_strict_format_reward_func,
+ r1_soft_format_reward_func,
+ r1_count_xml
+ ],
+ args: GRPOTrainingArgs = GRPOTrainingArgs(),
+ loss_fn: callable = grpo_loss,
+ iterate_batches: callable = iterate_grpo_batches,
+ training_callback: TrainingCallback = None,
+):
+ print(f"Starting GRPO training with {len(reward_funcs)} reward functions..., iters: {args.iters}")
+ world = mx.distributed.init()
+ world_size = world.size()
+ rank = world.rank()
+ if world_size > 1:
+ print(f"Node {rank} of {world_size}")
+
+ if args.grad_checkpoint:
+ grad_checkpoint(model.layers[0])
+
+ state = [model.state, optimizer.state]
+
+ def step(batch):
+
+ # Forward and backward pass
+ (loss, toks, metrics), grad = loss_value_and_grad(
+ model,
+ tokenizer=tokenizer,
+ batch=batch,
+ reward_funcs=reward_funcs,
+ beta=args.beta,
+ group_size=args.group_size,
+ epsilon=args.epsilon,
+ ref_model=ref_model,
+ max_tokens=args.max_completion_length,
+ )
+
+ # All reduce the gradients if running in distributed mode
+ grad = average_gradients(grad)
+
+ # Model update
+ optimizer.update(model, grad)
+
+ return loss, toks, metrics
+
+ loss_value_and_grad = nn.value_and_grad(model, loss_fn)
+
+ losses = 0
+ n_tokens = 0
+ steps = 0
+ trained_tokens = 0
+ accumulated_metrics = {
+ 'total_rewards_mean': 0,
+ 'total_rewards_std': 0,
+ 'grouped_rewards_mean': 0,
+ 'grouped_rewards_std': 0,
+ 'kl': 0
+ }
+ for reward_func in reward_funcs:
+ func_name = reward_func.__name__
+ accumulated_metrics[f'{func_name}_mean'] = 0
+ accumulated_metrics[f'{func_name}_std'] = 0
+
+ start = time.perf_counter()
+ for it, batch in zip(
+ range(1, args.iters + 1),
+ iterate_batches(
+ dataset=train_dataset,
+ tokenizer=tokenizer,
+ batch_size=args.batch_size,
+ max_seq_length=args.max_seq_length,
+ train=True,
+ ),
+ ):
+ # Report validation loss if needed, the first validation loss
+ # is always measured before any training.
+ if it == 1 or it % args.steps_per_eval == 0 or it == args.iters:
+ stop = time.perf_counter()
+ val_loss, val_ntokens, val_metrics = evaluate_grpo(
+ model=model,
+ dataset=val_dataset,
+ loss_fn=loss_fn,
+ ref_model=ref_model,
+ reward_funcs=reward_funcs,
+ tokenizer=tokenizer,
+ group_size=args.group_size,
+ batch_size=args.batch_size,
+ num_batches=args.val_batches,
+ max_seq_length=args.max_seq_length,
+ beta=args.beta,
+ epsilon=args.epsilon,
+ iterate_batches=iterate_batches,
+ )
+ val_time = time.perf_counter() - stop
+ if rank == 0:
+ val_metrics_str = (
+ f"Val loss {val_loss:.8f}, "
+ f"Val total_rewards_mean {val_metrics['total_rewards_mean']:.3f}, "
+ f"Val total_rewards_std {val_metrics['total_rewards_std']:.3f}, "
+ f"Val grouped_rewards_mean {val_metrics['grouped_rewards_mean']:.3f}, "
+ f"Val grouped_rewards_std {val_metrics['grouped_rewards_std']:.3f}, "
+ f"Val kl {val_metrics['kl']:.3f}"
+ )
+
+ # Add reward function specific metrics
+ for i, reward_func in enumerate(reward_funcs):
+ val_metrics_str += (
+ f", Val {reward_func.__name__}_mean {val_metrics[f'{reward_func.__name__}_mean']:.3f}, "
+ f"Val {reward_func.__name__}_std {val_metrics[f'{reward_func.__name__}_std']:.3f}"
+ )
+
+ print(
+ f"Iter {it}: {val_metrics_str}, "
+ f"Val took {val_time:.3f}s",
+ flush=True,
+ )
+
+ if training_callback is not None:
+ training_callback.on_val_loss_report({
+ "iteration": it,
+ "val_loss": val_loss,
+ **{f"val_{k}": v for k, v in val_metrics.items()},
+ "val_time": val_time,
+ })
+
+ start = time.perf_counter()
+
+ loss, toks, metrics = step(batch)
+ losses += loss
+ n_tokens += toks
+ steps += 1
+
+ for k, v in metrics.items():
+ accumulated_metrics[k] += v
+
+ mx.eval(state, losses, n_tokens)
+
+ if it % args.steps_per_report == 0 or it == args.iters:
+ stop = time.perf_counter()
+
+ train_loss = mx.distributed.all_sum(losses, stream=mx.cpu).item()
+ train_loss /= steps * mx.distributed.init().size()
+ avg_metrics = {k: v / (steps * world_size) for k, v in accumulated_metrics.items()}
+ n_tokens = mx.distributed.all_sum(n_tokens, stream=mx.cpu).item()
+ learning_rate = optimizer.learning_rate.item()
+ it_sec = args.steps_per_report / (stop - start)
+ tokens_sec = float(n_tokens) / (stop - start)
+ trained_tokens += n_tokens
+ peak_mem = mx.metal.get_peak_memory() / 1e9
+
+ if rank == 0:
+ train_metrics_str = (
+ f"Train loss {train_loss:.8f}, "
+ f"Total rewards mean {avg_metrics['total_rewards_mean']:.3f}, "
+ f"Total rewards std {avg_metrics['total_rewards_std']:.3f}, "
+ f"Grouped rewards mean {avg_metrics['grouped_rewards_mean']:.3f}, "
+ f"Grouped rewards std {avg_metrics['grouped_rewards_std']:.3f}, "
+ f"KL {avg_metrics['kl']:.3f}"
+ )
+
+ # Add reward function specific metrics
+ for i, reward_func in enumerate(reward_funcs):
+ func_name = reward_func.__name__
+ train_metrics_str += (
+ f", {func_name} mean {avg_metrics[f'{func_name}_mean']:.3f}, "
+ f"{func_name} std {avg_metrics[f'{func_name}_std']:.3f}"
+ )
+
+ print(
+ f"Iter {it}: {train_metrics_str}, "
+ f"Learning Rate {learning_rate:.3e}, "
+ f"It/sec {it_sec:.3f}, "
+ f"Tokens/sec {tokens_sec:.3f}, "
+ f"Peak mem {peak_mem:.3f} GB",
+ flush=True,
+ )
+
+ if training_callback is not None:
+ training_callback.on_train_loss_report({
+ "iteration": it,
+ "train_loss": train_loss,
+ **{f"train_{k}": v for k, v in avg_metrics.items()},
+ "learning_rate": learning_rate,
+ "iterations_per_second": it_sec,
+ "tokens_per_second": tokens_sec,
+ "trained_tokens": trained_tokens,
+ "peak_memory": peak_mem,
+ })
+
+ losses = 0
+ n_tokens = 0
+ steps = 0
+ start = time.perf_counter()
+
+ # Save adapter weights
+ if it % args.steps_per_save == 0:
+ adapter_weights = dict(tree_flatten(model.trainable_parameters()))
+ mx.save_safetensors(str(args.adapter_file), adapter_weights)
+ checkpoint = (
+ Path(args.adapter_file).parent / f"{it:07d}_adapters.safetensors"
+ )
+ mx.save_safetensors(str(checkpoint), adapter_weights)
+ print(
+ f"Iter {it}: Saved adapter weights to "
+ f"{args.adapter_file} and {checkpoint}."
+ )
+
+ # Save final weights
+ adapter_weights = dict(tree_flatten(model.trainable_parameters()))
+ mx.save_safetensors(str(args.adapter_file), adapter_weights)
+ print(f"Saved final weights to {args.adapter_file}.")
\ No newline at end of file