mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-12-16 02:08:55 +08:00
initial commit
This commit is contained in:
@@ -15,6 +15,7 @@ import yaml
|
||||
from .tokenizer_utils import TokenizerWrapper
|
||||
from .tuner.datasets import load_dataset
|
||||
from .tuner.trainer import TrainingArgs, TrainingCallback, evaluate, train
|
||||
from .tuner.dpo_trainer import DPOTrainingArgs, evaluate_dpo, train_dpo
|
||||
from .tuner.utils import (
|
||||
build_schedule,
|
||||
linear_to_lora_layers,
|
||||
@@ -43,6 +44,7 @@ CONFIG_DEFAULTS = {
|
||||
"model": "mlx_model",
|
||||
"train": False,
|
||||
"fine_tune_type": "lora",
|
||||
"training_mode": "normal",
|
||||
"data": "data/",
|
||||
"seed": 0,
|
||||
"num_layers": 16,
|
||||
@@ -62,6 +64,12 @@ CONFIG_DEFAULTS = {
|
||||
"grad_checkpoint": False,
|
||||
"lr_schedule": None,
|
||||
"lora_parameters": {"rank": 8, "alpha": 16, "dropout": 0.0, "scale": 10.0},
|
||||
"beta": 0.1,
|
||||
"dpo_loss_type": "sigmoid",
|
||||
"is_reference_free": False,
|
||||
"delta": 50.0,
|
||||
"reference_model_path": None,
|
||||
"train_bias_only": False,
|
||||
}
|
||||
|
||||
|
||||
@@ -94,6 +102,12 @@ def build_parser():
|
||||
choices=["lora", "dora", "full"],
|
||||
help="Type of fine-tuning to perform: lora, dora, or full.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--training-mode",
|
||||
type=str,
|
||||
choices=["normal", "dpo"],
|
||||
help="Training mode: normal or DPO",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--num-layers",
|
||||
type=int,
|
||||
@@ -160,6 +174,12 @@ def build_parser():
|
||||
help="Use gradient checkpointing to reduce memory use.",
|
||||
default=None,
|
||||
)
|
||||
parser.add_argument("--beta", type=float)
|
||||
parser.add_argument("--dpo-loss-type", type=str, choices=["sigmoid", "hinge", "ipo", "dpop"])
|
||||
parser.add_argument("--is-reference-free", action="store_true")
|
||||
parser.add_argument("--delta", type=float)
|
||||
parser.add_argument("--reference-model-path", type=str)
|
||||
parser.add_argument("--train-bias-only", action="store_true")
|
||||
parser.add_argument("--seed", type=int, help="The PRNG seed")
|
||||
return parser
|
||||
|
||||
@@ -200,19 +220,6 @@ def train_model(
|
||||
adapter_file = adapter_path / "adapters.safetensors"
|
||||
save_config(vars(args), adapter_path / "adapter_config.json")
|
||||
|
||||
# init training args
|
||||
training_args = TrainingArgs(
|
||||
batch_size=args.batch_size,
|
||||
iters=args.iters,
|
||||
val_batches=args.val_batches,
|
||||
steps_per_report=args.steps_per_report,
|
||||
steps_per_eval=args.steps_per_eval,
|
||||
steps_per_save=args.save_every,
|
||||
adapter_file=adapter_file,
|
||||
max_seq_length=args.max_seq_length,
|
||||
grad_checkpoint=args.grad_checkpoint,
|
||||
)
|
||||
|
||||
model.train()
|
||||
opt = optim.Adam(
|
||||
learning_rate=(
|
||||
@@ -220,32 +227,99 @@ def train_model(
|
||||
)
|
||||
)
|
||||
# Train model
|
||||
train(
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
args=training_args,
|
||||
optimizer=opt,
|
||||
train_dataset=train_set,
|
||||
val_dataset=valid_set,
|
||||
training_callback=training_callback,
|
||||
)
|
||||
if args.training_mode == "dpo":
|
||||
training_args = DPOTrainingArgs(
|
||||
batch_size=args.batch_size,
|
||||
iters=args.iters,
|
||||
val_batches=args.val_batches,
|
||||
steps_per_report=args.steps_per_report,
|
||||
steps_per_eval=args.steps_per_eval,
|
||||
steps_per_save=args.save_every,
|
||||
adapter_file=adapter_file,
|
||||
max_seq_length=args.max_seq_length,
|
||||
grad_checkpoint=args.grad_checkpoint,
|
||||
beta=args.beta,
|
||||
loss_type=args.dpo_loss_type,
|
||||
is_reference_free=args.is_reference_free,
|
||||
delta=args.delta,
|
||||
reference_model_path=args.reference_model_path,
|
||||
train_bias_only=args.train_bias_only,
|
||||
)
|
||||
|
||||
if args.reference_model_path:
|
||||
reference_model, _ = load(args.reference_model_path)
|
||||
else:
|
||||
reference_model = model
|
||||
|
||||
train_dpo(
|
||||
model=model,
|
||||
reference_model=reference_model,
|
||||
tokenizer=tokenizer,
|
||||
optimizer=opt,
|
||||
train_dataset=train_set,
|
||||
val_dataset=valid_set,
|
||||
args=training_args,
|
||||
training_callback=training_callback,
|
||||
)
|
||||
else:
|
||||
training_args = TrainingArgs(
|
||||
batch_size=args.batch_size,
|
||||
iters=args.iters,
|
||||
val_batches=args.val_batches,
|
||||
steps_per_report=args.steps_per_report,
|
||||
steps_per_eval=args.steps_per_eval,
|
||||
steps_per_save=args.save_every,
|
||||
adapter_file=adapter_file,
|
||||
max_seq_length=args.max_seq_length,
|
||||
grad_checkpoint=args.grad_checkpoint
|
||||
)
|
||||
|
||||
train(
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
optimizer=opt,
|
||||
train_dataset=train_set,
|
||||
val_dataset=valid_set,
|
||||
args=training_args,
|
||||
training_callback=training_callback,
|
||||
)
|
||||
|
||||
|
||||
def evaluate_model(args, model: nn.Module, tokenizer: TokenizerWrapper, test_set):
|
||||
model.eval()
|
||||
|
||||
test_loss = evaluate(
|
||||
model=model,
|
||||
dataset=test_set,
|
||||
tokenizer=tokenizer,
|
||||
batch_size=args.batch_size,
|
||||
num_batches=args.test_batches,
|
||||
max_seq_length=args.max_seq_length,
|
||||
)
|
||||
if args.training_mode == "dpo":
|
||||
if args.reference_model_path:
|
||||
reference_model, _ = load(args.reference_model_path)
|
||||
else:
|
||||
reference_model = model
|
||||
|
||||
test_ppl = math.exp(test_loss)
|
||||
test_loss, test_rewards = evaluate_dpo(
|
||||
model=model,
|
||||
reference_model=reference_model,
|
||||
dataset=test_set,
|
||||
tokenizer=tokenizer,
|
||||
batch_size=args.batch_size,
|
||||
num_batches=args.test_batches,
|
||||
max_seq_length=args.max_seq_length,
|
||||
beta=args.beta,
|
||||
delta=args.delta,
|
||||
loss_type=args.loss_type,
|
||||
)
|
||||
print(f"Test loss {test_loss:.3f}, Rewards: {test_rewards[0]:.3f}, {test_rewards[1]:.3f}")
|
||||
else:
|
||||
test_loss = evaluate(
|
||||
model=model,
|
||||
dataset=test_set,
|
||||
tokenizer=tokenizer,
|
||||
batch_size=args.batch_size,
|
||||
num_batches=args.test_batches,
|
||||
max_seq_length=args.max_seq_length,
|
||||
)
|
||||
|
||||
print(f"Test loss {test_loss:.3f}, Test ppl {test_ppl:.3f}.")
|
||||
test_ppl = math.exp(test_loss)
|
||||
|
||||
print(f"Test loss {test_loss:.3f}, Test ppl {test_ppl:.3f}.")
|
||||
|
||||
|
||||
def run(args, training_callback: TrainingCallback = None):
|
||||
@@ -263,7 +337,7 @@ def run(args, training_callback: TrainingCallback = None):
|
||||
load_adapters(model, args.adapter_path)
|
||||
|
||||
elif args.train:
|
||||
print("Training")
|
||||
print(f"Training in {args.training_mode} mode")
|
||||
train_model(args, model, tokenizer, train_set, valid_set, training_callback)
|
||||
else:
|
||||
raise ValueError("Must provide at least one of --train or --test")
|
||||
|
||||
Reference in New Issue
Block a user