mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-26 10:41:18 +08:00
update lora.py
This commit is contained in:
parent
a57d553fc1
commit
243c9621d9
@ -43,6 +43,7 @@ yaml_loader.add_implicit_resolver(
|
||||
CONFIG_DEFAULTS = {
|
||||
"model": "mlx_model",
|
||||
"train": False,
|
||||
"training_mode": "normal",
|
||||
"fine_tune_type": "lora",
|
||||
"data": "data/",
|
||||
"seed": 0,
|
||||
@ -62,6 +63,10 @@ CONFIG_DEFAULTS = {
|
||||
"config": None,
|
||||
"grad_checkpoint": False,
|
||||
"lr_schedule": None,
|
||||
"reference_model_path": None,
|
||||
"group_size": 4,
|
||||
"beta": 0.1,
|
||||
"epsilon": 1e-4,
|
||||
"lora_parameters": {"rank": 8, "alpha": 16, "dropout": 0.0, "scale": 10.0},
|
||||
}
|
||||
|
||||
@ -95,6 +100,12 @@ def build_parser():
|
||||
choices=["lora", "dora", "full"],
|
||||
help="Type of fine-tuning to perform: lora, dora, or full.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--training-mode",
|
||||
type=str,
|
||||
choices=["normal", "grpo"],
|
||||
help="Training mode: normal or GRPO",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--num-layers",
|
||||
type=int,
|
||||
@ -162,6 +173,25 @@ def build_parser():
|
||||
default=None,
|
||||
)
|
||||
parser.add_argument("--seed", type=int, help="The PRNG seed")
|
||||
|
||||
parser.add_argument(
|
||||
"--group-size",
|
||||
type=int,
|
||||
help="Number of responses per prompt.",
|
||||
default=4,
|
||||
)
|
||||
parser.add_argument(
|
||||
"--beta",
|
||||
type=float,
|
||||
help="KL penalty coefficient.",
|
||||
default=0.1,
|
||||
)
|
||||
parser.add_argument(
|
||||
"--epsilon",
|
||||
type=float,
|
||||
help="The Epsilon for numerical stability.",
|
||||
default=1e-4,
|
||||
)
|
||||
return parser
|
||||
|
||||
|
||||
@ -221,32 +251,98 @@ def train_model(
|
||||
)
|
||||
)
|
||||
# Train model
|
||||
train(
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
args=training_args,
|
||||
optimizer=opt,
|
||||
train_dataset=train_set,
|
||||
val_dataset=valid_set,
|
||||
training_callback=training_callback,
|
||||
)
|
||||
if args.training_mode == "grpo":
|
||||
training_args = GRPOTrainingArgs(
|
||||
batch_size=args.batch_size,
|
||||
iters=args.iters,
|
||||
val_batches=args.val_batches,
|
||||
steps_per_report=args.steps_per_report,
|
||||
steps_per_eval=args.steps_per_eval,
|
||||
steps_per_save=args.save_every,
|
||||
adapter_file=adapter_file,
|
||||
max_seq_length=args.max_seq_length,
|
||||
grad_checkpoint=args.grad_checkpoint,
|
||||
beta=args.beta,
|
||||
group_size=args.group_size,
|
||||
epsilon=args.epsilon,
|
||||
reference_model_path=args.reference_model_path
|
||||
)
|
||||
|
||||
if args.reference_model_path:
|
||||
reference_model, _ = load(args.reference_model_path)
|
||||
else:
|
||||
reference_model, _ = load(args.model)
|
||||
|
||||
train_grpo(
|
||||
model=model,
|
||||
reference_model=reference_model.freeze(),
|
||||
tokenizer=tokenizer,
|
||||
optimizer=opt,
|
||||
train_dataset=train_set,
|
||||
val_dataset=valid_set,
|
||||
args=training_args,
|
||||
training_callback=training_callback,
|
||||
)
|
||||
else:
|
||||
training_args = TrainingArgs(
|
||||
batch_size=args.batch_size,
|
||||
iters=args.iters,
|
||||
val_batches=args.val_batches,
|
||||
steps_per_report=args.steps_per_report,
|
||||
steps_per_eval=args.steps_per_eval,
|
||||
steps_per_save=args.save_every,
|
||||
adapter_file=adapter_file,
|
||||
max_seq_length=args.max_seq_length,
|
||||
grad_checkpoint=args.grad_checkpoint
|
||||
)
|
||||
|
||||
train(
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
args=training_args,
|
||||
optimizer=opt,
|
||||
train_dataset=train_set,
|
||||
val_dataset=valid_set,
|
||||
training_callback=training_callback,
|
||||
)
|
||||
|
||||
|
||||
def evaluate_model(args, model: nn.Module, tokenizer: TokenizerWrapper, test_set):
|
||||
model.eval()
|
||||
|
||||
test_loss = evaluate(
|
||||
model=model,
|
||||
dataset=test_set,
|
||||
tokenizer=tokenizer,
|
||||
batch_size=args.batch_size,
|
||||
num_batches=args.test_batches,
|
||||
max_seq_length=args.max_seq_length,
|
||||
)
|
||||
if args.training_mode == "grpo":
|
||||
if args.reference_model_path:
|
||||
reference_model, _ = load(args.reference_model_path)
|
||||
else:
|
||||
reference_model = model
|
||||
|
||||
test_ppl = math.exp(test_loss)
|
||||
test_loss, test_rewards = evaluate_grpo(
|
||||
model=model,
|
||||
reference_model=reference_model,
|
||||
dataset=test_set,
|
||||
tokenizer=tokenizer,
|
||||
batch_size=args.batch_size,
|
||||
num_batches=args.test_batches,
|
||||
max_seq_length=args.max_seq_length,
|
||||
beta=args.beta,
|
||||
group_size=args.group_size,
|
||||
epsilon=args.epsilon,
|
||||
reference_model_path=args.reference_model_path
|
||||
)
|
||||
print(f"Test loss {test_loss:.3f}, Rewards: {test_rewards[0]:.3f}, {test_rewards[1]:.3f}")
|
||||
else:
|
||||
test_loss = evaluate(
|
||||
model=model,
|
||||
dataset=test_set,
|
||||
tokenizer=tokenizer,
|
||||
batch_size=args.batch_size,
|
||||
num_batches=args.test_batches,
|
||||
max_seq_length=args.max_seq_length,
|
||||
)
|
||||
|
||||
print(f"Test loss {test_loss:.3f}, Test ppl {test_ppl:.3f}.")
|
||||
test_ppl = math.exp(test_loss)
|
||||
|
||||
print(f"Test loss {test_loss:.3f}, Test ppl {test_ppl:.3f}.")
|
||||
|
||||
|
||||
def run(args, training_callback: TrainingCallback = None):
|
||||
@ -297,4 +393,4 @@ def main():
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
main()
|
@ -22,13 +22,7 @@ generate()
|
||||
class GRPOTrainingArgs(TrainingArgs):
|
||||
group_size: int = field(
|
||||
default=4,
|
||||
metadata={"help": "Number of response sper prompt."},
|
||||
)
|
||||
is_reference_free: bool = field(
|
||||
default=False,
|
||||
metadata={
|
||||
"help": "Whether to use reference-free DPO training."
|
||||
}
|
||||
metadata={"help": "Number of responses per prompt."},
|
||||
)
|
||||
beta: float = field(
|
||||
default=0.1, metadata={"help": "KL penalty coefficient."}
|
||||
|
Loading…
Reference in New Issue
Block a user