mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-08-29 03:01:34 +08:00
init grpo
This commit is contained in:
parent
e2e5478da5
commit
ec50a869b0
19
llms/mlx_lm/convert2.py
Normal file
19
llms/mlx_lm/convert2.py
Normal file
@ -0,0 +1,19 @@
|
|||||||
|
import pandas as pd
|
||||||
|
import os
|
||||||
|
|
||||||
|
# Define dataset directory
|
||||||
|
dataset_dir = "/Users/cshang/Desktop/test_grpo/data"
|
||||||
|
|
||||||
|
# Convert each Parquet file to JSONL
|
||||||
|
for file in os.listdir(dataset_dir):
|
||||||
|
if file.endswith(".parquet"):
|
||||||
|
parquet_path = os.path.join(dataset_dir, file)
|
||||||
|
jsonl_path = os.path.join(dataset_dir, file.replace(".parquet", ".jsonl"))
|
||||||
|
|
||||||
|
# Load Parquet file
|
||||||
|
df = pd.read_parquet(parquet_path)
|
||||||
|
|
||||||
|
# Convert to JSONL format
|
||||||
|
df.to_json(jsonl_path, orient="records", lines=True)
|
||||||
|
|
||||||
|
print(f"Converted {parquet_path} -> {jsonl_path}")
|
@ -15,6 +15,7 @@ import yaml
|
|||||||
from .tokenizer_utils import TokenizerWrapper
|
from .tokenizer_utils import TokenizerWrapper
|
||||||
from .tuner.datasets import load_dataset
|
from .tuner.datasets import load_dataset
|
||||||
from .tuner.trainer import TrainingArgs, TrainingCallback, evaluate, train
|
from .tuner.trainer import TrainingArgs, TrainingCallback, evaluate, train
|
||||||
|
from .tuner.grpo_trainer import GRPOTrainingArgs, evaluate_grpo, train_grpo
|
||||||
from .tuner.utils import (
|
from .tuner.utils import (
|
||||||
build_schedule,
|
build_schedule,
|
||||||
linear_to_lora_layers,
|
linear_to_lora_layers,
|
||||||
@ -42,6 +43,7 @@ yaml_loader.add_implicit_resolver(
|
|||||||
CONFIG_DEFAULTS = {
|
CONFIG_DEFAULTS = {
|
||||||
"model": "mlx_model",
|
"model": "mlx_model",
|
||||||
"train": False,
|
"train": False,
|
||||||
|
"training_mode": "normal",
|
||||||
"fine_tune_type": "lora",
|
"fine_tune_type": "lora",
|
||||||
"data": "data/",
|
"data": "data/",
|
||||||
"seed": 0,
|
"seed": 0,
|
||||||
@ -62,6 +64,15 @@ CONFIG_DEFAULTS = {
|
|||||||
"grad_checkpoint": False,
|
"grad_checkpoint": False,
|
||||||
"lr_schedule": None,
|
"lr_schedule": None,
|
||||||
"lora_parameters": {"rank": 8, "alpha": 16, "dropout": 0.0, "scale": 10.0},
|
"lora_parameters": {"rank": 8, "alpha": 16, "dropout": 0.0, "scale": 10.0},
|
||||||
|
|
||||||
|
# GRPO args
|
||||||
|
"reference_model_path": None,
|
||||||
|
"group_size": 4,
|
||||||
|
"beta": 0.1,
|
||||||
|
"epsilon": 1e-4,
|
||||||
|
"max_completion_length": 512,
|
||||||
|
"use_chat_template": False,
|
||||||
|
"use_prompt": False,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@ -94,6 +105,12 @@ def build_parser():
|
|||||||
choices=["lora", "dora", "full"],
|
choices=["lora", "dora", "full"],
|
||||||
help="Type of fine-tuning to perform: lora, dora, or full.",
|
help="Type of fine-tuning to perform: lora, dora, or full.",
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--training-mode",
|
||||||
|
type=str,
|
||||||
|
choices=["normal", "grpo"],
|
||||||
|
help="Training mode: normal or GRPO",
|
||||||
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--num-layers",
|
"--num-layers",
|
||||||
type=int,
|
type=int,
|
||||||
@ -161,6 +178,44 @@ def build_parser():
|
|||||||
default=None,
|
default=None,
|
||||||
)
|
)
|
||||||
parser.add_argument("--seed", type=int, help="The PRNG seed")
|
parser.add_argument("--seed", type=int, help="The PRNG seed")
|
||||||
|
|
||||||
|
# GRPO args
|
||||||
|
parser.add_argument(
|
||||||
|
"--group-size",
|
||||||
|
type=int,
|
||||||
|
help="Number of generations.",
|
||||||
|
default=4,
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--max-completion-length",
|
||||||
|
type=int,
|
||||||
|
help="Maximum length of the prompt. If the prompt is longer than this value, it will be truncated left.",
|
||||||
|
default=512,
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--beta",
|
||||||
|
type=float,
|
||||||
|
help="KL penalty coefficient.",
|
||||||
|
default=0.1,
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--epsilon",
|
||||||
|
type=float,
|
||||||
|
help="The Epsilon for numerical stability.",
|
||||||
|
default=1e-4,
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--use-chat-template",
|
||||||
|
action="store_true",
|
||||||
|
help="If the model is a Chat model, use the Chat template.",
|
||||||
|
default=None,
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--use-prompt",
|
||||||
|
action="store_true",
|
||||||
|
help="Rather to use the prompt from the R1 paper.",
|
||||||
|
default=None,
|
||||||
|
)
|
||||||
return parser
|
return parser
|
||||||
|
|
||||||
|
|
||||||
@ -220,32 +275,102 @@ def train_model(
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
# Train model
|
# Train model
|
||||||
train(
|
if args.training_mode == "grpo":
|
||||||
model=model,
|
training_args = GRPOTrainingArgs(
|
||||||
tokenizer=tokenizer,
|
batch_size=args.batch_size,
|
||||||
args=training_args,
|
iters=args.iters,
|
||||||
optimizer=opt,
|
val_batches=args.val_batches,
|
||||||
train_dataset=train_set,
|
steps_per_report=args.steps_per_report,
|
||||||
val_dataset=valid_set,
|
steps_per_eval=args.steps_per_eval,
|
||||||
training_callback=training_callback,
|
steps_per_save=args.save_every,
|
||||||
)
|
adapter_file=adapter_file,
|
||||||
|
max_seq_length=args.max_seq_length,
|
||||||
|
max_completion_length=args.max_completion_length,
|
||||||
|
grad_checkpoint=args.grad_checkpoint,
|
||||||
|
beta=args.beta,
|
||||||
|
group_size=args.group_size,
|
||||||
|
epsilon=args.epsilon,
|
||||||
|
reference_model_path=args.reference_model_path
|
||||||
|
)
|
||||||
|
|
||||||
|
if args.reference_model_path:
|
||||||
|
reference_model, _ = load(args.reference_model_path)
|
||||||
|
reference_model = reference_model.freeze()
|
||||||
|
else:
|
||||||
|
reference_model, _ = load(args.model)
|
||||||
|
|
||||||
|
train_grpo(
|
||||||
|
model=model,
|
||||||
|
ref_model=reference_model,
|
||||||
|
tokenizer=tokenizer,
|
||||||
|
optimizer=opt,
|
||||||
|
train_dataset=train_set,
|
||||||
|
val_dataset=valid_set,
|
||||||
|
args=training_args,
|
||||||
|
training_callback=training_callback,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
training_args = TrainingArgs(
|
||||||
|
batch_size=args.batch_size,
|
||||||
|
iters=args.iters,
|
||||||
|
val_batches=args.val_batches,
|
||||||
|
steps_per_report=args.steps_per_report,
|
||||||
|
steps_per_eval=args.steps_per_eval,
|
||||||
|
steps_per_save=args.save_every,
|
||||||
|
adapter_file=adapter_file,
|
||||||
|
max_seq_length=args.max_seq_length,
|
||||||
|
grad_checkpoint=args.grad_checkpoint
|
||||||
|
)
|
||||||
|
|
||||||
|
train(
|
||||||
|
model=model,
|
||||||
|
tokenizer=tokenizer,
|
||||||
|
args=training_args,
|
||||||
|
optimizer=opt,
|
||||||
|
train_dataset=train_set,
|
||||||
|
val_dataset=valid_set,
|
||||||
|
training_callback=training_callback,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def evaluate_model(args, model: nn.Module, tokenizer: TokenizerWrapper, test_set):
|
def evaluate_model(args, model: nn.Module, tokenizer: TokenizerWrapper, test_set):
|
||||||
model.eval()
|
model.eval()
|
||||||
|
|
||||||
test_loss = evaluate(
|
if args.training_mode == "grpo":
|
||||||
model=model,
|
if args.reference_model_path:
|
||||||
dataset=test_set,
|
reference_model, _ = load(args.reference_model_path)
|
||||||
tokenizer=tokenizer,
|
else:
|
||||||
batch_size=args.batch_size,
|
reference_model = model
|
||||||
num_batches=args.test_batches,
|
|
||||||
max_seq_length=args.max_seq_length,
|
|
||||||
)
|
|
||||||
|
|
||||||
test_ppl = math.exp(test_loss)
|
test_loss, _, test_rewards = evaluate_grpo(
|
||||||
|
model=model,
|
||||||
|
ref_model=reference_model,
|
||||||
|
dataset=test_set,
|
||||||
|
tokenizer=tokenizer,
|
||||||
|
batch_size=args.batch_size,
|
||||||
|
num_batches=args.test_batches,
|
||||||
|
max_seq_length=args.max_seq_length,
|
||||||
|
beta=args.beta,
|
||||||
|
group_size=args.group_size,
|
||||||
|
epsilon=args.epsilon
|
||||||
|
)
|
||||||
|
|
||||||
print(f"Test loss {test_loss:.3f}, Test ppl {test_ppl:.3f}.")
|
test_ppl = math.exp(test_loss)
|
||||||
|
|
||||||
|
print(f"Test loss {test_loss:.3f}, Test ppl {test_ppl:.3f}, Rewards: {test_rewards[0]:.3f}, {test_rewards[1]:.3f}")
|
||||||
|
else:
|
||||||
|
test_loss = evaluate(
|
||||||
|
model=model,
|
||||||
|
dataset=test_set,
|
||||||
|
tokenizer=tokenizer,
|
||||||
|
batch_size=args.batch_size,
|
||||||
|
num_batches=args.test_batches,
|
||||||
|
max_seq_length=args.max_seq_length,
|
||||||
|
)
|
||||||
|
|
||||||
|
test_ppl = math.exp(test_loss)
|
||||||
|
|
||||||
|
print(f"Test loss {test_loss:.3f}, Test ppl {test_ppl:.3f}.")
|
||||||
|
|
||||||
|
|
||||||
def run(args, training_callback: TrainingCallback = None):
|
def run(args, training_callback: TrainingCallback = None):
|
||||||
|
1
llms/mlx_lm/test_grpo
Submodule
1
llms/mlx_lm/test_grpo
Submodule
@ -0,0 +1 @@
|
|||||||
|
Subproject commit a74695c9280dd46208ea000f507f44bc8ddd9533
|
@ -1,10 +1,59 @@
|
|||||||
import json
|
import json
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Dict, List, Optional
|
from typing import Dict, List, Optional, Tuple
|
||||||
|
|
||||||
from transformers import PreTrainedTokenizer
|
from transformers import PreTrainedTokenizer
|
||||||
|
|
||||||
|
|
||||||
|
class GRPODataset:
|
||||||
|
"""
|
||||||
|
Dataset wrapper for GRPO training data.
|
||||||
|
Each example should have a 'prompt' and 'answer' field.
|
||||||
|
Returns data in (prompt_tokens, answer_tokens, prompt_str, answer_str) tuple format.
|
||||||
|
"""
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
data: List[Dict[str, str]],
|
||||||
|
tokenizer: PreTrainedTokenizer,
|
||||||
|
prompt_key: str = "prompt",
|
||||||
|
answer_key: str = "answer",
|
||||||
|
use_chat_template: bool = False,
|
||||||
|
use_prompt: bool = False
|
||||||
|
):
|
||||||
|
self._data = []
|
||||||
|
for item in data:
|
||||||
|
prompt_str = str(item[prompt_key])
|
||||||
|
answer_str = str(item[answer_key])
|
||||||
|
if use_chat_template:
|
||||||
|
prompt_tokens = tokenizer.apply_chat_template(
|
||||||
|
[
|
||||||
|
{'role': 'system', 'content': """A conversation between User and Assistant. The user asks a question, and the Assistant solves it.
|
||||||
|
The assistantfirst thinks about the reasoning process in the mind and then provides the user with the answer.
|
||||||
|
The reasoning process and answer are enclosed within <think> </think> and <answer> </answer> tags, respectively, i.e., <think> reasoning process here </think><answer> answer here </answer>."""},
|
||||||
|
{'role': 'user', 'content': prompt_str}
|
||||||
|
],
|
||||||
|
)
|
||||||
|
answer_tokens = tokenizer.encode(answer_str)
|
||||||
|
else:
|
||||||
|
if use_prompt:
|
||||||
|
prompt_tokens = tokenizer.encode(f"""A conversation between User and Assistant. The user asks a question, and the Assistant solves it.
|
||||||
|
The assistantfirst thinks about the reasoning process in the mind and then provides the user with the answer.
|
||||||
|
The reasoning process and answer are enclosed within <think> </think> and <answer> </answer> tags, respectively, i.e., <think> reasoning process here </think><answer> answer here </answer>.
|
||||||
|
User: {prompt_str}. Assistant: """)
|
||||||
|
else:
|
||||||
|
prompt_tokens = tokenizer.encode(prompt_str)
|
||||||
|
answer_tokens = tokenizer.encode(answer_str)
|
||||||
|
self._data.append((prompt_tokens, answer_tokens, prompt_str, answer_str))
|
||||||
|
|
||||||
|
def __getitem__(self, idx: int) -> Tuple[List[int], List[int], str, str]:
|
||||||
|
"""Returns a (prompt_tokens, answer_tokens, prompt_str, answer_str) tuple."""
|
||||||
|
return self._data[idx]
|
||||||
|
|
||||||
|
def __len__(self) -> int:
|
||||||
|
"""Returns the number of examples in the dataset."""
|
||||||
|
return len(self._data)
|
||||||
|
|
||||||
|
|
||||||
class Dataset:
|
class Dataset:
|
||||||
"""
|
"""
|
||||||
Light-weight wrapper to hold a dataset.
|
Light-weight wrapper to hold a dataset.
|
||||||
@ -82,6 +131,7 @@ class CompletionsDataset:
|
|||||||
|
|
||||||
|
|
||||||
def create_dataset(
|
def create_dataset(
|
||||||
|
args,
|
||||||
data,
|
data,
|
||||||
tokenizer: PreTrainedTokenizer,
|
tokenizer: PreTrainedTokenizer,
|
||||||
prompt_feature: Optional[str] = None,
|
prompt_feature: Optional[str] = None,
|
||||||
@ -90,31 +140,44 @@ def create_dataset(
|
|||||||
prompt_feature = prompt_feature or "prompt"
|
prompt_feature = prompt_feature or "prompt"
|
||||||
completion_feature = completion_feature or "completion"
|
completion_feature = completion_feature or "completion"
|
||||||
sample = data[0]
|
sample = data[0]
|
||||||
if "messages" in sample:
|
|
||||||
return ChatDataset(data, tokenizer)
|
if args.training_mode == "normal":
|
||||||
elif prompt_feature in sample and completion_feature in sample:
|
if "messages" in sample:
|
||||||
return CompletionsDataset(data, tokenizer, prompt_feature, completion_feature)
|
return ChatDataset(data, tokenizer)
|
||||||
elif "text" in sample:
|
elif prompt_feature in sample and completion_feature in sample:
|
||||||
return Dataset(data, tokenizer)
|
return CompletionsDataset(data, tokenizer, prompt_feature, completion_feature)
|
||||||
|
elif "text" in sample:
|
||||||
|
return Dataset(data, tokenizer)
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
"Unsupported data format, check the supported formats here:\n"
|
||||||
|
"https://github.com/ml-explore/mlx-examples/blob/main/llms/mlx_lm/LORA.md#data."
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
raise ValueError(
|
return GRPODataset(
|
||||||
"Unsupported data format, check the supported formats here:\n"
|
data=data,
|
||||||
"https://github.com/ml-explore/mlx-examples/blob/main/llms/mlx_lm/LORA.md#data."
|
tokenizer=tokenizer,
|
||||||
|
prompt_key="prompt",
|
||||||
|
answer_key="answer",
|
||||||
|
use_chat_template=args.use_chat_template,
|
||||||
|
use_prompt=args.use_prompt
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def load_local_dataset(
|
def load_local_dataset(
|
||||||
|
args,
|
||||||
data_path: Path,
|
data_path: Path,
|
||||||
tokenizer: PreTrainedTokenizer,
|
tokenizer: PreTrainedTokenizer,
|
||||||
prompt_feature: Optional[str] = None,
|
prompt_feature: Optional[str] = None,
|
||||||
completion_feature: Optional[str] = None,
|
completion_feature: Optional[str] = None,
|
||||||
):
|
):
|
||||||
def load_subset(path):
|
def load_subset(path):
|
||||||
|
print(path)
|
||||||
if not path.exists():
|
if not path.exists():
|
||||||
return []
|
return []
|
||||||
with open(path, "r") as fid:
|
with open(path, "r") as fid:
|
||||||
data = [json.loads(l) for l in fid]
|
data = [json.loads(l) for l in fid]
|
||||||
return create_dataset(data, tokenizer, prompt_feature, completion_feature)
|
return create_dataset(args, data, tokenizer, prompt_feature, completion_feature)
|
||||||
|
|
||||||
names = ("train", "valid", "test")
|
names = ("train", "valid", "test")
|
||||||
train, valid, test = [load_subset(data_path / f"{n}.jsonl") for n in names]
|
train, valid, test = [load_subset(data_path / f"{n}.jsonl") for n in names]
|
||||||
@ -122,6 +185,7 @@ def load_local_dataset(
|
|||||||
|
|
||||||
|
|
||||||
def load_hf_dataset(
|
def load_hf_dataset(
|
||||||
|
args,
|
||||||
data_id: str,
|
data_id: str,
|
||||||
tokenizer: PreTrainedTokenizer,
|
tokenizer: PreTrainedTokenizer,
|
||||||
prompt_feature: Optional[str] = None,
|
prompt_feature: Optional[str] = None,
|
||||||
@ -137,7 +201,7 @@ def load_hf_dataset(
|
|||||||
train, valid, test = [
|
train, valid, test = [
|
||||||
(
|
(
|
||||||
create_dataset(
|
create_dataset(
|
||||||
dataset[n], tokenizer, prompt_feature, completion_feature
|
args, dataset[n], tokenizer, prompt_feature, completion_feature
|
||||||
)
|
)
|
||||||
if n in dataset.keys()
|
if n in dataset.keys()
|
||||||
else []
|
else []
|
||||||
@ -202,12 +266,12 @@ def load_dataset(args, tokenizer: PreTrainedTokenizer):
|
|||||||
completion_feature = getattr(args, "completion_feature", None)
|
completion_feature = getattr(args, "completion_feature", None)
|
||||||
if data_path.exists():
|
if data_path.exists():
|
||||||
train, valid, test = load_local_dataset(
|
train, valid, test = load_local_dataset(
|
||||||
data_path, tokenizer, prompt_feature, completion_feature
|
args, data_path, tokenizer, prompt_feature, completion_feature
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
print(f"Loading Hugging Face dataset {args.data}.")
|
print(f"Loading Hugging Face dataset {args.data}.")
|
||||||
train, valid, test = load_hf_dataset(
|
train, valid, test = load_hf_dataset(
|
||||||
args.data, tokenizer, prompt_feature, completion_feature
|
args, args.data, tokenizer, prompt_feature, completion_feature
|
||||||
)
|
)
|
||||||
|
|
||||||
if args.train and len(train) == 0:
|
if args.train and len(train) == 0:
|
||||||
|
690
llms/mlx_lm/tuner/grpo_trainer.py
Normal file
690
llms/mlx_lm/tuner/grpo_trainer.py
Normal file
@ -0,0 +1,690 @@
|
|||||||
|
# Copyright © 2024 Apple Inc.
|
||||||
|
|
||||||
|
import time
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from pathlib import Path
|
||||||
|
import re
|
||||||
|
|
||||||
|
import mlx.core as mx
|
||||||
|
import mlx.nn as nn
|
||||||
|
import numpy as np
|
||||||
|
from mlx.utils import tree_flatten
|
||||||
|
|
||||||
|
from .trainer import grad_checkpoint, TrainingArgs, TrainingCallback, average_gradients, iterate_batches
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class GRPOTrainingArgs(TrainingArgs):
|
||||||
|
group_size: int = field(
|
||||||
|
default=4,
|
||||||
|
metadata={"help": "Number of responses per prompt."},
|
||||||
|
)
|
||||||
|
beta: float = field(
|
||||||
|
default=0.1, metadata={"help": "KL penalty coefficient."}
|
||||||
|
)
|
||||||
|
epsilon: float = field(
|
||||||
|
default=1e-4, metadata={"help": "The Epsilon for numerical stability."}
|
||||||
|
)
|
||||||
|
max_completion_length: int = field(
|
||||||
|
default=512, metadata={"help": "Number of Generations."}
|
||||||
|
)
|
||||||
|
reference_model_path: str = field(
|
||||||
|
default=None,
|
||||||
|
metadata={
|
||||||
|
"help": "Path to reference model weights. If None, uses the same model."
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def r1_extract_xml_answer(text: str) -> str:
|
||||||
|
"""Extracts the answer from an XML formatted text string."""
|
||||||
|
try:
|
||||||
|
answer = text.split("<answer>")[-1]
|
||||||
|
answer = answer.split("</answer>")[0]
|
||||||
|
return answer.strip()
|
||||||
|
except:
|
||||||
|
print("r1_extract_xml_answer returned empty string")
|
||||||
|
return ""
|
||||||
|
|
||||||
|
|
||||||
|
def r1_int_reward_func(prompts: list, completions: list, answer: list, **kwargs) -> list[float]:
|
||||||
|
"""Ensures we always return a list of floats."""
|
||||||
|
if not completions:
|
||||||
|
return [0.0] * len(prompts)
|
||||||
|
extracted_responses = [r1_extract_xml_answer(r) for r in completions]
|
||||||
|
return [0.5 if r and r.isdigit() else 0.0 for r in extracted_responses]
|
||||||
|
|
||||||
|
def r1_accuracy_reward_func(prompts: list, completions: list, answer: list, **kwargs) -> list[float]:
|
||||||
|
"""Ensures we always return a list of floats."""
|
||||||
|
if not completions or not answer:
|
||||||
|
return [0.0] * len(prompts)
|
||||||
|
extracted_responses = [r1_extract_xml_answer(r) for r in completions]
|
||||||
|
return [2.0 if r and a and r == a else 0.0 for r, a in zip(extracted_responses, answer)]
|
||||||
|
|
||||||
|
|
||||||
|
def r1_soft_format_reward_func(prompts: list, completions: list, answer: list, **kwargs) -> list[float]:
|
||||||
|
"""Ensures we always return a list of floats."""
|
||||||
|
if not completions:
|
||||||
|
return [0.0] * len(prompts)
|
||||||
|
pattern = r"<think>.*?</think>\s*<answer>.*?</answer>"
|
||||||
|
matches = [bool(re.search(pattern, r)) if r else False for r in completions]
|
||||||
|
return [0.5 if match else 0.0 for match in matches]
|
||||||
|
|
||||||
|
|
||||||
|
def r1_strict_format_reward_func(prompts: list, completions: list, answer: list, **kwargs) -> list[float]:
|
||||||
|
"""Ensures we always return a list of floats."""
|
||||||
|
if not completions:
|
||||||
|
return [0.0] * len(prompts)
|
||||||
|
pattern = r"^<think>\n.*?\n</think>\n<answer>\n.*?\n</answer>\n$"
|
||||||
|
matches = [bool(re.search(pattern, r)) if r else False for r in completions]
|
||||||
|
return [0.5 if match else 0.0 for match in matches]
|
||||||
|
|
||||||
|
|
||||||
|
def r1_count_xml(prompts: list, completions: list, answer: list, **kwargs) -> list[float]:
|
||||||
|
"""Ensures we always return a list of floats."""
|
||||||
|
if not completions:
|
||||||
|
return [0.0] * len(prompts)
|
||||||
|
|
||||||
|
scores = []
|
||||||
|
for text in completions:
|
||||||
|
if not text:
|
||||||
|
scores.append(0.0)
|
||||||
|
continue
|
||||||
|
|
||||||
|
count = 0.0
|
||||||
|
if text.count("<think>\n") == 1:
|
||||||
|
count += 0.125
|
||||||
|
if text.count("\n</think>\n") == 1:
|
||||||
|
count += 0.125
|
||||||
|
if text.count("\n<answer>\n") == 1:
|
||||||
|
count += 0.125
|
||||||
|
if text.count("\n</answer>\n") == 1:
|
||||||
|
count += 0.125
|
||||||
|
|
||||||
|
# Penalize extra text after </answer>
|
||||||
|
end_text = text.split("\n</answer>\n")[-1]
|
||||||
|
count -= len(end_text) * 0.001 if len(end_text) > 0 else 0
|
||||||
|
|
||||||
|
scores.append(max(0.0, count)) # Ensure non-negative score
|
||||||
|
|
||||||
|
return scores
|
||||||
|
|
||||||
|
|
||||||
|
def generate_grpo(model, prompt, max_tokens, tokenizer, temperature):
|
||||||
|
if len(prompt.shape) == 1:
|
||||||
|
prompt = prompt[None, :]
|
||||||
|
if prompt.shape[1] == 0:
|
||||||
|
return None
|
||||||
|
|
||||||
|
end_sequence = tokenizer.encode("</answer>")
|
||||||
|
end_sequence_length = len(end_sequence)
|
||||||
|
output = mx.zeros((prompt.shape[1] + max_tokens,), dtype=mx.int32)
|
||||||
|
output[:prompt.shape[1]] = prompt[0]
|
||||||
|
current_length = prompt.shape[1]
|
||||||
|
|
||||||
|
try:
|
||||||
|
def sample(logits):
|
||||||
|
if temperature > 0:
|
||||||
|
logits /= temperature
|
||||||
|
logprobs = logits - mx.logsumexp(logits, keepdims=True)
|
||||||
|
return mx.random.categorical(logprobs[None, :]).astype(mx.int32)[0]
|
||||||
|
|
||||||
|
for _ in range(max_tokens):
|
||||||
|
current_input = output[:current_length][None, :]
|
||||||
|
logits = model(current_input)
|
||||||
|
token_logits = logits[0, -1]
|
||||||
|
next_token = sample(token_logits)
|
||||||
|
token_value = next_token.item()
|
||||||
|
output[current_length] = token_value
|
||||||
|
current_length += 1
|
||||||
|
|
||||||
|
if token_value == tokenizer.eos_token_id:
|
||||||
|
break
|
||||||
|
|
||||||
|
if current_length >= end_sequence_length:
|
||||||
|
last_tokens = output[current_length - end_sequence_length:current_length].tolist()
|
||||||
|
# print(f"Last tokens: {last_tokens}")
|
||||||
|
# print(f"Decoded text: {tokenizer.decode(last_tokens)}")
|
||||||
|
# print(f"Target sequence: {end_sequence}")
|
||||||
|
if last_tokens == end_sequence:
|
||||||
|
break
|
||||||
|
|
||||||
|
if current_length > prompt.shape[1]:
|
||||||
|
return output[:current_length]
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Generation error: {str(e)}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def get_per_token_logps(model, inputs, lengths):
|
||||||
|
logits = model(inputs).astype(mx.float16)
|
||||||
|
logits = logits[:, :-1, :]
|
||||||
|
targets = inputs[:, 1:]
|
||||||
|
|
||||||
|
per_token_logps = []
|
||||||
|
for i in range(logits.shape[0]):
|
||||||
|
seq_len = int(lengths[i]) - 1
|
||||||
|
|
||||||
|
seq_logits = logits[i, :seq_len]
|
||||||
|
seq_targets = targets[i, :seq_len]
|
||||||
|
|
||||||
|
log_probs = nn.log_softmax(seq_logits, axis=-1)
|
||||||
|
|
||||||
|
token_log_probs = mx.take_along_axis(
|
||||||
|
log_probs,
|
||||||
|
seq_targets.reshape(seq_len, 1),
|
||||||
|
axis=-1
|
||||||
|
).squeeze(-1)
|
||||||
|
|
||||||
|
per_token_logps.append(token_log_probs)
|
||||||
|
mx.eval(logits)
|
||||||
|
return per_token_logps
|
||||||
|
|
||||||
|
|
||||||
|
def grpo_loss(
|
||||||
|
model,
|
||||||
|
tokenizer,
|
||||||
|
batch,
|
||||||
|
reward_funcs=None,
|
||||||
|
beta=0.1,
|
||||||
|
group_size=4,
|
||||||
|
epsilon=1e-4,
|
||||||
|
ref_model=None,
|
||||||
|
max_tokens=64,
|
||||||
|
temperature=1.0
|
||||||
|
):
|
||||||
|
prompt_tokens, answer_tokens, prompt_text, answer_text = batch
|
||||||
|
batch_size = len(prompt_tokens)
|
||||||
|
|
||||||
|
# Generation logic remains the same
|
||||||
|
all_completions = []
|
||||||
|
all_completion_texts = []
|
||||||
|
|
||||||
|
for i in range(0, batch_size, batch_size):
|
||||||
|
batch_prompts = prompt_tokens[i:i+batch_size]
|
||||||
|
for prompt in batch_prompts:
|
||||||
|
prompt_tensor = mx.array(prompt)
|
||||||
|
for _ in range(group_size):
|
||||||
|
try:
|
||||||
|
completion_ids = generate_grpo(model, prompt_tensor, max_tokens, tokenizer, temperature)
|
||||||
|
if completion_ids is not None:
|
||||||
|
completion_text = tokenizer.decode(completion_ids.tolist())
|
||||||
|
all_completions.append(completion_ids)
|
||||||
|
all_completion_texts.append(completion_text)
|
||||||
|
|
||||||
|
# Clear completion tensors
|
||||||
|
mx.eval(completion_ids)
|
||||||
|
del completion_ids
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Generation error: {e}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
mx.metal.clear_cache()
|
||||||
|
|
||||||
|
# Prepare inputs
|
||||||
|
expanded_answers = []
|
||||||
|
expanded_prompts = []
|
||||||
|
for i in range(batch_size):
|
||||||
|
expanded_answers.extend([answer_text[i]] * group_size)
|
||||||
|
expanded_prompts.extend([prompt_text[i]] * group_size)
|
||||||
|
|
||||||
|
max_length = max(ids.shape[0] for ids in all_completions)
|
||||||
|
padded_completions = []
|
||||||
|
attention_masks = []
|
||||||
|
|
||||||
|
for completion_ids in all_completions:
|
||||||
|
padding_length = max_length - completion_ids.shape[0]
|
||||||
|
if padding_length > 0:
|
||||||
|
padding = mx.zeros((padding_length,), dtype=completion_ids.dtype)
|
||||||
|
padded_ids = mx.concatenate([completion_ids, padding])
|
||||||
|
mask = mx.concatenate([mx.ones_like(completion_ids), mx.zeros_like(padding)])
|
||||||
|
else:
|
||||||
|
padded_ids = completion_ids
|
||||||
|
mask = mx.ones_like(completion_ids)
|
||||||
|
padded_completions.append(padded_ids)
|
||||||
|
attention_masks.append(mask)
|
||||||
|
|
||||||
|
inputs = mx.stack(padded_completions)
|
||||||
|
attention_mask = mx.stack(attention_masks)
|
||||||
|
lengths = attention_mask.sum(axis=1)
|
||||||
|
|
||||||
|
# Current policy probabilities
|
||||||
|
token_log_probs = get_per_token_logps(model, inputs, lengths)
|
||||||
|
|
||||||
|
mx.eval(token_log_probs)
|
||||||
|
mx.metal.clear_cache()
|
||||||
|
|
||||||
|
# Reference policy probabilities
|
||||||
|
if ref_model is not None:
|
||||||
|
ref_token_log_probs = get_per_token_logps(ref_model, inputs, lengths)
|
||||||
|
else:
|
||||||
|
ref_token_log_probs = token_log_probs
|
||||||
|
|
||||||
|
max_len = max(x.shape[0] for x in token_log_probs)
|
||||||
|
padded_log_probs = []
|
||||||
|
padded_ref_log_probs = []
|
||||||
|
|
||||||
|
for i in range(len(token_log_probs)):
|
||||||
|
seq_len = token_log_probs[i].shape[0]
|
||||||
|
padding = mx.zeros((max_len - seq_len,), dtype=mx.float16)
|
||||||
|
|
||||||
|
padded_log_probs.append(mx.concatenate([token_log_probs[i], padding]))
|
||||||
|
padded_ref_log_probs.append(mx.concatenate([ref_token_log_probs[i], padding]))
|
||||||
|
|
||||||
|
token_log_probs = mx.stack(padded_log_probs)
|
||||||
|
ref_token_log_probs = mx.stack(padded_ref_log_probs)
|
||||||
|
|
||||||
|
# Calculate rewards and advantages
|
||||||
|
rewards = mx.zeros((len(all_completions),))
|
||||||
|
for reward_func in reward_funcs:
|
||||||
|
func_rewards = mx.array(reward_func(
|
||||||
|
prompts=expanded_prompts,
|
||||||
|
completions=all_completion_texts,
|
||||||
|
answer=expanded_answers
|
||||||
|
))
|
||||||
|
rewards += func_rewards
|
||||||
|
|
||||||
|
if len(reward_funcs) > 1:
|
||||||
|
rewards /= len(reward_funcs)
|
||||||
|
|
||||||
|
# Reshape rewards and compute advantages following GRPO formula
|
||||||
|
rewards_reshaped = rewards.reshape(batch_size, group_size)
|
||||||
|
mean_rewards = mx.broadcast_to(mx.mean(rewards_reshaped, axis=1)[:, None], (rewards_reshaped.shape[0], group_size)).reshape(-1)
|
||||||
|
std_rewards = mx.broadcast_to(mx.std(rewards_reshaped, axis=1)[:, None], (rewards_reshaped.shape[0], group_size)).reshape(-1)
|
||||||
|
advantages = (rewards - mean_rewards) / (std_rewards + epsilon)
|
||||||
|
|
||||||
|
# Compute KL divergence using Schulman's approximator
|
||||||
|
kl_div = (mx.exp(token_log_probs - ref_token_log_probs) - 1) - (token_log_probs - ref_token_log_probs)
|
||||||
|
|
||||||
|
# Create mask for valid tokens
|
||||||
|
length_mask = mx.arange(inputs.shape[1] - 1)[None, :] < (lengths[:, None] - 1)
|
||||||
|
|
||||||
|
# Compute policy ratio
|
||||||
|
policy_ratio = mx.exp(token_log_probs - mx.stop_gradient(token_log_probs))
|
||||||
|
|
||||||
|
# Compute per-token loss following GRPO formula
|
||||||
|
per_token_loss = -((policy_ratio * advantages.reshape(-1, 1) - beta * kl_div) * length_mask)
|
||||||
|
|
||||||
|
# Average over tokens and sequences
|
||||||
|
sequence_sums = per_token_loss.sum(axis=1)
|
||||||
|
sequence_lengths = length_mask.sum(axis=1)
|
||||||
|
|
||||||
|
loss = (sequence_sums / sequence_lengths).mean()
|
||||||
|
|
||||||
|
# Calculate mean KL divergence for metrics
|
||||||
|
mean_kl = ((kl_div * length_mask).sum(axis=1) / length_mask.sum(axis=1)).mean()
|
||||||
|
|
||||||
|
# Collect reward metrics
|
||||||
|
reward_metrics = {}
|
||||||
|
for i, reward_func in enumerate(reward_funcs):
|
||||||
|
func_name = reward_func.__name__
|
||||||
|
func_rewards = mx.array(reward_func(
|
||||||
|
prompts=expanded_prompts,
|
||||||
|
completions=all_completion_texts,
|
||||||
|
answer=expanded_answers
|
||||||
|
))
|
||||||
|
reward_metrics[f'{func_name}_mean'] = mx.mean(func_rewards)
|
||||||
|
reward_metrics[f'{func_name}_std'] = mx.std(func_rewards)
|
||||||
|
|
||||||
|
metrics = {
|
||||||
|
'total_rewards_mean': mx.mean(rewards),
|
||||||
|
'total_rewards_std': mx.std(rewards),
|
||||||
|
'grouped_rewards_mean': mx.mean(rewards_reshaped),
|
||||||
|
'grouped_rewards_std': mx.std(rewards_reshaped),
|
||||||
|
'kl': mean_kl,
|
||||||
|
**reward_metrics
|
||||||
|
}
|
||||||
|
mx.metal.clear_cache()
|
||||||
|
|
||||||
|
return loss, sequence_lengths.sum(), metrics
|
||||||
|
|
||||||
|
|
||||||
|
def iterate_grpo_batches(dataset, tokenizer, batch_size, max_seq_length, train=False):
|
||||||
|
"""Memory-optimized version of iterate_grpo_batches"""
|
||||||
|
if not dataset or not isinstance(dataset[0], tuple) or len(dataset[0]) != 4:
|
||||||
|
raise ValueError("Dataset must be list of (prompt_tokens, answer_tokens, prompt_str, answer_str) tuples")
|
||||||
|
|
||||||
|
# Sort by length but use generator to avoid keeping full sorted list in memory
|
||||||
|
def length_key(i):
|
||||||
|
return len(dataset[i][0]) + len(dataset[i][1])
|
||||||
|
|
||||||
|
idx = sorted(range(len(dataset)), key=length_key)
|
||||||
|
|
||||||
|
if len(dataset) < batch_size:
|
||||||
|
raise ValueError(
|
||||||
|
f"Dataset must have at least batch_size={batch_size} "
|
||||||
|
f"examples but only has {len(dataset)}."
|
||||||
|
)
|
||||||
|
|
||||||
|
step = mx.distributed.init().size()
|
||||||
|
if batch_size % step != 0:
|
||||||
|
raise ValueError("The batch size must be divisible by the number of workers")
|
||||||
|
|
||||||
|
# Use generator for batch indices
|
||||||
|
def batch_index_generator():
|
||||||
|
for i in range(0, len(idx) - batch_size + 1, batch_size):
|
||||||
|
yield idx[i : i + batch_size : step]
|
||||||
|
|
||||||
|
while True:
|
||||||
|
indices = (
|
||||||
|
np.random.permutation(list(batch_index_generator())) if train
|
||||||
|
else batch_index_generator()
|
||||||
|
)
|
||||||
|
|
||||||
|
for batch_idx in indices:
|
||||||
|
current_batch = [dataset[j] for j in batch_idx]
|
||||||
|
|
||||||
|
prompts_tokens = [item[0] for item in current_batch]
|
||||||
|
answers_tokens = [item[1] for item in current_batch]
|
||||||
|
prompts_text = [item[2] for item in current_batch]
|
||||||
|
answers_text = [item[3] for item in current_batch]
|
||||||
|
|
||||||
|
if any(len(p) > max_seq_length for p in prompts_tokens):
|
||||||
|
print(
|
||||||
|
f"[WARNING] Some prompts are longer than {max_seq_length} tokens. "
|
||||||
|
"Long prompts will be truncated."
|
||||||
|
)
|
||||||
|
|
||||||
|
yield prompts_tokens, answers_tokens, prompts_text, answers_text
|
||||||
|
|
||||||
|
if not train:
|
||||||
|
break
|
||||||
|
|
||||||
|
|
||||||
|
def evaluate_grpo(
|
||||||
|
model,
|
||||||
|
ref_model,
|
||||||
|
dataset,
|
||||||
|
tokenizer,
|
||||||
|
batch_size,
|
||||||
|
num_batches,
|
||||||
|
beta: float,
|
||||||
|
epsilon: float,
|
||||||
|
group_size: int,
|
||||||
|
max_seq_length,
|
||||||
|
reward_funcs = None,
|
||||||
|
loss_fn: callable = grpo_loss,
|
||||||
|
iterate_batches: callable = iterate_grpo_batches
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Evaluate model using GRPO loss.
|
||||||
|
Returns:
|
||||||
|
tuple: (average loss, number of tokens, average metrics)
|
||||||
|
"""
|
||||||
|
all_losses = 0
|
||||||
|
ntokens = 0
|
||||||
|
all_metrics = None # Initialize metrics dictionary
|
||||||
|
|
||||||
|
# Create iterator for batches
|
||||||
|
index_iterator = iter(range(num_batches)) if num_batches != -1 else iter(int, 1)
|
||||||
|
|
||||||
|
# Iterate through batches
|
||||||
|
for _, batch in zip(
|
||||||
|
index_iterator,
|
||||||
|
iterate_batches(
|
||||||
|
dataset=dataset,
|
||||||
|
tokenizer=tokenizer,
|
||||||
|
batch_size=batch_size,
|
||||||
|
max_seq_length=max_seq_length,
|
||||||
|
),
|
||||||
|
):
|
||||||
|
# Calculate loss for current batch
|
||||||
|
losses, toks, metrics = loss_fn(
|
||||||
|
model=model,
|
||||||
|
tokenizer=tokenizer,
|
||||||
|
batch=batch,
|
||||||
|
reward_funcs=reward_funcs,
|
||||||
|
beta=beta,
|
||||||
|
group_size=group_size,
|
||||||
|
epsilon=epsilon,
|
||||||
|
ref_model=ref_model
|
||||||
|
)
|
||||||
|
|
||||||
|
# Accumulate losses and tokens
|
||||||
|
all_losses += losses * toks
|
||||||
|
ntokens += toks
|
||||||
|
|
||||||
|
# Accumulate metrics
|
||||||
|
if all_metrics is None:
|
||||||
|
all_metrics = {k: v * toks for k, v in metrics.items()}
|
||||||
|
else:
|
||||||
|
for k, v in metrics.items():
|
||||||
|
all_metrics[k] += v * toks
|
||||||
|
|
||||||
|
# Evaluate accumulated values
|
||||||
|
mx.eval(all_losses, ntokens)
|
||||||
|
|
||||||
|
# Aggregate across distributed workers
|
||||||
|
all_losses = mx.distributed.all_sum(all_losses, stream=mx.cpu)
|
||||||
|
ntokens = mx.distributed.all_sum(ntokens, stream=mx.cpu)
|
||||||
|
all_metrics = {k: mx.distributed.all_sum(v) for k, v in all_metrics.items()}
|
||||||
|
|
||||||
|
# Calculate averages
|
||||||
|
avg_metrics = {k: (v / ntokens).item() for k, v in all_metrics.items()}
|
||||||
|
avg_loss = (all_losses / ntokens).item()
|
||||||
|
|
||||||
|
return avg_loss, ntokens, avg_metrics
|
||||||
|
|
||||||
|
|
||||||
|
def train_grpo(
|
||||||
|
model,
|
||||||
|
ref_model,
|
||||||
|
tokenizer,
|
||||||
|
optimizer,
|
||||||
|
train_dataset,
|
||||||
|
val_dataset,
|
||||||
|
reward_funcs = [
|
||||||
|
r1_accuracy_reward_func,
|
||||||
|
r1_int_reward_func,
|
||||||
|
r1_strict_format_reward_func,
|
||||||
|
r1_soft_format_reward_func,
|
||||||
|
r1_count_xml
|
||||||
|
],
|
||||||
|
args: GRPOTrainingArgs = GRPOTrainingArgs(),
|
||||||
|
loss_fn: callable = grpo_loss,
|
||||||
|
iterate_batches: callable = iterate_grpo_batches,
|
||||||
|
training_callback: TrainingCallback = None,
|
||||||
|
):
|
||||||
|
print(f"Starting GRPO training with {len(reward_funcs)} reward functions..., iters: {args.iters}")
|
||||||
|
world = mx.distributed.init()
|
||||||
|
world_size = world.size()
|
||||||
|
rank = world.rank()
|
||||||
|
if world_size > 1:
|
||||||
|
print(f"Node {rank} of {world_size}")
|
||||||
|
|
||||||
|
if args.grad_checkpoint:
|
||||||
|
grad_checkpoint(model.layers[0])
|
||||||
|
|
||||||
|
state = [model.state, optimizer.state]
|
||||||
|
|
||||||
|
def step(batch):
|
||||||
|
|
||||||
|
# Forward and backward pass
|
||||||
|
(loss, toks, metrics), grad = loss_value_and_grad(
|
||||||
|
model,
|
||||||
|
tokenizer=tokenizer,
|
||||||
|
batch=batch,
|
||||||
|
reward_funcs=reward_funcs,
|
||||||
|
beta=args.beta,
|
||||||
|
group_size=args.group_size,
|
||||||
|
epsilon=args.epsilon,
|
||||||
|
ref_model=ref_model,
|
||||||
|
max_tokens=args.max_completion_length,
|
||||||
|
)
|
||||||
|
|
||||||
|
# All reduce the gradients if running in distributed mode
|
||||||
|
grad = average_gradients(grad)
|
||||||
|
|
||||||
|
# Model update
|
||||||
|
optimizer.update(model, grad)
|
||||||
|
|
||||||
|
return loss, toks, metrics
|
||||||
|
|
||||||
|
loss_value_and_grad = nn.value_and_grad(model, loss_fn)
|
||||||
|
|
||||||
|
losses = 0
|
||||||
|
n_tokens = 0
|
||||||
|
steps = 0
|
||||||
|
trained_tokens = 0
|
||||||
|
accumulated_metrics = {
|
||||||
|
'total_rewards_mean': 0,
|
||||||
|
'total_rewards_std': 0,
|
||||||
|
'grouped_rewards_mean': 0,
|
||||||
|
'grouped_rewards_std': 0,
|
||||||
|
'kl': 0
|
||||||
|
}
|
||||||
|
for reward_func in reward_funcs:
|
||||||
|
func_name = reward_func.__name__
|
||||||
|
accumulated_metrics[f'{func_name}_mean'] = 0
|
||||||
|
accumulated_metrics[f'{func_name}_std'] = 0
|
||||||
|
|
||||||
|
start = time.perf_counter()
|
||||||
|
for it, batch in zip(
|
||||||
|
range(1, args.iters + 1),
|
||||||
|
iterate_batches(
|
||||||
|
dataset=train_dataset,
|
||||||
|
tokenizer=tokenizer,
|
||||||
|
batch_size=args.batch_size,
|
||||||
|
max_seq_length=args.max_seq_length,
|
||||||
|
train=True,
|
||||||
|
),
|
||||||
|
):
|
||||||
|
# Report validation loss if needed, the first validation loss
|
||||||
|
# is always measured before any training.
|
||||||
|
if it == 1 or it % args.steps_per_eval == 0 or it == args.iters:
|
||||||
|
stop = time.perf_counter()
|
||||||
|
val_loss, val_ntokens, val_metrics = evaluate_grpo(
|
||||||
|
model=model,
|
||||||
|
dataset=val_dataset,
|
||||||
|
loss_fn=loss_fn,
|
||||||
|
ref_model=ref_model,
|
||||||
|
reward_funcs=reward_funcs,
|
||||||
|
tokenizer=tokenizer,
|
||||||
|
group_size=args.group_size,
|
||||||
|
batch_size=args.batch_size,
|
||||||
|
num_batches=args.val_batches,
|
||||||
|
max_seq_length=args.max_seq_length,
|
||||||
|
beta=args.beta,
|
||||||
|
epsilon=args.epsilon,
|
||||||
|
iterate_batches=iterate_batches,
|
||||||
|
)
|
||||||
|
val_time = time.perf_counter() - stop
|
||||||
|
if rank == 0:
|
||||||
|
val_metrics_str = (
|
||||||
|
f"Val loss {val_loss:.8f}, "
|
||||||
|
f"Val total_rewards_mean {val_metrics['total_rewards_mean']:.3f}, "
|
||||||
|
f"Val total_rewards_std {val_metrics['total_rewards_std']:.3f}, "
|
||||||
|
f"Val grouped_rewards_mean {val_metrics['grouped_rewards_mean']:.3f}, "
|
||||||
|
f"Val grouped_rewards_std {val_metrics['grouped_rewards_std']:.3f}, "
|
||||||
|
f"Val kl {val_metrics['kl']:.3f}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Add reward function specific metrics
|
||||||
|
for i, reward_func in enumerate(reward_funcs):
|
||||||
|
val_metrics_str += (
|
||||||
|
f", Val {reward_func.__name__}_mean {val_metrics[f'{reward_func.__name__}_mean']:.3f}, "
|
||||||
|
f"Val {reward_func.__name__}_std {val_metrics[f'{reward_func.__name__}_std']:.3f}"
|
||||||
|
)
|
||||||
|
|
||||||
|
print(
|
||||||
|
f"Iter {it}: {val_metrics_str}, "
|
||||||
|
f"Val took {val_time:.3f}s",
|
||||||
|
flush=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
if training_callback is not None:
|
||||||
|
training_callback.on_val_loss_report({
|
||||||
|
"iteration": it,
|
||||||
|
"val_loss": val_loss,
|
||||||
|
**{f"val_{k}": v for k, v in val_metrics.items()},
|
||||||
|
"val_time": val_time,
|
||||||
|
})
|
||||||
|
|
||||||
|
start = time.perf_counter()
|
||||||
|
|
||||||
|
loss, toks, metrics = step(batch)
|
||||||
|
losses += loss
|
||||||
|
n_tokens += toks
|
||||||
|
steps += 1
|
||||||
|
|
||||||
|
for k, v in metrics.items():
|
||||||
|
accumulated_metrics[k] += v
|
||||||
|
|
||||||
|
mx.eval(state, losses, n_tokens)
|
||||||
|
|
||||||
|
if it % args.steps_per_report == 0 or it == args.iters:
|
||||||
|
stop = time.perf_counter()
|
||||||
|
|
||||||
|
train_loss = mx.distributed.all_sum(losses, stream=mx.cpu).item()
|
||||||
|
train_loss /= steps * mx.distributed.init().size()
|
||||||
|
avg_metrics = {k: v / (steps * world_size) for k, v in accumulated_metrics.items()}
|
||||||
|
n_tokens = mx.distributed.all_sum(n_tokens, stream=mx.cpu).item()
|
||||||
|
learning_rate = optimizer.learning_rate.item()
|
||||||
|
it_sec = args.steps_per_report / (stop - start)
|
||||||
|
tokens_sec = float(n_tokens) / (stop - start)
|
||||||
|
trained_tokens += n_tokens
|
||||||
|
peak_mem = mx.metal.get_peak_memory() / 1e9
|
||||||
|
|
||||||
|
if rank == 0:
|
||||||
|
train_metrics_str = (
|
||||||
|
f"Train loss {train_loss:.8f}, "
|
||||||
|
f"Total rewards mean {avg_metrics['total_rewards_mean']:.3f}, "
|
||||||
|
f"Total rewards std {avg_metrics['total_rewards_std']:.3f}, "
|
||||||
|
f"Grouped rewards mean {avg_metrics['grouped_rewards_mean']:.3f}, "
|
||||||
|
f"Grouped rewards std {avg_metrics['grouped_rewards_std']:.3f}, "
|
||||||
|
f"KL {avg_metrics['kl']:.3f}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Add reward function specific metrics
|
||||||
|
for i, reward_func in enumerate(reward_funcs):
|
||||||
|
func_name = reward_func.__name__
|
||||||
|
train_metrics_str += (
|
||||||
|
f", {func_name} mean {avg_metrics[f'{func_name}_mean']:.3f}, "
|
||||||
|
f"{func_name} std {avg_metrics[f'{func_name}_std']:.3f}"
|
||||||
|
)
|
||||||
|
|
||||||
|
print(
|
||||||
|
f"Iter {it}: {train_metrics_str}, "
|
||||||
|
f"Learning Rate {learning_rate:.3e}, "
|
||||||
|
f"It/sec {it_sec:.3f}, "
|
||||||
|
f"Tokens/sec {tokens_sec:.3f}, "
|
||||||
|
f"Peak mem {peak_mem:.3f} GB",
|
||||||
|
flush=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
if training_callback is not None:
|
||||||
|
training_callback.on_train_loss_report({
|
||||||
|
"iteration": it,
|
||||||
|
"train_loss": train_loss,
|
||||||
|
**{f"train_{k}": v for k, v in avg_metrics.items()},
|
||||||
|
"learning_rate": learning_rate,
|
||||||
|
"iterations_per_second": it_sec,
|
||||||
|
"tokens_per_second": tokens_sec,
|
||||||
|
"trained_tokens": trained_tokens,
|
||||||
|
"peak_memory": peak_mem,
|
||||||
|
})
|
||||||
|
|
||||||
|
losses = 0
|
||||||
|
n_tokens = 0
|
||||||
|
steps = 0
|
||||||
|
start = time.perf_counter()
|
||||||
|
|
||||||
|
# Save adapter weights
|
||||||
|
if it % args.steps_per_save == 0:
|
||||||
|
adapter_weights = dict(tree_flatten(model.trainable_parameters()))
|
||||||
|
mx.save_safetensors(str(args.adapter_file), adapter_weights)
|
||||||
|
checkpoint = (
|
||||||
|
Path(args.adapter_file).parent / f"{it:07d}_adapters.safetensors"
|
||||||
|
)
|
||||||
|
mx.save_safetensors(str(checkpoint), adapter_weights)
|
||||||
|
print(
|
||||||
|
f"Iter {it}: Saved adapter weights to "
|
||||||
|
f"{args.adapter_file} and {checkpoint}."
|
||||||
|
)
|
||||||
|
|
||||||
|
# Save final weights
|
||||||
|
adapter_weights = dict(tree_flatten(model.trainable_parameters()))
|
||||||
|
mx.save_safetensors(str(args.adapter_file), adapter_weights)
|
||||||
|
print(f"Saved final weights to {args.adapter_file}.")
|
Loading…
Reference in New Issue
Block a user