removing comments + adding temperature + reward weighting

This commit is contained in:
Goekdeniz-Guelmez
2025-02-15 15:29:22 +01:00
parent baeb9f117f
commit 5ec4790656
2 changed files with 64 additions and 52 deletions

View File

@@ -1,21 +1,21 @@
# Copyright © 2024 Apple Inc.
from pathlib import Path
import argparse
import types
import math
import os
import re
import types
from pathlib import Path
import mlx.nn as nn
import mlx.optimizers as optim
import mlx.nn as nn
import numpy as np
import yaml
from .tuner.grpo_trainer import GRPOTrainingArgs, evaluate_grpo, train_grpo
from .tuner.trainer import TrainingArgs, TrainingCallback, evaluate, train
from .tokenizer_utils import TokenizerWrapper
from .tuner.datasets import load_dataset
from .tuner.trainer import TrainingArgs, TrainingCallback, evaluate, train
from .tuner.grpo_trainer import GRPOTrainingArgs, evaluate_grpo, train_grpo
from .tuner.utils import (
build_schedule,
linear_to_lora_layers,
@@ -73,6 +73,8 @@ CONFIG_DEFAULTS = {
"max_completion_length": 512,
"use_chat_template": False,
"use_prompt": False,
"temperature": 1.0,
"reward_weights": None,
}
@@ -224,6 +226,18 @@ def build_parser():
help="Rather to use the prompt from the R1 paper.",
default=None,
)
parser.add_argument(
"--temperature",
type=float,
help="Temperature for sampling. The higher the temperature, the more random the completions.",
default=1.0,
)
parser.add_argument(
"--reward-weights",
type=str,
help="Weights for each reward function. Must match the number of reward functions and be in this format [0.1, 0.2, 0.3, 0.4, 0.5]. If not given, all rewards are weighted equally with weight `1.0`.",
default=None,
)
return parser
def train_model_grpo(model, tokenizer, args, opt, train_set, valid_set, adapter_file, training_callback):
@@ -241,7 +255,9 @@ def train_model_grpo(model, tokenizer, args, opt, train_set, valid_set, adapter_
beta=args.beta,
group_size=args.group_size,
epsilon=args.epsilon,
reference_model_path=args.reference_model_path
reference_model_path=args.reference_model_path,
temperature=args.temperature,
reward_weights=[float(x) for x in args.reward_weights.strip('[]').split(',')] if args.reward_weights else None
)
if args.reference_model_path: