mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-24 09:21:18 +08:00
LoRA: some minor optimizations (#573)
* init training_args in training scope * Add trainable parameters percentage
This commit is contained in:
parent
d4e1de1d5b
commit
485180ae91
@ -194,6 +194,17 @@ def load_dataset(args):
|
|||||||
return train, valid, test
|
return train, valid, test
|
||||||
|
|
||||||
|
|
||||||
|
def print_trainable_parameters(model):
|
||||||
|
total_p = sum(v.size for _, v in tree_flatten(model.parameters())) / 10**6
|
||||||
|
trainable_p = (
|
||||||
|
sum(v.size for _, v in tree_flatten(model.trainable_parameters())) / 10**6
|
||||||
|
)
|
||||||
|
print(
|
||||||
|
f"Trainable parameters: {(trainable_p * 100 / total_p):.3f}% "
|
||||||
|
f"({trainable_p:.3f}M/{total_p:.3f}M)"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def run(args, training_callback: TrainingCallback = None):
|
def run(args, training_callback: TrainingCallback = None):
|
||||||
np.random.seed(args.seed)
|
np.random.seed(args.seed)
|
||||||
|
|
||||||
@ -205,10 +216,7 @@ def run(args, training_callback: TrainingCallback = None):
|
|||||||
# Convert linear layers to lora layers and unfreeze in the process
|
# Convert linear layers to lora layers and unfreeze in the process
|
||||||
linear_to_lora_layers(model, args.lora_layers, args.lora_parameters)
|
linear_to_lora_layers(model, args.lora_layers, args.lora_parameters)
|
||||||
|
|
||||||
p = sum(v.size for _, v in tree_flatten(model.parameters())) / 10**6
|
print_trainable_parameters(model)
|
||||||
print(f"Total parameters {p:.3f}M")
|
|
||||||
p = sum(v.size for _, v in tree_flatten(model.trainable_parameters())) / 10**6
|
|
||||||
print(f"Trainable parameters {p:.3f}M")
|
|
||||||
|
|
||||||
print("Loading datasets")
|
print("Loading datasets")
|
||||||
train_set, valid_set, test_set = load_dataset(args)
|
train_set, valid_set, test_set = load_dataset(args)
|
||||||
@ -217,8 +225,11 @@ def run(args, training_callback: TrainingCallback = None):
|
|||||||
if args.resume_adapter_file is not None:
|
if args.resume_adapter_file is not None:
|
||||||
print(f"Loading pretrained adapters from {args.resume_adapter_file}")
|
print(f"Loading pretrained adapters from {args.resume_adapter_file}")
|
||||||
model.load_weights(args.resume_adapter_file, strict=False)
|
model.load_weights(args.resume_adapter_file, strict=False)
|
||||||
|
|
||||||
|
if args.train:
|
||||||
|
print("Training")
|
||||||
# init training args
|
# init training args
|
||||||
trainingArgs = TrainingArgs(
|
training_args = TrainingArgs(
|
||||||
batch_size=args.batch_size,
|
batch_size=args.batch_size,
|
||||||
iters=args.iters,
|
iters=args.iters,
|
||||||
val_batches=args.val_batches,
|
val_batches=args.val_batches,
|
||||||
@ -229,15 +240,14 @@ def run(args, training_callback: TrainingCallback = None):
|
|||||||
max_seq_length=args.max_seq_length,
|
max_seq_length=args.max_seq_length,
|
||||||
grad_checkpoint=args.grad_checkpoint,
|
grad_checkpoint=args.grad_checkpoint,
|
||||||
)
|
)
|
||||||
if args.train:
|
|
||||||
print("Training")
|
|
||||||
model.train()
|
model.train()
|
||||||
opt = optim.Adam(learning_rate=args.learning_rate)
|
opt = optim.Adam(learning_rate=args.learning_rate)
|
||||||
# Train model
|
# Train model
|
||||||
train(
|
train(
|
||||||
model=model,
|
model=model,
|
||||||
tokenizer=tokenizer,
|
tokenizer=tokenizer,
|
||||||
args=trainingArgs,
|
args=training_args,
|
||||||
optimizer=opt,
|
optimizer=opt,
|
||||||
train_dataset=train_set,
|
train_dataset=train_set,
|
||||||
val_dataset=valid_set,
|
val_dataset=valid_set,
|
||||||
|
Loading…
Reference in New Issue
Block a user