diff --git a/llms/mlx_lm/lora.py b/llms/mlx_lm/lora.py index 36307802..5c2d1f00 100644 --- a/llms/mlx_lm/lora.py +++ b/llms/mlx_lm/lora.py @@ -194,6 +194,17 @@ def load_dataset(args): return train, valid, test +def print_trainable_parameters(model): + total_p = sum(v.size for _, v in tree_flatten(model.parameters())) / 10**6 + trainable_p = ( + sum(v.size for _, v in tree_flatten(model.trainable_parameters())) / 10**6 + ) + print( + f"Trainable parameters: {(trainable_p * 100 / total_p):.3f}% " + f"({trainable_p:.3f}M/{total_p:.3f}M)" + ) + + def run(args, training_callback: TrainingCallback = None): np.random.seed(args.seed) @@ -205,10 +216,7 @@ def run(args, training_callback: TrainingCallback = None): # Convert linear layers to lora layers and unfreeze in the process linear_to_lora_layers(model, args.lora_layers, args.lora_parameters) - p = sum(v.size for _, v in tree_flatten(model.parameters())) / 10**6 - print(f"Total parameters {p:.3f}M") - p = sum(v.size for _, v in tree_flatten(model.trainable_parameters())) / 10**6 - print(f"Trainable parameters {p:.3f}M") + print_trainable_parameters(model) print("Loading datasets") train_set, valid_set, test_set = load_dataset(args) @@ -217,27 +225,29 @@ def run(args, training_callback: TrainingCallback = None): if args.resume_adapter_file is not None: print(f"Loading pretrained adapters from {args.resume_adapter_file}") model.load_weights(args.resume_adapter_file, strict=False) - # init training args - trainingArgs = TrainingArgs( - batch_size=args.batch_size, - iters=args.iters, - val_batches=args.val_batches, - steps_per_report=args.steps_per_report, - steps_per_eval=args.steps_per_eval, - steps_per_save=args.save_every, - adapter_file=args.adapter_file, - max_seq_length=args.max_seq_length, - grad_checkpoint=args.grad_checkpoint, - ) + if args.train: print("Training") + # init training args + training_args = TrainingArgs( + batch_size=args.batch_size, + iters=args.iters, + val_batches=args.val_batches, + steps_per_report=args.steps_per_report, + steps_per_eval=args.steps_per_eval, + steps_per_save=args.save_every, + adapter_file=args.adapter_file, + max_seq_length=args.max_seq_length, + grad_checkpoint=args.grad_checkpoint, + ) + model.train() opt = optim.Adam(learning_rate=args.learning_rate) # Train model train( model=model, tokenizer=tokenizer, - args=trainingArgs, + args=training_args, optimizer=opt, train_dataset=train_set, val_dataset=valid_set,