mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-08-28 00:30:09 +08:00
init grpo
This commit is contained in:
parent
e2e5478da5
commit
ec50a869b0
19
llms/mlx_lm/convert2.py
Normal file
19
llms/mlx_lm/convert2.py
Normal file
@ -0,0 +1,19 @@
|
||||
import pandas as pd
|
||||
import os
|
||||
|
||||
# Define dataset directory
|
||||
dataset_dir = "/Users/cshang/Desktop/test_grpo/data"
|
||||
|
||||
# Convert each Parquet file to JSONL
|
||||
for file in os.listdir(dataset_dir):
|
||||
if file.endswith(".parquet"):
|
||||
parquet_path = os.path.join(dataset_dir, file)
|
||||
jsonl_path = os.path.join(dataset_dir, file.replace(".parquet", ".jsonl"))
|
||||
|
||||
# Load Parquet file
|
||||
df = pd.read_parquet(parquet_path)
|
||||
|
||||
# Convert to JSONL format
|
||||
df.to_json(jsonl_path, orient="records", lines=True)
|
||||
|
||||
print(f"Converted {parquet_path} -> {jsonl_path}")
|
@ -15,6 +15,7 @@ import yaml
|
||||
from .tokenizer_utils import TokenizerWrapper
|
||||
from .tuner.datasets import load_dataset
|
||||
from .tuner.trainer import TrainingArgs, TrainingCallback, evaluate, train
|
||||
from .tuner.grpo_trainer import GRPOTrainingArgs, evaluate_grpo, train_grpo
|
||||
from .tuner.utils import (
|
||||
build_schedule,
|
||||
linear_to_lora_layers,
|
||||
@ -42,6 +43,7 @@ yaml_loader.add_implicit_resolver(
|
||||
CONFIG_DEFAULTS = {
|
||||
"model": "mlx_model",
|
||||
"train": False,
|
||||
"training_mode": "normal",
|
||||
"fine_tune_type": "lora",
|
||||
"data": "data/",
|
||||
"seed": 0,
|
||||
@ -62,6 +64,15 @@ CONFIG_DEFAULTS = {
|
||||
"grad_checkpoint": False,
|
||||
"lr_schedule": None,
|
||||
"lora_parameters": {"rank": 8, "alpha": 16, "dropout": 0.0, "scale": 10.0},
|
||||
|
||||
# GRPO args
|
||||
"reference_model_path": None,
|
||||
"group_size": 4,
|
||||
"beta": 0.1,
|
||||
"epsilon": 1e-4,
|
||||
"max_completion_length": 512,
|
||||
"use_chat_template": False,
|
||||
"use_prompt": False,
|
||||
}
|
||||
|
||||
|
||||
@ -94,6 +105,12 @@ def build_parser():
|
||||
choices=["lora", "dora", "full"],
|
||||
help="Type of fine-tuning to perform: lora, dora, or full.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--training-mode",
|
||||
type=str,
|
||||
choices=["normal", "grpo"],
|
||||
help="Training mode: normal or GRPO",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--num-layers",
|
||||
type=int,
|
||||
@ -161,6 +178,44 @@ def build_parser():
|
||||
default=None,
|
||||
)
|
||||
parser.add_argument("--seed", type=int, help="The PRNG seed")
|
||||
|
||||
# GRPO args
|
||||
parser.add_argument(
|
||||
"--group-size",
|
||||
type=int,
|
||||
help="Number of generations.",
|
||||
default=4,
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max-completion-length",
|
||||
type=int,
|
||||
help="Maximum length of the prompt. If the prompt is longer than this value, it will be truncated left.",
|
||||
default=512,
|
||||
)
|
||||
parser.add_argument(
|
||||
"--beta",
|
||||
type=float,
|
||||
help="KL penalty coefficient.",
|
||||
default=0.1,
|
||||
)
|
||||
parser.add_argument(
|
||||
"--epsilon",
|
||||
type=float,
|
||||
help="The Epsilon for numerical stability.",
|
||||
default=1e-4,
|
||||
)
|
||||
parser.add_argument(
|
||||
"--use-chat-template",
|
||||
action="store_true",
|
||||
help="If the model is a Chat model, use the Chat template.",
|
||||
default=None,
|
||||
)
|
||||
parser.add_argument(
|
||||
"--use-prompt",
|
||||
action="store_true",
|
||||
help="Rather to use the prompt from the R1 paper.",
|
||||
default=None,
|
||||
)
|
||||
return parser
|
||||
|
||||
|
||||
@ -220,32 +275,102 @@ def train_model(
|
||||
)
|
||||
)
|
||||
# Train model
|
||||
train(
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
args=training_args,
|
||||
optimizer=opt,
|
||||
train_dataset=train_set,
|
||||
val_dataset=valid_set,
|
||||
training_callback=training_callback,
|
||||
)
|
||||
if args.training_mode == "grpo":
|
||||
training_args = GRPOTrainingArgs(
|
||||
batch_size=args.batch_size,
|
||||
iters=args.iters,
|
||||
val_batches=args.val_batches,
|
||||
steps_per_report=args.steps_per_report,
|
||||
steps_per_eval=args.steps_per_eval,
|
||||
steps_per_save=args.save_every,
|
||||
adapter_file=adapter_file,
|
||||
max_seq_length=args.max_seq_length,
|
||||
max_completion_length=args.max_completion_length,
|
||||
grad_checkpoint=args.grad_checkpoint,
|
||||
beta=args.beta,
|
||||
group_size=args.group_size,
|
||||
epsilon=args.epsilon,
|
||||
reference_model_path=args.reference_model_path
|
||||
)
|
||||
|
||||
if args.reference_model_path:
|
||||
reference_model, _ = load(args.reference_model_path)
|
||||
reference_model = reference_model.freeze()
|
||||
else:
|
||||
reference_model, _ = load(args.model)
|
||||
|
||||
train_grpo(
|
||||
model=model,
|
||||
ref_model=reference_model,
|
||||
tokenizer=tokenizer,
|
||||
optimizer=opt,
|
||||
train_dataset=train_set,
|
||||
val_dataset=valid_set,
|
||||
args=training_args,
|
||||
training_callback=training_callback,
|
||||
)
|
||||
else:
|
||||
training_args = TrainingArgs(
|
||||
batch_size=args.batch_size,
|
||||
iters=args.iters,
|
||||
val_batches=args.val_batches,
|
||||
steps_per_report=args.steps_per_report,
|
||||
steps_per_eval=args.steps_per_eval,
|
||||
steps_per_save=args.save_every,
|
||||
adapter_file=adapter_file,
|
||||
max_seq_length=args.max_seq_length,
|
||||
grad_checkpoint=args.grad_checkpoint
|
||||
)
|
||||
|
||||
train(
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
args=training_args,
|
||||
optimizer=opt,
|
||||
train_dataset=train_set,
|
||||
val_dataset=valid_set,
|
||||
training_callback=training_callback,
|
||||
)
|
||||
|
||||
|
||||
def evaluate_model(args, model: nn.Module, tokenizer: TokenizerWrapper, test_set):
|
||||
model.eval()
|
||||
|
||||
test_loss = evaluate(
|
||||
model=model,
|
||||
dataset=test_set,
|
||||
tokenizer=tokenizer,
|
||||
batch_size=args.batch_size,
|
||||
num_batches=args.test_batches,
|
||||
max_seq_length=args.max_seq_length,
|
||||
)
|
||||
if args.training_mode == "grpo":
|
||||
if args.reference_model_path:
|
||||
reference_model, _ = load(args.reference_model_path)
|
||||
else:
|
||||
reference_model = model
|
||||
|
||||
test_ppl = math.exp(test_loss)
|
||||
test_loss, _, test_rewards = evaluate_grpo(
|
||||
model=model,
|
||||
ref_model=reference_model,
|
||||
dataset=test_set,
|
||||
tokenizer=tokenizer,
|
||||
batch_size=args.batch_size,
|
||||
num_batches=args.test_batches,
|
||||
max_seq_length=args.max_seq_length,
|
||||
beta=args.beta,
|
||||
group_size=args.group_size,
|
||||
epsilon=args.epsilon
|
||||
)
|
||||
|
||||
print(f"Test loss {test_loss:.3f}, Test ppl {test_ppl:.3f}.")
|
||||
test_ppl = math.exp(test_loss)
|
||||
|
||||
print(f"Test loss {test_loss:.3f}, Test ppl {test_ppl:.3f}, Rewards: {test_rewards[0]:.3f}, {test_rewards[1]:.3f}")
|
||||
else:
|
||||
test_loss = evaluate(
|
||||
model=model,
|
||||
dataset=test_set,
|
||||
tokenizer=tokenizer,
|
||||
batch_size=args.batch_size,
|
||||
num_batches=args.test_batches,
|
||||
max_seq_length=args.max_seq_length,
|
||||
)
|
||||
|
||||
test_ppl = math.exp(test_loss)
|
||||
|
||||
print(f"Test loss {test_loss:.3f}, Test ppl {test_ppl:.3f}.")
|
||||
|
||||
|
||||
def run(args, training_callback: TrainingCallback = None):
|
||||
|
1
llms/mlx_lm/test_grpo
Submodule
1
llms/mlx_lm/test_grpo
Submodule
@ -0,0 +1 @@
|
||||
Subproject commit a74695c9280dd46208ea000f507f44bc8ddd9533
|
@ -1,10 +1,59 @@
|
||||
import json
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
from transformers import PreTrainedTokenizer
|
||||
|
||||
|
||||
class GRPODataset:
|
||||
"""
|
||||
Dataset wrapper for GRPO training data.
|
||||
Each example should have a 'prompt' and 'answer' field.
|
||||
Returns data in (prompt_tokens, answer_tokens, prompt_str, answer_str) tuple format.
|
||||
"""
|
||||
def __init__(
|
||||
self,
|
||||
data: List[Dict[str, str]],
|
||||
tokenizer: PreTrainedTokenizer,
|
||||
prompt_key: str = "prompt",
|
||||
answer_key: str = "answer",
|
||||
use_chat_template: bool = False,
|
||||
use_prompt: bool = False
|
||||
):
|
||||
self._data = []
|
||||
for item in data:
|
||||
prompt_str = str(item[prompt_key])
|
||||
answer_str = str(item[answer_key])
|
||||
if use_chat_template:
|
||||
prompt_tokens = tokenizer.apply_chat_template(
|
||||
[
|
||||
{'role': 'system', 'content': """A conversation between User and Assistant. The user asks a question, and the Assistant solves it.
|
||||
The assistantfirst thinks about the reasoning process in the mind and then provides the user with the answer.
|
||||
The reasoning process and answer are enclosed within <think> </think> and <answer> </answer> tags, respectively, i.e., <think> reasoning process here </think><answer> answer here </answer>."""},
|
||||
{'role': 'user', 'content': prompt_str}
|
||||
],
|
||||
)
|
||||
answer_tokens = tokenizer.encode(answer_str)
|
||||
else:
|
||||
if use_prompt:
|
||||
prompt_tokens = tokenizer.encode(f"""A conversation between User and Assistant. The user asks a question, and the Assistant solves it.
|
||||
The assistantfirst thinks about the reasoning process in the mind and then provides the user with the answer.
|
||||
The reasoning process and answer are enclosed within <think> </think> and <answer> </answer> tags, respectively, i.e., <think> reasoning process here </think><answer> answer here </answer>.
|
||||
User: {prompt_str}. Assistant: """)
|
||||
else:
|
||||
prompt_tokens = tokenizer.encode(prompt_str)
|
||||
answer_tokens = tokenizer.encode(answer_str)
|
||||
self._data.append((prompt_tokens, answer_tokens, prompt_str, answer_str))
|
||||
|
||||
def __getitem__(self, idx: int) -> Tuple[List[int], List[int], str, str]:
|
||||
"""Returns a (prompt_tokens, answer_tokens, prompt_str, answer_str) tuple."""
|
||||
return self._data[idx]
|
||||
|
||||
def __len__(self) -> int:
|
||||
"""Returns the number of examples in the dataset."""
|
||||
return len(self._data)
|
||||
|
||||
|
||||
class Dataset:
|
||||
"""
|
||||
Light-weight wrapper to hold a dataset.
|
||||
@ -82,6 +131,7 @@ class CompletionsDataset:
|
||||
|
||||
|
||||
def create_dataset(
|
||||
args,
|
||||
data,
|
||||
tokenizer: PreTrainedTokenizer,
|
||||
prompt_feature: Optional[str] = None,
|
||||
@ -90,31 +140,44 @@ def create_dataset(
|
||||
prompt_feature = prompt_feature or "prompt"
|
||||
completion_feature = completion_feature or "completion"
|
||||
sample = data[0]
|
||||
if "messages" in sample:
|
||||
return ChatDataset(data, tokenizer)
|
||||
elif prompt_feature in sample and completion_feature in sample:
|
||||
return CompletionsDataset(data, tokenizer, prompt_feature, completion_feature)
|
||||
elif "text" in sample:
|
||||
return Dataset(data, tokenizer)
|
||||
|
||||
if args.training_mode == "normal":
|
||||
if "messages" in sample:
|
||||
return ChatDataset(data, tokenizer)
|
||||
elif prompt_feature in sample and completion_feature in sample:
|
||||
return CompletionsDataset(data, tokenizer, prompt_feature, completion_feature)
|
||||
elif "text" in sample:
|
||||
return Dataset(data, tokenizer)
|
||||
else:
|
||||
raise ValueError(
|
||||
"Unsupported data format, check the supported formats here:\n"
|
||||
"https://github.com/ml-explore/mlx-examples/blob/main/llms/mlx_lm/LORA.md#data."
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
"Unsupported data format, check the supported formats here:\n"
|
||||
"https://github.com/ml-explore/mlx-examples/blob/main/llms/mlx_lm/LORA.md#data."
|
||||
return GRPODataset(
|
||||
data=data,
|
||||
tokenizer=tokenizer,
|
||||
prompt_key="prompt",
|
||||
answer_key="answer",
|
||||
use_chat_template=args.use_chat_template,
|
||||
use_prompt=args.use_prompt
|
||||
)
|
||||
|
||||
|
||||
def load_local_dataset(
|
||||
args,
|
||||
data_path: Path,
|
||||
tokenizer: PreTrainedTokenizer,
|
||||
prompt_feature: Optional[str] = None,
|
||||
completion_feature: Optional[str] = None,
|
||||
):
|
||||
def load_subset(path):
|
||||
print(path)
|
||||
if not path.exists():
|
||||
return []
|
||||
with open(path, "r") as fid:
|
||||
data = [json.loads(l) for l in fid]
|
||||
return create_dataset(data, tokenizer, prompt_feature, completion_feature)
|
||||
return create_dataset(args, data, tokenizer, prompt_feature, completion_feature)
|
||||
|
||||
names = ("train", "valid", "test")
|
||||
train, valid, test = [load_subset(data_path / f"{n}.jsonl") for n in names]
|
||||
@ -122,6 +185,7 @@ def load_local_dataset(
|
||||
|
||||
|
||||
def load_hf_dataset(
|
||||
args,
|
||||
data_id: str,
|
||||
tokenizer: PreTrainedTokenizer,
|
||||
prompt_feature: Optional[str] = None,
|
||||
@ -137,7 +201,7 @@ def load_hf_dataset(
|
||||
train, valid, test = [
|
||||
(
|
||||
create_dataset(
|
||||
dataset[n], tokenizer, prompt_feature, completion_feature
|
||||
args, dataset[n], tokenizer, prompt_feature, completion_feature
|
||||
)
|
||||
if n in dataset.keys()
|
||||
else []
|
||||
@ -202,12 +266,12 @@ def load_dataset(args, tokenizer: PreTrainedTokenizer):
|
||||
completion_feature = getattr(args, "completion_feature", None)
|
||||
if data_path.exists():
|
||||
train, valid, test = load_local_dataset(
|
||||
data_path, tokenizer, prompt_feature, completion_feature
|
||||
args, data_path, tokenizer, prompt_feature, completion_feature
|
||||
)
|
||||
else:
|
||||
print(f"Loading Hugging Face dataset {args.data}.")
|
||||
train, valid, test = load_hf_dataset(
|
||||
args.data, tokenizer, prompt_feature, completion_feature
|
||||
args, args.data, tokenizer, prompt_feature, completion_feature
|
||||
)
|
||||
|
||||
if args.train and len(train) == 0:
|
||||
|
690
llms/mlx_lm/tuner/grpo_trainer.py
Normal file
690
llms/mlx_lm/tuner/grpo_trainer.py
Normal file
@ -0,0 +1,690 @@
|
||||
# Copyright © 2024 Apple Inc.
|
||||
|
||||
import time
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
import re
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
import numpy as np
|
||||
from mlx.utils import tree_flatten
|
||||
|
||||
from .trainer import grad_checkpoint, TrainingArgs, TrainingCallback, average_gradients, iterate_batches
|
||||
|
||||
@dataclass
|
||||
class GRPOTrainingArgs(TrainingArgs):
|
||||
group_size: int = field(
|
||||
default=4,
|
||||
metadata={"help": "Number of responses per prompt."},
|
||||
)
|
||||
beta: float = field(
|
||||
default=0.1, metadata={"help": "KL penalty coefficient."}
|
||||
)
|
||||
epsilon: float = field(
|
||||
default=1e-4, metadata={"help": "The Epsilon for numerical stability."}
|
||||
)
|
||||
max_completion_length: int = field(
|
||||
default=512, metadata={"help": "Number of Generations."}
|
||||
)
|
||||
reference_model_path: str = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "Path to reference model weights. If None, uses the same model."
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def r1_extract_xml_answer(text: str) -> str:
|
||||
"""Extracts the answer from an XML formatted text string."""
|
||||
try:
|
||||
answer = text.split("<answer>")[-1]
|
||||
answer = answer.split("</answer>")[0]
|
||||
return answer.strip()
|
||||
except:
|
||||
print("r1_extract_xml_answer returned empty string")
|
||||
return ""
|
||||
|
||||
|
||||
def r1_int_reward_func(prompts: list, completions: list, answer: list, **kwargs) -> list[float]:
|
||||
"""Ensures we always return a list of floats."""
|
||||
if not completions:
|
||||
return [0.0] * len(prompts)
|
||||
extracted_responses = [r1_extract_xml_answer(r) for r in completions]
|
||||
return [0.5 if r and r.isdigit() else 0.0 for r in extracted_responses]
|
||||
|
||||
def r1_accuracy_reward_func(prompts: list, completions: list, answer: list, **kwargs) -> list[float]:
|
||||
"""Ensures we always return a list of floats."""
|
||||
if not completions or not answer:
|
||||
return [0.0] * len(prompts)
|
||||
extracted_responses = [r1_extract_xml_answer(r) for r in completions]
|
||||
return [2.0 if r and a and r == a else 0.0 for r, a in zip(extracted_responses, answer)]
|
||||
|
||||
|
||||
def r1_soft_format_reward_func(prompts: list, completions: list, answer: list, **kwargs) -> list[float]:
|
||||
"""Ensures we always return a list of floats."""
|
||||
if not completions:
|
||||
return [0.0] * len(prompts)
|
||||
pattern = r"<think>.*?</think>\s*<answer>.*?</answer>"
|
||||
matches = [bool(re.search(pattern, r)) if r else False for r in completions]
|
||||
return [0.5 if match else 0.0 for match in matches]
|
||||
|
||||
|
||||
def r1_strict_format_reward_func(prompts: list, completions: list, answer: list, **kwargs) -> list[float]:
|
||||
"""Ensures we always return a list of floats."""
|
||||
if not completions:
|
||||
return [0.0] * len(prompts)
|
||||
pattern = r"^<think>\n.*?\n</think>\n<answer>\n.*?\n</answer>\n$"
|
||||
matches = [bool(re.search(pattern, r)) if r else False for r in completions]
|
||||
return [0.5 if match else 0.0 for match in matches]
|
||||
|
||||
|
||||
def r1_count_xml(prompts: list, completions: list, answer: list, **kwargs) -> list[float]:
|
||||
"""Ensures we always return a list of floats."""
|
||||
if not completions:
|
||||
return [0.0] * len(prompts)
|
||||
|
||||
scores = []
|
||||
for text in completions:
|
||||
if not text:
|
||||
scores.append(0.0)
|
||||
continue
|
||||
|
||||
count = 0.0
|
||||
if text.count("<think>\n") == 1:
|
||||
count += 0.125
|
||||
if text.count("\n</think>\n") == 1:
|
||||
count += 0.125
|
||||
if text.count("\n<answer>\n") == 1:
|
||||
count += 0.125
|
||||
if text.count("\n</answer>\n") == 1:
|
||||
count += 0.125
|
||||
|
||||
# Penalize extra text after </answer>
|
||||
end_text = text.split("\n</answer>\n")[-1]
|
||||
count -= len(end_text) * 0.001 if len(end_text) > 0 else 0
|
||||
|
||||
scores.append(max(0.0, count)) # Ensure non-negative score
|
||||
|
||||
return scores
|
||||
|
||||
|
||||
def generate_grpo(model, prompt, max_tokens, tokenizer, temperature):
|
||||
if len(prompt.shape) == 1:
|
||||
prompt = prompt[None, :]
|
||||
if prompt.shape[1] == 0:
|
||||
return None
|
||||
|
||||
end_sequence = tokenizer.encode("</answer>")
|
||||
end_sequence_length = len(end_sequence)
|
||||
output = mx.zeros((prompt.shape[1] + max_tokens,), dtype=mx.int32)
|
||||
output[:prompt.shape[1]] = prompt[0]
|
||||
current_length = prompt.shape[1]
|
||||
|
||||
try:
|
||||
def sample(logits):
|
||||
if temperature > 0:
|
||||
logits /= temperature
|
||||
logprobs = logits - mx.logsumexp(logits, keepdims=True)
|
||||
return mx.random.categorical(logprobs[None, :]).astype(mx.int32)[0]
|
||||
|
||||
for _ in range(max_tokens):
|
||||
current_input = output[:current_length][None, :]
|
||||
logits = model(current_input)
|
||||
token_logits = logits[0, -1]
|
||||
next_token = sample(token_logits)
|
||||
token_value = next_token.item()
|
||||
output[current_length] = token_value
|
||||
current_length += 1
|
||||
|
||||
if token_value == tokenizer.eos_token_id:
|
||||
break
|
||||
|
||||
if current_length >= end_sequence_length:
|
||||
last_tokens = output[current_length - end_sequence_length:current_length].tolist()
|
||||
# print(f"Last tokens: {last_tokens}")
|
||||
# print(f"Decoded text: {tokenizer.decode(last_tokens)}")
|
||||
# print(f"Target sequence: {end_sequence}")
|
||||
if last_tokens == end_sequence:
|
||||
break
|
||||
|
||||
if current_length > prompt.shape[1]:
|
||||
return output[:current_length]
|
||||
|
||||
except Exception as e:
|
||||
print(f"Generation error: {str(e)}")
|
||||
return None
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def get_per_token_logps(model, inputs, lengths):
|
||||
logits = model(inputs).astype(mx.float16)
|
||||
logits = logits[:, :-1, :]
|
||||
targets = inputs[:, 1:]
|
||||
|
||||
per_token_logps = []
|
||||
for i in range(logits.shape[0]):
|
||||
seq_len = int(lengths[i]) - 1
|
||||
|
||||
seq_logits = logits[i, :seq_len]
|
||||
seq_targets = targets[i, :seq_len]
|
||||
|
||||
log_probs = nn.log_softmax(seq_logits, axis=-1)
|
||||
|
||||
token_log_probs = mx.take_along_axis(
|
||||
log_probs,
|
||||
seq_targets.reshape(seq_len, 1),
|
||||
axis=-1
|
||||
).squeeze(-1)
|
||||
|
||||
per_token_logps.append(token_log_probs)
|
||||
mx.eval(logits)
|
||||
return per_token_logps
|
||||
|
||||
|
||||
def grpo_loss(
|
||||
model,
|
||||
tokenizer,
|
||||
batch,
|
||||
reward_funcs=None,
|
||||
beta=0.1,
|
||||
group_size=4,
|
||||
epsilon=1e-4,
|
||||
ref_model=None,
|
||||
max_tokens=64,
|
||||
temperature=1.0
|
||||
):
|
||||
prompt_tokens, answer_tokens, prompt_text, answer_text = batch
|
||||
batch_size = len(prompt_tokens)
|
||||
|
||||
# Generation logic remains the same
|
||||
all_completions = []
|
||||
all_completion_texts = []
|
||||
|
||||
for i in range(0, batch_size, batch_size):
|
||||
batch_prompts = prompt_tokens[i:i+batch_size]
|
||||
for prompt in batch_prompts:
|
||||
prompt_tensor = mx.array(prompt)
|
||||
for _ in range(group_size):
|
||||
try:
|
||||
completion_ids = generate_grpo(model, prompt_tensor, max_tokens, tokenizer, temperature)
|
||||
if completion_ids is not None:
|
||||
completion_text = tokenizer.decode(completion_ids.tolist())
|
||||
all_completions.append(completion_ids)
|
||||
all_completion_texts.append(completion_text)
|
||||
|
||||
# Clear completion tensors
|
||||
mx.eval(completion_ids)
|
||||
del completion_ids
|
||||
except Exception as e:
|
||||
print(f"Generation error: {e}")
|
||||
continue
|
||||
|
||||
mx.metal.clear_cache()
|
||||
|
||||
# Prepare inputs
|
||||
expanded_answers = []
|
||||
expanded_prompts = []
|
||||
for i in range(batch_size):
|
||||
expanded_answers.extend([answer_text[i]] * group_size)
|
||||
expanded_prompts.extend([prompt_text[i]] * group_size)
|
||||
|
||||
max_length = max(ids.shape[0] for ids in all_completions)
|
||||
padded_completions = []
|
||||
attention_masks = []
|
||||
|
||||
for completion_ids in all_completions:
|
||||
padding_length = max_length - completion_ids.shape[0]
|
||||
if padding_length > 0:
|
||||
padding = mx.zeros((padding_length,), dtype=completion_ids.dtype)
|
||||
padded_ids = mx.concatenate([completion_ids, padding])
|
||||
mask = mx.concatenate([mx.ones_like(completion_ids), mx.zeros_like(padding)])
|
||||
else:
|
||||
padded_ids = completion_ids
|
||||
mask = mx.ones_like(completion_ids)
|
||||
padded_completions.append(padded_ids)
|
||||
attention_masks.append(mask)
|
||||
|
||||
inputs = mx.stack(padded_completions)
|
||||
attention_mask = mx.stack(attention_masks)
|
||||
lengths = attention_mask.sum(axis=1)
|
||||
|
||||
# Current policy probabilities
|
||||
token_log_probs = get_per_token_logps(model, inputs, lengths)
|
||||
|
||||
mx.eval(token_log_probs)
|
||||
mx.metal.clear_cache()
|
||||
|
||||
# Reference policy probabilities
|
||||
if ref_model is not None:
|
||||
ref_token_log_probs = get_per_token_logps(ref_model, inputs, lengths)
|
||||
else:
|
||||
ref_token_log_probs = token_log_probs
|
||||
|
||||
max_len = max(x.shape[0] for x in token_log_probs)
|
||||
padded_log_probs = []
|
||||
padded_ref_log_probs = []
|
||||
|
||||
for i in range(len(token_log_probs)):
|
||||
seq_len = token_log_probs[i].shape[0]
|
||||
padding = mx.zeros((max_len - seq_len,), dtype=mx.float16)
|
||||
|
||||
padded_log_probs.append(mx.concatenate([token_log_probs[i], padding]))
|
||||
padded_ref_log_probs.append(mx.concatenate([ref_token_log_probs[i], padding]))
|
||||
|
||||
token_log_probs = mx.stack(padded_log_probs)
|
||||
ref_token_log_probs = mx.stack(padded_ref_log_probs)
|
||||
|
||||
# Calculate rewards and advantages
|
||||
rewards = mx.zeros((len(all_completions),))
|
||||
for reward_func in reward_funcs:
|
||||
func_rewards = mx.array(reward_func(
|
||||
prompts=expanded_prompts,
|
||||
completions=all_completion_texts,
|
||||
answer=expanded_answers
|
||||
))
|
||||
rewards += func_rewards
|
||||
|
||||
if len(reward_funcs) > 1:
|
||||
rewards /= len(reward_funcs)
|
||||
|
||||
# Reshape rewards and compute advantages following GRPO formula
|
||||
rewards_reshaped = rewards.reshape(batch_size, group_size)
|
||||
mean_rewards = mx.broadcast_to(mx.mean(rewards_reshaped, axis=1)[:, None], (rewards_reshaped.shape[0], group_size)).reshape(-1)
|
||||
std_rewards = mx.broadcast_to(mx.std(rewards_reshaped, axis=1)[:, None], (rewards_reshaped.shape[0], group_size)).reshape(-1)
|
||||
advantages = (rewards - mean_rewards) / (std_rewards + epsilon)
|
||||
|
||||
# Compute KL divergence using Schulman's approximator
|
||||
kl_div = (mx.exp(token_log_probs - ref_token_log_probs) - 1) - (token_log_probs - ref_token_log_probs)
|
||||
|
||||
# Create mask for valid tokens
|
||||
length_mask = mx.arange(inputs.shape[1] - 1)[None, :] < (lengths[:, None] - 1)
|
||||
|
||||
# Compute policy ratio
|
||||
policy_ratio = mx.exp(token_log_probs - mx.stop_gradient(token_log_probs))
|
||||
|
||||
# Compute per-token loss following GRPO formula
|
||||
per_token_loss = -((policy_ratio * advantages.reshape(-1, 1) - beta * kl_div) * length_mask)
|
||||
|
||||
# Average over tokens and sequences
|
||||
sequence_sums = per_token_loss.sum(axis=1)
|
||||
sequence_lengths = length_mask.sum(axis=1)
|
||||
|
||||
loss = (sequence_sums / sequence_lengths).mean()
|
||||
|
||||
# Calculate mean KL divergence for metrics
|
||||
mean_kl = ((kl_div * length_mask).sum(axis=1) / length_mask.sum(axis=1)).mean()
|
||||
|
||||
# Collect reward metrics
|
||||
reward_metrics = {}
|
||||
for i, reward_func in enumerate(reward_funcs):
|
||||
func_name = reward_func.__name__
|
||||
func_rewards = mx.array(reward_func(
|
||||
prompts=expanded_prompts,
|
||||
completions=all_completion_texts,
|
||||
answer=expanded_answers
|
||||
))
|
||||
reward_metrics[f'{func_name}_mean'] = mx.mean(func_rewards)
|
||||
reward_metrics[f'{func_name}_std'] = mx.std(func_rewards)
|
||||
|
||||
metrics = {
|
||||
'total_rewards_mean': mx.mean(rewards),
|
||||
'total_rewards_std': mx.std(rewards),
|
||||
'grouped_rewards_mean': mx.mean(rewards_reshaped),
|
||||
'grouped_rewards_std': mx.std(rewards_reshaped),
|
||||
'kl': mean_kl,
|
||||
**reward_metrics
|
||||
}
|
||||
mx.metal.clear_cache()
|
||||
|
||||
return loss, sequence_lengths.sum(), metrics
|
||||
|
||||
|
||||
def iterate_grpo_batches(dataset, tokenizer, batch_size, max_seq_length, train=False):
|
||||
"""Memory-optimized version of iterate_grpo_batches"""
|
||||
if not dataset or not isinstance(dataset[0], tuple) or len(dataset[0]) != 4:
|
||||
raise ValueError("Dataset must be list of (prompt_tokens, answer_tokens, prompt_str, answer_str) tuples")
|
||||
|
||||
# Sort by length but use generator to avoid keeping full sorted list in memory
|
||||
def length_key(i):
|
||||
return len(dataset[i][0]) + len(dataset[i][1])
|
||||
|
||||
idx = sorted(range(len(dataset)), key=length_key)
|
||||
|
||||
if len(dataset) < batch_size:
|
||||
raise ValueError(
|
||||
f"Dataset must have at least batch_size={batch_size} "
|
||||
f"examples but only has {len(dataset)}."
|
||||
)
|
||||
|
||||
step = mx.distributed.init().size()
|
||||
if batch_size % step != 0:
|
||||
raise ValueError("The batch size must be divisible by the number of workers")
|
||||
|
||||
# Use generator for batch indices
|
||||
def batch_index_generator():
|
||||
for i in range(0, len(idx) - batch_size + 1, batch_size):
|
||||
yield idx[i : i + batch_size : step]
|
||||
|
||||
while True:
|
||||
indices = (
|
||||
np.random.permutation(list(batch_index_generator())) if train
|
||||
else batch_index_generator()
|
||||
)
|
||||
|
||||
for batch_idx in indices:
|
||||
current_batch = [dataset[j] for j in batch_idx]
|
||||
|
||||
prompts_tokens = [item[0] for item in current_batch]
|
||||
answers_tokens = [item[1] for item in current_batch]
|
||||
prompts_text = [item[2] for item in current_batch]
|
||||
answers_text = [item[3] for item in current_batch]
|
||||
|
||||
if any(len(p) > max_seq_length for p in prompts_tokens):
|
||||
print(
|
||||
f"[WARNING] Some prompts are longer than {max_seq_length} tokens. "
|
||||
"Long prompts will be truncated."
|
||||
)
|
||||
|
||||
yield prompts_tokens, answers_tokens, prompts_text, answers_text
|
||||
|
||||
if not train:
|
||||
break
|
||||
|
||||
|
||||
def evaluate_grpo(
|
||||
model,
|
||||
ref_model,
|
||||
dataset,
|
||||
tokenizer,
|
||||
batch_size,
|
||||
num_batches,
|
||||
beta: float,
|
||||
epsilon: float,
|
||||
group_size: int,
|
||||
max_seq_length,
|
||||
reward_funcs = None,
|
||||
loss_fn: callable = grpo_loss,
|
||||
iterate_batches: callable = iterate_grpo_batches
|
||||
):
|
||||
"""
|
||||
Evaluate model using GRPO loss.
|
||||
Returns:
|
||||
tuple: (average loss, number of tokens, average metrics)
|
||||
"""
|
||||
all_losses = 0
|
||||
ntokens = 0
|
||||
all_metrics = None # Initialize metrics dictionary
|
||||
|
||||
# Create iterator for batches
|
||||
index_iterator = iter(range(num_batches)) if num_batches != -1 else iter(int, 1)
|
||||
|
||||
# Iterate through batches
|
||||
for _, batch in zip(
|
||||
index_iterator,
|
||||
iterate_batches(
|
||||
dataset=dataset,
|
||||
tokenizer=tokenizer,
|
||||
batch_size=batch_size,
|
||||
max_seq_length=max_seq_length,
|
||||
),
|
||||
):
|
||||
# Calculate loss for current batch
|
||||
losses, toks, metrics = loss_fn(
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
batch=batch,
|
||||
reward_funcs=reward_funcs,
|
||||
beta=beta,
|
||||
group_size=group_size,
|
||||
epsilon=epsilon,
|
||||
ref_model=ref_model
|
||||
)
|
||||
|
||||
# Accumulate losses and tokens
|
||||
all_losses += losses * toks
|
||||
ntokens += toks
|
||||
|
||||
# Accumulate metrics
|
||||
if all_metrics is None:
|
||||
all_metrics = {k: v * toks for k, v in metrics.items()}
|
||||
else:
|
||||
for k, v in metrics.items():
|
||||
all_metrics[k] += v * toks
|
||||
|
||||
# Evaluate accumulated values
|
||||
mx.eval(all_losses, ntokens)
|
||||
|
||||
# Aggregate across distributed workers
|
||||
all_losses = mx.distributed.all_sum(all_losses, stream=mx.cpu)
|
||||
ntokens = mx.distributed.all_sum(ntokens, stream=mx.cpu)
|
||||
all_metrics = {k: mx.distributed.all_sum(v) for k, v in all_metrics.items()}
|
||||
|
||||
# Calculate averages
|
||||
avg_metrics = {k: (v / ntokens).item() for k, v in all_metrics.items()}
|
||||
avg_loss = (all_losses / ntokens).item()
|
||||
|
||||
return avg_loss, ntokens, avg_metrics
|
||||
|
||||
|
||||
def train_grpo(
|
||||
model,
|
||||
ref_model,
|
||||
tokenizer,
|
||||
optimizer,
|
||||
train_dataset,
|
||||
val_dataset,
|
||||
reward_funcs = [
|
||||
r1_accuracy_reward_func,
|
||||
r1_int_reward_func,
|
||||
r1_strict_format_reward_func,
|
||||
r1_soft_format_reward_func,
|
||||
r1_count_xml
|
||||
],
|
||||
args: GRPOTrainingArgs = GRPOTrainingArgs(),
|
||||
loss_fn: callable = grpo_loss,
|
||||
iterate_batches: callable = iterate_grpo_batches,
|
||||
training_callback: TrainingCallback = None,
|
||||
):
|
||||
print(f"Starting GRPO training with {len(reward_funcs)} reward functions..., iters: {args.iters}")
|
||||
world = mx.distributed.init()
|
||||
world_size = world.size()
|
||||
rank = world.rank()
|
||||
if world_size > 1:
|
||||
print(f"Node {rank} of {world_size}")
|
||||
|
||||
if args.grad_checkpoint:
|
||||
grad_checkpoint(model.layers[0])
|
||||
|
||||
state = [model.state, optimizer.state]
|
||||
|
||||
def step(batch):
|
||||
|
||||
# Forward and backward pass
|
||||
(loss, toks, metrics), grad = loss_value_and_grad(
|
||||
model,
|
||||
tokenizer=tokenizer,
|
||||
batch=batch,
|
||||
reward_funcs=reward_funcs,
|
||||
beta=args.beta,
|
||||
group_size=args.group_size,
|
||||
epsilon=args.epsilon,
|
||||
ref_model=ref_model,
|
||||
max_tokens=args.max_completion_length,
|
||||
)
|
||||
|
||||
# All reduce the gradients if running in distributed mode
|
||||
grad = average_gradients(grad)
|
||||
|
||||
# Model update
|
||||
optimizer.update(model, grad)
|
||||
|
||||
return loss, toks, metrics
|
||||
|
||||
loss_value_and_grad = nn.value_and_grad(model, loss_fn)
|
||||
|
||||
losses = 0
|
||||
n_tokens = 0
|
||||
steps = 0
|
||||
trained_tokens = 0
|
||||
accumulated_metrics = {
|
||||
'total_rewards_mean': 0,
|
||||
'total_rewards_std': 0,
|
||||
'grouped_rewards_mean': 0,
|
||||
'grouped_rewards_std': 0,
|
||||
'kl': 0
|
||||
}
|
||||
for reward_func in reward_funcs:
|
||||
func_name = reward_func.__name__
|
||||
accumulated_metrics[f'{func_name}_mean'] = 0
|
||||
accumulated_metrics[f'{func_name}_std'] = 0
|
||||
|
||||
start = time.perf_counter()
|
||||
for it, batch in zip(
|
||||
range(1, args.iters + 1),
|
||||
iterate_batches(
|
||||
dataset=train_dataset,
|
||||
tokenizer=tokenizer,
|
||||
batch_size=args.batch_size,
|
||||
max_seq_length=args.max_seq_length,
|
||||
train=True,
|
||||
),
|
||||
):
|
||||
# Report validation loss if needed, the first validation loss
|
||||
# is always measured before any training.
|
||||
if it == 1 or it % args.steps_per_eval == 0 or it == args.iters:
|
||||
stop = time.perf_counter()
|
||||
val_loss, val_ntokens, val_metrics = evaluate_grpo(
|
||||
model=model,
|
||||
dataset=val_dataset,
|
||||
loss_fn=loss_fn,
|
||||
ref_model=ref_model,
|
||||
reward_funcs=reward_funcs,
|
||||
tokenizer=tokenizer,
|
||||
group_size=args.group_size,
|
||||
batch_size=args.batch_size,
|
||||
num_batches=args.val_batches,
|
||||
max_seq_length=args.max_seq_length,
|
||||
beta=args.beta,
|
||||
epsilon=args.epsilon,
|
||||
iterate_batches=iterate_batches,
|
||||
)
|
||||
val_time = time.perf_counter() - stop
|
||||
if rank == 0:
|
||||
val_metrics_str = (
|
||||
f"Val loss {val_loss:.8f}, "
|
||||
f"Val total_rewards_mean {val_metrics['total_rewards_mean']:.3f}, "
|
||||
f"Val total_rewards_std {val_metrics['total_rewards_std']:.3f}, "
|
||||
f"Val grouped_rewards_mean {val_metrics['grouped_rewards_mean']:.3f}, "
|
||||
f"Val grouped_rewards_std {val_metrics['grouped_rewards_std']:.3f}, "
|
||||
f"Val kl {val_metrics['kl']:.3f}"
|
||||
)
|
||||
|
||||
# Add reward function specific metrics
|
||||
for i, reward_func in enumerate(reward_funcs):
|
||||
val_metrics_str += (
|
||||
f", Val {reward_func.__name__}_mean {val_metrics[f'{reward_func.__name__}_mean']:.3f}, "
|
||||
f"Val {reward_func.__name__}_std {val_metrics[f'{reward_func.__name__}_std']:.3f}"
|
||||
)
|
||||
|
||||
print(
|
||||
f"Iter {it}: {val_metrics_str}, "
|
||||
f"Val took {val_time:.3f}s",
|
||||
flush=True,
|
||||
)
|
||||
|
||||
if training_callback is not None:
|
||||
training_callback.on_val_loss_report({
|
||||
"iteration": it,
|
||||
"val_loss": val_loss,
|
||||
**{f"val_{k}": v for k, v in val_metrics.items()},
|
||||
"val_time": val_time,
|
||||
})
|
||||
|
||||
start = time.perf_counter()
|
||||
|
||||
loss, toks, metrics = step(batch)
|
||||
losses += loss
|
||||
n_tokens += toks
|
||||
steps += 1
|
||||
|
||||
for k, v in metrics.items():
|
||||
accumulated_metrics[k] += v
|
||||
|
||||
mx.eval(state, losses, n_tokens)
|
||||
|
||||
if it % args.steps_per_report == 0 or it == args.iters:
|
||||
stop = time.perf_counter()
|
||||
|
||||
train_loss = mx.distributed.all_sum(losses, stream=mx.cpu).item()
|
||||
train_loss /= steps * mx.distributed.init().size()
|
||||
avg_metrics = {k: v / (steps * world_size) for k, v in accumulated_metrics.items()}
|
||||
n_tokens = mx.distributed.all_sum(n_tokens, stream=mx.cpu).item()
|
||||
learning_rate = optimizer.learning_rate.item()
|
||||
it_sec = args.steps_per_report / (stop - start)
|
||||
tokens_sec = float(n_tokens) / (stop - start)
|
||||
trained_tokens += n_tokens
|
||||
peak_mem = mx.metal.get_peak_memory() / 1e9
|
||||
|
||||
if rank == 0:
|
||||
train_metrics_str = (
|
||||
f"Train loss {train_loss:.8f}, "
|
||||
f"Total rewards mean {avg_metrics['total_rewards_mean']:.3f}, "
|
||||
f"Total rewards std {avg_metrics['total_rewards_std']:.3f}, "
|
||||
f"Grouped rewards mean {avg_metrics['grouped_rewards_mean']:.3f}, "
|
||||
f"Grouped rewards std {avg_metrics['grouped_rewards_std']:.3f}, "
|
||||
f"KL {avg_metrics['kl']:.3f}"
|
||||
)
|
||||
|
||||
# Add reward function specific metrics
|
||||
for i, reward_func in enumerate(reward_funcs):
|
||||
func_name = reward_func.__name__
|
||||
train_metrics_str += (
|
||||
f", {func_name} mean {avg_metrics[f'{func_name}_mean']:.3f}, "
|
||||
f"{func_name} std {avg_metrics[f'{func_name}_std']:.3f}"
|
||||
)
|
||||
|
||||
print(
|
||||
f"Iter {it}: {train_metrics_str}, "
|
||||
f"Learning Rate {learning_rate:.3e}, "
|
||||
f"It/sec {it_sec:.3f}, "
|
||||
f"Tokens/sec {tokens_sec:.3f}, "
|
||||
f"Peak mem {peak_mem:.3f} GB",
|
||||
flush=True,
|
||||
)
|
||||
|
||||
if training_callback is not None:
|
||||
training_callback.on_train_loss_report({
|
||||
"iteration": it,
|
||||
"train_loss": train_loss,
|
||||
**{f"train_{k}": v for k, v in avg_metrics.items()},
|
||||
"learning_rate": learning_rate,
|
||||
"iterations_per_second": it_sec,
|
||||
"tokens_per_second": tokens_sec,
|
||||
"trained_tokens": trained_tokens,
|
||||
"peak_memory": peak_mem,
|
||||
})
|
||||
|
||||
losses = 0
|
||||
n_tokens = 0
|
||||
steps = 0
|
||||
start = time.perf_counter()
|
||||
|
||||
# Save adapter weights
|
||||
if it % args.steps_per_save == 0:
|
||||
adapter_weights = dict(tree_flatten(model.trainable_parameters()))
|
||||
mx.save_safetensors(str(args.adapter_file), adapter_weights)
|
||||
checkpoint = (
|
||||
Path(args.adapter_file).parent / f"{it:07d}_adapters.safetensors"
|
||||
)
|
||||
mx.save_safetensors(str(checkpoint), adapter_weights)
|
||||
print(
|
||||
f"Iter {it}: Saved adapter weights to "
|
||||
f"{args.adapter_file} and {checkpoint}."
|
||||
)
|
||||
|
||||
# Save final weights
|
||||
adapter_weights = dict(tree_flatten(model.trainable_parameters()))
|
||||
mx.save_safetensors(str(args.adapter_file), adapter_weights)
|
||||
print(f"Saved final weights to {args.adapter_file}.")
|
Loading…
Reference in New Issue
Block a user