update lora.py

This commit is contained in:
Goekdeniz-Guelmez 2025-01-31 21:10:44 +01:00
parent a57d553fc1
commit 243c9621d9
2 changed files with 117 additions and 27 deletions

View File

@ -43,6 +43,7 @@ yaml_loader.add_implicit_resolver(
CONFIG_DEFAULTS = { CONFIG_DEFAULTS = {
"model": "mlx_model", "model": "mlx_model",
"train": False, "train": False,
"training_mode": "normal",
"fine_tune_type": "lora", "fine_tune_type": "lora",
"data": "data/", "data": "data/",
"seed": 0, "seed": 0,
@ -62,6 +63,10 @@ CONFIG_DEFAULTS = {
"config": None, "config": None,
"grad_checkpoint": False, "grad_checkpoint": False,
"lr_schedule": None, "lr_schedule": None,
"reference_model_path": None,
"group_size": 4,
"beta": 0.1,
"epsilon": 1e-4,
"lora_parameters": {"rank": 8, "alpha": 16, "dropout": 0.0, "scale": 10.0}, "lora_parameters": {"rank": 8, "alpha": 16, "dropout": 0.0, "scale": 10.0},
} }
@ -95,6 +100,12 @@ def build_parser():
choices=["lora", "dora", "full"], choices=["lora", "dora", "full"],
help="Type of fine-tuning to perform: lora, dora, or full.", help="Type of fine-tuning to perform: lora, dora, or full.",
) )
parser.add_argument(
"--training-mode",
type=str,
choices=["normal", "grpo"],
help="Training mode: normal or GRPO",
)
parser.add_argument( parser.add_argument(
"--num-layers", "--num-layers",
type=int, type=int,
@ -162,6 +173,25 @@ def build_parser():
default=None, default=None,
) )
parser.add_argument("--seed", type=int, help="The PRNG seed") parser.add_argument("--seed", type=int, help="The PRNG seed")
parser.add_argument(
"--group-size",
type=int,
help="Number of responses per prompt.",
default=4,
)
parser.add_argument(
"--beta",
type=float,
help="KL penalty coefficient.",
default=0.1,
)
parser.add_argument(
"--epsilon",
type=float,
help="The Epsilon for numerical stability.",
default=1e-4,
)
return parser return parser
@ -221,32 +251,98 @@ def train_model(
) )
) )
# Train model # Train model
train( if args.training_mode == "grpo":
model=model, training_args = GRPOTrainingArgs(
tokenizer=tokenizer, batch_size=args.batch_size,
args=training_args, iters=args.iters,
optimizer=opt, val_batches=args.val_batches,
train_dataset=train_set, steps_per_report=args.steps_per_report,
val_dataset=valid_set, steps_per_eval=args.steps_per_eval,
training_callback=training_callback, steps_per_save=args.save_every,
) adapter_file=adapter_file,
max_seq_length=args.max_seq_length,
grad_checkpoint=args.grad_checkpoint,
beta=args.beta,
group_size=args.group_size,
epsilon=args.epsilon,
reference_model_path=args.reference_model_path
)
if args.reference_model_path:
reference_model, _ = load(args.reference_model_path)
else:
reference_model, _ = load(args.model)
train_grpo(
model=model,
reference_model=reference_model.freeze(),
tokenizer=tokenizer,
optimizer=opt,
train_dataset=train_set,
val_dataset=valid_set,
args=training_args,
training_callback=training_callback,
)
else:
training_args = TrainingArgs(
batch_size=args.batch_size,
iters=args.iters,
val_batches=args.val_batches,
steps_per_report=args.steps_per_report,
steps_per_eval=args.steps_per_eval,
steps_per_save=args.save_every,
adapter_file=adapter_file,
max_seq_length=args.max_seq_length,
grad_checkpoint=args.grad_checkpoint
)
train(
model=model,
tokenizer=tokenizer,
args=training_args,
optimizer=opt,
train_dataset=train_set,
val_dataset=valid_set,
training_callback=training_callback,
)
def evaluate_model(args, model: nn.Module, tokenizer: TokenizerWrapper, test_set): def evaluate_model(args, model: nn.Module, tokenizer: TokenizerWrapper, test_set):
model.eval() model.eval()
test_loss = evaluate( if args.training_mode == "grpo":
model=model, if args.reference_model_path:
dataset=test_set, reference_model, _ = load(args.reference_model_path)
tokenizer=tokenizer, else:
batch_size=args.batch_size, reference_model = model
num_batches=args.test_batches,
max_seq_length=args.max_seq_length,
)
test_ppl = math.exp(test_loss) test_loss, test_rewards = evaluate_grpo(
model=model,
reference_model=reference_model,
dataset=test_set,
tokenizer=tokenizer,
batch_size=args.batch_size,
num_batches=args.test_batches,
max_seq_length=args.max_seq_length,
beta=args.beta,
group_size=args.group_size,
epsilon=args.epsilon,
reference_model_path=args.reference_model_path
)
print(f"Test loss {test_loss:.3f}, Rewards: {test_rewards[0]:.3f}, {test_rewards[1]:.3f}")
else:
test_loss = evaluate(
model=model,
dataset=test_set,
tokenizer=tokenizer,
batch_size=args.batch_size,
num_batches=args.test_batches,
max_seq_length=args.max_seq_length,
)
print(f"Test loss {test_loss:.3f}, Test ppl {test_ppl:.3f}.") test_ppl = math.exp(test_loss)
print(f"Test loss {test_loss:.3f}, Test ppl {test_ppl:.3f}.")
def run(args, training_callback: TrainingCallback = None): def run(args, training_callback: TrainingCallback = None):

View File

@ -22,13 +22,7 @@ generate()
class GRPOTrainingArgs(TrainingArgs): class GRPOTrainingArgs(TrainingArgs):
group_size: int = field( group_size: int = field(
default=4, default=4,
metadata={"help": "Number of response sper prompt."}, metadata={"help": "Number of responses per prompt."},
)
is_reference_free: bool = field(
default=False,
metadata={
"help": "Whether to use reference-free DPO training."
}
) )
beta: float = field( beta: float = field(
default=0.1, metadata={"help": "KL penalty coefficient."} default=0.1, metadata={"help": "KL penalty coefficient."}