LoRA: some minor optimizations (#573)

* init training_args in training scope

* Add trainable parameters percentage
This commit is contained in:
madroid 2024-03-14 11:26:30 +08:00 committed by GitHub
parent d4e1de1d5b
commit 485180ae91
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -194,6 +194,17 @@ def load_dataset(args):
return train, valid, test
def print_trainable_parameters(model):
total_p = sum(v.size for _, v in tree_flatten(model.parameters())) / 10**6
trainable_p = (
sum(v.size for _, v in tree_flatten(model.trainable_parameters())) / 10**6
)
print(
f"Trainable parameters: {(trainable_p * 100 / total_p):.3f}% "
f"({trainable_p:.3f}M/{total_p:.3f}M)"
)
def run(args, training_callback: TrainingCallback = None):
np.random.seed(args.seed)
@ -205,10 +216,7 @@ def run(args, training_callback: TrainingCallback = None):
# Convert linear layers to lora layers and unfreeze in the process
linear_to_lora_layers(model, args.lora_layers, args.lora_parameters)
p = sum(v.size for _, v in tree_flatten(model.parameters())) / 10**6
print(f"Total parameters {p:.3f}M")
p = sum(v.size for _, v in tree_flatten(model.trainable_parameters())) / 10**6
print(f"Trainable parameters {p:.3f}M")
print_trainable_parameters(model)
print("Loading datasets")
train_set, valid_set, test_set = load_dataset(args)
@ -217,27 +225,29 @@ def run(args, training_callback: TrainingCallback = None):
if args.resume_adapter_file is not None:
print(f"Loading pretrained adapters from {args.resume_adapter_file}")
model.load_weights(args.resume_adapter_file, strict=False)
# init training args
trainingArgs = TrainingArgs(
batch_size=args.batch_size,
iters=args.iters,
val_batches=args.val_batches,
steps_per_report=args.steps_per_report,
steps_per_eval=args.steps_per_eval,
steps_per_save=args.save_every,
adapter_file=args.adapter_file,
max_seq_length=args.max_seq_length,
grad_checkpoint=args.grad_checkpoint,
)
if args.train:
print("Training")
# init training args
training_args = TrainingArgs(
batch_size=args.batch_size,
iters=args.iters,
val_batches=args.val_batches,
steps_per_report=args.steps_per_report,
steps_per_eval=args.steps_per_eval,
steps_per_save=args.save_every,
adapter_file=args.adapter_file,
max_seq_length=args.max_seq_length,
grad_checkpoint=args.grad_checkpoint,
)
model.train()
opt = optim.Adam(learning_rate=args.learning_rate)
# Train model
train(
model=model,
tokenizer=tokenizer,
args=trainingArgs,
args=training_args,
optimizer=opt,
train_dataset=train_set,
val_dataset=valid_set,