mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-12-16 02:08:55 +08:00
removing comments + adding temperature + reward weighting
This commit is contained in:
@@ -1,21 +1,21 @@
|
||||
# Copyright © 2024 Apple Inc.
|
||||
|
||||
from pathlib import Path
|
||||
import argparse
|
||||
import types
|
||||
import math
|
||||
import os
|
||||
import re
|
||||
import types
|
||||
from pathlib import Path
|
||||
|
||||
import mlx.nn as nn
|
||||
import mlx.optimizers as optim
|
||||
import mlx.nn as nn
|
||||
import numpy as np
|
||||
import yaml
|
||||
|
||||
from .tuner.grpo_trainer import GRPOTrainingArgs, evaluate_grpo, train_grpo
|
||||
from .tuner.trainer import TrainingArgs, TrainingCallback, evaluate, train
|
||||
from .tokenizer_utils import TokenizerWrapper
|
||||
from .tuner.datasets import load_dataset
|
||||
from .tuner.trainer import TrainingArgs, TrainingCallback, evaluate, train
|
||||
from .tuner.grpo_trainer import GRPOTrainingArgs, evaluate_grpo, train_grpo
|
||||
from .tuner.utils import (
|
||||
build_schedule,
|
||||
linear_to_lora_layers,
|
||||
@@ -73,6 +73,8 @@ CONFIG_DEFAULTS = {
|
||||
"max_completion_length": 512,
|
||||
"use_chat_template": False,
|
||||
"use_prompt": False,
|
||||
"temperature": 1.0,
|
||||
"reward_weights": None,
|
||||
}
|
||||
|
||||
|
||||
@@ -224,6 +226,18 @@ def build_parser():
|
||||
help="Rather to use the prompt from the R1 paper.",
|
||||
default=None,
|
||||
)
|
||||
parser.add_argument(
|
||||
"--temperature",
|
||||
type=float,
|
||||
help="Temperature for sampling. The higher the temperature, the more random the completions.",
|
||||
default=1.0,
|
||||
)
|
||||
parser.add_argument(
|
||||
"--reward-weights",
|
||||
type=str,
|
||||
help="Weights for each reward function. Must match the number of reward functions and be in this format [0.1, 0.2, 0.3, 0.4, 0.5]. If not given, all rewards are weighted equally with weight `1.0`.",
|
||||
default=None,
|
||||
)
|
||||
return parser
|
||||
|
||||
def train_model_grpo(model, tokenizer, args, opt, train_set, valid_set, adapter_file, training_callback):
|
||||
@@ -241,7 +255,9 @@ def train_model_grpo(model, tokenizer, args, opt, train_set, valid_set, adapter_
|
||||
beta=args.beta,
|
||||
group_size=args.group_size,
|
||||
epsilon=args.epsilon,
|
||||
reference_model_path=args.reference_model_path
|
||||
reference_model_path=args.reference_model_path,
|
||||
temperature=args.temperature,
|
||||
reward_weights=[float(x) for x in args.reward_weights.strip('[]').split(',')] if args.reward_weights else None
|
||||
)
|
||||
|
||||
if args.reference_model_path:
|
||||
|
||||
Reference in New Issue
Block a user