mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-12-16 02:08:55 +08:00
first succesfull training run
This commit is contained in:
@@ -63,11 +63,16 @@ CONFIG_DEFAULTS = {
|
||||
"config": None,
|
||||
"grad_checkpoint": False,
|
||||
"lr_schedule": None,
|
||||
"lora_parameters": {"rank": 8, "alpha": 16, "dropout": 0.0, "scale": 10.0},
|
||||
|
||||
# GRPO args
|
||||
"reference_model_path": None,
|
||||
"group_size": 4,
|
||||
"beta": 0.1,
|
||||
"epsilon": 1e-4,
|
||||
"lora_parameters": {"rank": 8, "alpha": 16, "dropout": 0.0, "scale": 10.0},
|
||||
"max_completion_length": 512,
|
||||
"use_chat_template": False,
|
||||
"use_prompt": False,
|
||||
}
|
||||
|
||||
|
||||
@@ -178,9 +183,15 @@ def build_parser():
|
||||
parser.add_argument(
|
||||
"--group-size",
|
||||
type=int,
|
||||
help="Number of responses per prompt.",
|
||||
help="Number of generations.",
|
||||
default=4,
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max-completion-length",
|
||||
type=int,
|
||||
help="Maximum length of the prompt. If the prompt is longer than this value, it will be truncated left.",
|
||||
default=512,
|
||||
)
|
||||
parser.add_argument(
|
||||
"--beta",
|
||||
type=float,
|
||||
@@ -193,6 +204,18 @@ def build_parser():
|
||||
help="The Epsilon for numerical stability.",
|
||||
default=1e-4,
|
||||
)
|
||||
parser.add_argument(
|
||||
"--use-chat-template",
|
||||
type=bool,
|
||||
help="If the model is a Chat model, use the Chat template.",
|
||||
default=False,
|
||||
)
|
||||
parser.add_argument(
|
||||
"--use-prompt",
|
||||
type=bool,
|
||||
help="Rather to use the prompt from teh R1 paper.",
|
||||
default=False,
|
||||
)
|
||||
return parser
|
||||
|
||||
|
||||
@@ -262,6 +285,7 @@ def train_model(
|
||||
steps_per_save=args.save_every,
|
||||
adapter_file=adapter_file,
|
||||
max_seq_length=args.max_seq_length,
|
||||
max_completion_length=args.max_completion_length,
|
||||
grad_checkpoint=args.grad_checkpoint,
|
||||
beta=args.beta,
|
||||
group_size=args.group_size,
|
||||
@@ -273,7 +297,7 @@ def train_model(
|
||||
reference_model, _ = load(args.reference_model_path)
|
||||
reference_model = reference_model.freeze()
|
||||
else:
|
||||
reference_model, _ = None, None
|
||||
reference_model, _ = load(args.model)
|
||||
|
||||
train_grpo(
|
||||
model=model,
|
||||
|
||||
Reference in New Issue
Block a user