update lora.py

This commit is contained in:
Goekdeniz-Guelmez 2025-01-31 21:10:44 +01:00
parent a57d553fc1
commit 243c9621d9
2 changed files with 117 additions and 27 deletions

View File

@ -43,6 +43,7 @@ yaml_loader.add_implicit_resolver(
CONFIG_DEFAULTS = {
"model": "mlx_model",
"train": False,
"training_mode": "normal",
"fine_tune_type": "lora",
"data": "data/",
"seed": 0,
@ -62,6 +63,10 @@ CONFIG_DEFAULTS = {
"config": None,
"grad_checkpoint": False,
"lr_schedule": None,
"reference_model_path": None,
"group_size": 4,
"beta": 0.1,
"epsilon": 1e-4,
"lora_parameters": {"rank": 8, "alpha": 16, "dropout": 0.0, "scale": 10.0},
}
@ -95,6 +100,12 @@ def build_parser():
choices=["lora", "dora", "full"],
help="Type of fine-tuning to perform: lora, dora, or full.",
)
parser.add_argument(
"--training-mode",
type=str,
choices=["normal", "grpo"],
help="Training mode: normal or GRPO",
)
parser.add_argument(
"--num-layers",
type=int,
@ -162,6 +173,25 @@ def build_parser():
default=None,
)
parser.add_argument("--seed", type=int, help="The PRNG seed")
parser.add_argument(
"--group-size",
type=int,
help="Number of responses per prompt.",
default=4,
)
parser.add_argument(
"--beta",
type=float,
help="KL penalty coefficient.",
default=0.1,
)
parser.add_argument(
"--epsilon",
type=float,
help="The Epsilon for numerical stability.",
default=1e-4,
)
return parser
@ -221,32 +251,98 @@ def train_model(
)
)
# Train model
train(
model=model,
tokenizer=tokenizer,
args=training_args,
optimizer=opt,
train_dataset=train_set,
val_dataset=valid_set,
training_callback=training_callback,
)
if args.training_mode == "grpo":
training_args = GRPOTrainingArgs(
batch_size=args.batch_size,
iters=args.iters,
val_batches=args.val_batches,
steps_per_report=args.steps_per_report,
steps_per_eval=args.steps_per_eval,
steps_per_save=args.save_every,
adapter_file=adapter_file,
max_seq_length=args.max_seq_length,
grad_checkpoint=args.grad_checkpoint,
beta=args.beta,
group_size=args.group_size,
epsilon=args.epsilon,
reference_model_path=args.reference_model_path
)
if args.reference_model_path:
reference_model, _ = load(args.reference_model_path)
else:
reference_model, _ = load(args.model)
train_grpo(
model=model,
reference_model=reference_model.freeze(),
tokenizer=tokenizer,
optimizer=opt,
train_dataset=train_set,
val_dataset=valid_set,
args=training_args,
training_callback=training_callback,
)
else:
training_args = TrainingArgs(
batch_size=args.batch_size,
iters=args.iters,
val_batches=args.val_batches,
steps_per_report=args.steps_per_report,
steps_per_eval=args.steps_per_eval,
steps_per_save=args.save_every,
adapter_file=adapter_file,
max_seq_length=args.max_seq_length,
grad_checkpoint=args.grad_checkpoint
)
train(
model=model,
tokenizer=tokenizer,
args=training_args,
optimizer=opt,
train_dataset=train_set,
val_dataset=valid_set,
training_callback=training_callback,
)
def evaluate_model(args, model: nn.Module, tokenizer: TokenizerWrapper, test_set):
model.eval()
test_loss = evaluate(
model=model,
dataset=test_set,
tokenizer=tokenizer,
batch_size=args.batch_size,
num_batches=args.test_batches,
max_seq_length=args.max_seq_length,
)
if args.training_mode == "grpo":
if args.reference_model_path:
reference_model, _ = load(args.reference_model_path)
else:
reference_model = model
test_ppl = math.exp(test_loss)
test_loss, test_rewards = evaluate_grpo(
model=model,
reference_model=reference_model,
dataset=test_set,
tokenizer=tokenizer,
batch_size=args.batch_size,
num_batches=args.test_batches,
max_seq_length=args.max_seq_length,
beta=args.beta,
group_size=args.group_size,
epsilon=args.epsilon,
reference_model_path=args.reference_model_path
)
print(f"Test loss {test_loss:.3f}, Rewards: {test_rewards[0]:.3f}, {test_rewards[1]:.3f}")
else:
test_loss = evaluate(
model=model,
dataset=test_set,
tokenizer=tokenizer,
batch_size=args.batch_size,
num_batches=args.test_batches,
max_seq_length=args.max_seq_length,
)
print(f"Test loss {test_loss:.3f}, Test ppl {test_ppl:.3f}.")
test_ppl = math.exp(test_loss)
print(f"Test loss {test_loss:.3f}, Test ppl {test_ppl:.3f}.")
def run(args, training_callback: TrainingCallback = None):

View File

@ -22,13 +22,7 @@ generate()
class GRPOTrainingArgs(TrainingArgs):
group_size: int = field(
default=4,
metadata={"help": "Number of response sper prompt."},
)
is_reference_free: bool = field(
default=False,
metadata={
"help": "Whether to use reference-free DPO training."
}
metadata={"help": "Number of responses per prompt."},
)
beta: float = field(
default=0.1, metadata={"help": "KL penalty coefficient."}