mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-27 03:05:20 +08:00
update lora.py
This commit is contained in:
parent
a57d553fc1
commit
243c9621d9
@ -43,6 +43,7 @@ yaml_loader.add_implicit_resolver(
|
|||||||
CONFIG_DEFAULTS = {
|
CONFIG_DEFAULTS = {
|
||||||
"model": "mlx_model",
|
"model": "mlx_model",
|
||||||
"train": False,
|
"train": False,
|
||||||
|
"training_mode": "normal",
|
||||||
"fine_tune_type": "lora",
|
"fine_tune_type": "lora",
|
||||||
"data": "data/",
|
"data": "data/",
|
||||||
"seed": 0,
|
"seed": 0,
|
||||||
@ -62,6 +63,10 @@ CONFIG_DEFAULTS = {
|
|||||||
"config": None,
|
"config": None,
|
||||||
"grad_checkpoint": False,
|
"grad_checkpoint": False,
|
||||||
"lr_schedule": None,
|
"lr_schedule": None,
|
||||||
|
"reference_model_path": None,
|
||||||
|
"group_size": 4,
|
||||||
|
"beta": 0.1,
|
||||||
|
"epsilon": 1e-4,
|
||||||
"lora_parameters": {"rank": 8, "alpha": 16, "dropout": 0.0, "scale": 10.0},
|
"lora_parameters": {"rank": 8, "alpha": 16, "dropout": 0.0, "scale": 10.0},
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -95,6 +100,12 @@ def build_parser():
|
|||||||
choices=["lora", "dora", "full"],
|
choices=["lora", "dora", "full"],
|
||||||
help="Type of fine-tuning to perform: lora, dora, or full.",
|
help="Type of fine-tuning to perform: lora, dora, or full.",
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--training-mode",
|
||||||
|
type=str,
|
||||||
|
choices=["normal", "grpo"],
|
||||||
|
help="Training mode: normal or GRPO",
|
||||||
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--num-layers",
|
"--num-layers",
|
||||||
type=int,
|
type=int,
|
||||||
@ -162,6 +173,25 @@ def build_parser():
|
|||||||
default=None,
|
default=None,
|
||||||
)
|
)
|
||||||
parser.add_argument("--seed", type=int, help="The PRNG seed")
|
parser.add_argument("--seed", type=int, help="The PRNG seed")
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--group-size",
|
||||||
|
type=int,
|
||||||
|
help="Number of responses per prompt.",
|
||||||
|
default=4,
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--beta",
|
||||||
|
type=float,
|
||||||
|
help="KL penalty coefficient.",
|
||||||
|
default=0.1,
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--epsilon",
|
||||||
|
type=float,
|
||||||
|
help="The Epsilon for numerical stability.",
|
||||||
|
default=1e-4,
|
||||||
|
)
|
||||||
return parser
|
return parser
|
||||||
|
|
||||||
|
|
||||||
@ -221,32 +251,98 @@ def train_model(
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
# Train model
|
# Train model
|
||||||
train(
|
if args.training_mode == "grpo":
|
||||||
model=model,
|
training_args = GRPOTrainingArgs(
|
||||||
tokenizer=tokenizer,
|
batch_size=args.batch_size,
|
||||||
args=training_args,
|
iters=args.iters,
|
||||||
optimizer=opt,
|
val_batches=args.val_batches,
|
||||||
train_dataset=train_set,
|
steps_per_report=args.steps_per_report,
|
||||||
val_dataset=valid_set,
|
steps_per_eval=args.steps_per_eval,
|
||||||
training_callback=training_callback,
|
steps_per_save=args.save_every,
|
||||||
)
|
adapter_file=adapter_file,
|
||||||
|
max_seq_length=args.max_seq_length,
|
||||||
|
grad_checkpoint=args.grad_checkpoint,
|
||||||
|
beta=args.beta,
|
||||||
|
group_size=args.group_size,
|
||||||
|
epsilon=args.epsilon,
|
||||||
|
reference_model_path=args.reference_model_path
|
||||||
|
)
|
||||||
|
|
||||||
|
if args.reference_model_path:
|
||||||
|
reference_model, _ = load(args.reference_model_path)
|
||||||
|
else:
|
||||||
|
reference_model, _ = load(args.model)
|
||||||
|
|
||||||
|
train_grpo(
|
||||||
|
model=model,
|
||||||
|
reference_model=reference_model.freeze(),
|
||||||
|
tokenizer=tokenizer,
|
||||||
|
optimizer=opt,
|
||||||
|
train_dataset=train_set,
|
||||||
|
val_dataset=valid_set,
|
||||||
|
args=training_args,
|
||||||
|
training_callback=training_callback,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
training_args = TrainingArgs(
|
||||||
|
batch_size=args.batch_size,
|
||||||
|
iters=args.iters,
|
||||||
|
val_batches=args.val_batches,
|
||||||
|
steps_per_report=args.steps_per_report,
|
||||||
|
steps_per_eval=args.steps_per_eval,
|
||||||
|
steps_per_save=args.save_every,
|
||||||
|
adapter_file=adapter_file,
|
||||||
|
max_seq_length=args.max_seq_length,
|
||||||
|
grad_checkpoint=args.grad_checkpoint
|
||||||
|
)
|
||||||
|
|
||||||
|
train(
|
||||||
|
model=model,
|
||||||
|
tokenizer=tokenizer,
|
||||||
|
args=training_args,
|
||||||
|
optimizer=opt,
|
||||||
|
train_dataset=train_set,
|
||||||
|
val_dataset=valid_set,
|
||||||
|
training_callback=training_callback,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def evaluate_model(args, model: nn.Module, tokenizer: TokenizerWrapper, test_set):
|
def evaluate_model(args, model: nn.Module, tokenizer: TokenizerWrapper, test_set):
|
||||||
model.eval()
|
model.eval()
|
||||||
|
|
||||||
test_loss = evaluate(
|
if args.training_mode == "grpo":
|
||||||
model=model,
|
if args.reference_model_path:
|
||||||
dataset=test_set,
|
reference_model, _ = load(args.reference_model_path)
|
||||||
tokenizer=tokenizer,
|
else:
|
||||||
batch_size=args.batch_size,
|
reference_model = model
|
||||||
num_batches=args.test_batches,
|
|
||||||
max_seq_length=args.max_seq_length,
|
|
||||||
)
|
|
||||||
|
|
||||||
test_ppl = math.exp(test_loss)
|
test_loss, test_rewards = evaluate_grpo(
|
||||||
|
model=model,
|
||||||
|
reference_model=reference_model,
|
||||||
|
dataset=test_set,
|
||||||
|
tokenizer=tokenizer,
|
||||||
|
batch_size=args.batch_size,
|
||||||
|
num_batches=args.test_batches,
|
||||||
|
max_seq_length=args.max_seq_length,
|
||||||
|
beta=args.beta,
|
||||||
|
group_size=args.group_size,
|
||||||
|
epsilon=args.epsilon,
|
||||||
|
reference_model_path=args.reference_model_path
|
||||||
|
)
|
||||||
|
print(f"Test loss {test_loss:.3f}, Rewards: {test_rewards[0]:.3f}, {test_rewards[1]:.3f}")
|
||||||
|
else:
|
||||||
|
test_loss = evaluate(
|
||||||
|
model=model,
|
||||||
|
dataset=test_set,
|
||||||
|
tokenizer=tokenizer,
|
||||||
|
batch_size=args.batch_size,
|
||||||
|
num_batches=args.test_batches,
|
||||||
|
max_seq_length=args.max_seq_length,
|
||||||
|
)
|
||||||
|
|
||||||
print(f"Test loss {test_loss:.3f}, Test ppl {test_ppl:.3f}.")
|
test_ppl = math.exp(test_loss)
|
||||||
|
|
||||||
|
print(f"Test loss {test_loss:.3f}, Test ppl {test_ppl:.3f}.")
|
||||||
|
|
||||||
|
|
||||||
def run(args, training_callback: TrainingCallback = None):
|
def run(args, training_callback: TrainingCallback = None):
|
||||||
|
@ -22,13 +22,7 @@ generate()
|
|||||||
class GRPOTrainingArgs(TrainingArgs):
|
class GRPOTrainingArgs(TrainingArgs):
|
||||||
group_size: int = field(
|
group_size: int = field(
|
||||||
default=4,
|
default=4,
|
||||||
metadata={"help": "Number of response sper prompt."},
|
metadata={"help": "Number of responses per prompt."},
|
||||||
)
|
|
||||||
is_reference_free: bool = field(
|
|
||||||
default=False,
|
|
||||||
metadata={
|
|
||||||
"help": "Whether to use reference-free DPO training."
|
|
||||||
}
|
|
||||||
)
|
)
|
||||||
beta: float = field(
|
beta: float = field(
|
||||||
default=0.1, metadata={"help": "KL penalty coefficient."}
|
default=0.1, metadata={"help": "KL penalty coefficient."}
|
||||||
|
Loading…
Reference in New Issue
Block a user