mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-24 01:17:28 +08:00
LoRA: some minor optimizations (#573)
* init training_args in training scope * Add trainable parameters percentage
This commit is contained in:
parent
d4e1de1d5b
commit
485180ae91
@ -194,6 +194,17 @@ def load_dataset(args):
|
||||
return train, valid, test
|
||||
|
||||
|
||||
def print_trainable_parameters(model):
|
||||
total_p = sum(v.size for _, v in tree_flatten(model.parameters())) / 10**6
|
||||
trainable_p = (
|
||||
sum(v.size for _, v in tree_flatten(model.trainable_parameters())) / 10**6
|
||||
)
|
||||
print(
|
||||
f"Trainable parameters: {(trainable_p * 100 / total_p):.3f}% "
|
||||
f"({trainable_p:.3f}M/{total_p:.3f}M)"
|
||||
)
|
||||
|
||||
|
||||
def run(args, training_callback: TrainingCallback = None):
|
||||
np.random.seed(args.seed)
|
||||
|
||||
@ -205,10 +216,7 @@ def run(args, training_callback: TrainingCallback = None):
|
||||
# Convert linear layers to lora layers and unfreeze in the process
|
||||
linear_to_lora_layers(model, args.lora_layers, args.lora_parameters)
|
||||
|
||||
p = sum(v.size for _, v in tree_flatten(model.parameters())) / 10**6
|
||||
print(f"Total parameters {p:.3f}M")
|
||||
p = sum(v.size for _, v in tree_flatten(model.trainable_parameters())) / 10**6
|
||||
print(f"Trainable parameters {p:.3f}M")
|
||||
print_trainable_parameters(model)
|
||||
|
||||
print("Loading datasets")
|
||||
train_set, valid_set, test_set = load_dataset(args)
|
||||
@ -217,27 +225,29 @@ def run(args, training_callback: TrainingCallback = None):
|
||||
if args.resume_adapter_file is not None:
|
||||
print(f"Loading pretrained adapters from {args.resume_adapter_file}")
|
||||
model.load_weights(args.resume_adapter_file, strict=False)
|
||||
# init training args
|
||||
trainingArgs = TrainingArgs(
|
||||
batch_size=args.batch_size,
|
||||
iters=args.iters,
|
||||
val_batches=args.val_batches,
|
||||
steps_per_report=args.steps_per_report,
|
||||
steps_per_eval=args.steps_per_eval,
|
||||
steps_per_save=args.save_every,
|
||||
adapter_file=args.adapter_file,
|
||||
max_seq_length=args.max_seq_length,
|
||||
grad_checkpoint=args.grad_checkpoint,
|
||||
)
|
||||
|
||||
if args.train:
|
||||
print("Training")
|
||||
# init training args
|
||||
training_args = TrainingArgs(
|
||||
batch_size=args.batch_size,
|
||||
iters=args.iters,
|
||||
val_batches=args.val_batches,
|
||||
steps_per_report=args.steps_per_report,
|
||||
steps_per_eval=args.steps_per_eval,
|
||||
steps_per_save=args.save_every,
|
||||
adapter_file=args.adapter_file,
|
||||
max_seq_length=args.max_seq_length,
|
||||
grad_checkpoint=args.grad_checkpoint,
|
||||
)
|
||||
|
||||
model.train()
|
||||
opt = optim.Adam(learning_rate=args.learning_rate)
|
||||
# Train model
|
||||
train(
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
args=trainingArgs,
|
||||
args=training_args,
|
||||
optimizer=opt,
|
||||
train_dataset=train_set,
|
||||
val_dataset=valid_set,
|
||||
|
Loading…
Reference in New Issue
Block a user