initial commit

This commit is contained in:
Goekdeniz-Guelmez
2025-01-19 00:19:36 +01:00
parent 07f88f8057
commit 1ff788821c
3 changed files with 705 additions and 34 deletions

View File

@@ -15,6 +15,7 @@ import yaml
from .tokenizer_utils import TokenizerWrapper
from .tuner.datasets import load_dataset
from .tuner.trainer import TrainingArgs, TrainingCallback, evaluate, train
from .tuner.dpo_trainer import DPOTrainingArgs, evaluate_dpo, train_dpo
from .tuner.utils import (
build_schedule,
linear_to_lora_layers,
@@ -43,6 +44,7 @@ CONFIG_DEFAULTS = {
"model": "mlx_model",
"train": False,
"fine_tune_type": "lora",
"training_mode": "normal",
"data": "data/",
"seed": 0,
"num_layers": 16,
@@ -62,6 +64,12 @@ CONFIG_DEFAULTS = {
"grad_checkpoint": False,
"lr_schedule": None,
"lora_parameters": {"rank": 8, "alpha": 16, "dropout": 0.0, "scale": 10.0},
"beta": 0.1,
"dpo_loss_type": "sigmoid",
"is_reference_free": False,
"delta": 50.0,
"reference_model_path": None,
"train_bias_only": False,
}
@@ -94,6 +102,12 @@ def build_parser():
choices=["lora", "dora", "full"],
help="Type of fine-tuning to perform: lora, dora, or full.",
)
parser.add_argument(
"--training-mode",
type=str,
choices=["normal", "dpo"],
help="Training mode: normal or DPO",
)
parser.add_argument(
"--num-layers",
type=int,
@@ -160,6 +174,12 @@ def build_parser():
help="Use gradient checkpointing to reduce memory use.",
default=None,
)
parser.add_argument("--beta", type=float)
parser.add_argument("--dpo-loss-type", type=str, choices=["sigmoid", "hinge", "ipo", "dpop"])
parser.add_argument("--is-reference-free", action="store_true")
parser.add_argument("--delta", type=float)
parser.add_argument("--reference-model-path", type=str)
parser.add_argument("--train-bias-only", action="store_true")
parser.add_argument("--seed", type=int, help="The PRNG seed")
return parser
@@ -200,19 +220,6 @@ def train_model(
adapter_file = adapter_path / "adapters.safetensors"
save_config(vars(args), adapter_path / "adapter_config.json")
# init training args
training_args = TrainingArgs(
batch_size=args.batch_size,
iters=args.iters,
val_batches=args.val_batches,
steps_per_report=args.steps_per_report,
steps_per_eval=args.steps_per_eval,
steps_per_save=args.save_every,
adapter_file=adapter_file,
max_seq_length=args.max_seq_length,
grad_checkpoint=args.grad_checkpoint,
)
model.train()
opt = optim.Adam(
learning_rate=(
@@ -220,32 +227,99 @@ def train_model(
)
)
# Train model
train(
model=model,
tokenizer=tokenizer,
args=training_args,
optimizer=opt,
train_dataset=train_set,
val_dataset=valid_set,
training_callback=training_callback,
)
if args.training_mode == "dpo":
training_args = DPOTrainingArgs(
batch_size=args.batch_size,
iters=args.iters,
val_batches=args.val_batches,
steps_per_report=args.steps_per_report,
steps_per_eval=args.steps_per_eval,
steps_per_save=args.save_every,
adapter_file=adapter_file,
max_seq_length=args.max_seq_length,
grad_checkpoint=args.grad_checkpoint,
beta=args.beta,
loss_type=args.dpo_loss_type,
is_reference_free=args.is_reference_free,
delta=args.delta,
reference_model_path=args.reference_model_path,
train_bias_only=args.train_bias_only,
)
if args.reference_model_path:
reference_model, _ = load(args.reference_model_path)
else:
reference_model = model
train_dpo(
model=model,
reference_model=reference_model,
tokenizer=tokenizer,
optimizer=opt,
train_dataset=train_set,
val_dataset=valid_set,
args=training_args,
training_callback=training_callback,
)
else:
training_args = TrainingArgs(
batch_size=args.batch_size,
iters=args.iters,
val_batches=args.val_batches,
steps_per_report=args.steps_per_report,
steps_per_eval=args.steps_per_eval,
steps_per_save=args.save_every,
adapter_file=adapter_file,
max_seq_length=args.max_seq_length,
grad_checkpoint=args.grad_checkpoint
)
train(
model=model,
tokenizer=tokenizer,
optimizer=opt,
train_dataset=train_set,
val_dataset=valid_set,
args=training_args,
training_callback=training_callback,
)
def evaluate_model(args, model: nn.Module, tokenizer: TokenizerWrapper, test_set):
model.eval()
test_loss = evaluate(
model=model,
dataset=test_set,
tokenizer=tokenizer,
batch_size=args.batch_size,
num_batches=args.test_batches,
max_seq_length=args.max_seq_length,
)
if args.training_mode == "dpo":
if args.reference_model_path:
reference_model, _ = load(args.reference_model_path)
else:
reference_model = model
test_ppl = math.exp(test_loss)
test_loss, test_rewards = evaluate_dpo(
model=model,
reference_model=reference_model,
dataset=test_set,
tokenizer=tokenizer,
batch_size=args.batch_size,
num_batches=args.test_batches,
max_seq_length=args.max_seq_length,
beta=args.beta,
delta=args.delta,
loss_type=args.loss_type,
)
print(f"Test loss {test_loss:.3f}, Rewards: {test_rewards[0]:.3f}, {test_rewards[1]:.3f}")
else:
test_loss = evaluate(
model=model,
dataset=test_set,
tokenizer=tokenizer,
batch_size=args.batch_size,
num_batches=args.test_batches,
max_seq_length=args.max_seq_length,
)
print(f"Test loss {test_loss:.3f}, Test ppl {test_ppl:.3f}.")
test_ppl = math.exp(test_loss)
print(f"Test loss {test_loss:.3f}, Test ppl {test_ppl:.3f}.")
def run(args, training_callback: TrainingCallback = None):
@@ -263,7 +337,7 @@ def run(args, training_callback: TrainingCallback = None):
load_adapters(model, args.adapter_path)
elif args.train:
print("Training")
print(f"Training in {args.training_mode} mode")
train_model(args, model, tokenizer, train_set, valid_set, training_callback)
else:
raise ValueError("Must provide at least one of --train or --test")