adding wandb reporting to lora.py

This commit is contained in:
Goekdeniz-Guelmez 2025-03-12 14:31:23 +01:00
parent 57175b7b95
commit 04537fa346

View File

@ -63,6 +63,7 @@ CONFIG_DEFAULTS = {
"lr_schedule": None,
"lora_parameters": {"rank": 8, "alpha": 16, "dropout": 0.0, "scale": 10.0},
"mask_prompt": False,
"report_to_wandb": False
}
@ -169,6 +170,12 @@ def build_parser():
help="Use gradient checkpointing to reduce memory use.",
default=None,
)
parser.add_argument(
"--report-to-wandb",
action="store_true",
help="Report the training args to WandB.",
default=None,
)
parser.add_argument("--seed", type=int, help="The PRNG seed")
return parser
@ -267,6 +274,26 @@ def evaluate_model(args, model: nn.Module, tokenizer: TokenizerWrapper, test_set
def run(args, training_callback: TrainingCallback = None):
np.random.seed(args.seed)
# Initialize WandB if requested
if args.report_to_wandb:
import wandb
wandb.init(project="mlx-finetuning", config=vars(args))
# Create a simple wandb callback that wraps the existing one
original_callback = training_callback
class WandBCallback(TrainingCallback):
def on_train_loss_report(self, train_info: dict):
wandb.log(train_info)
if original_callback:
original_callback.on_train_loss_report(train_info)
def on_val_loss_report(self, val_info: dict):
wandb.log(val_info)
if original_callback:
original_callback.on_val_loss_report(val_info)
training_callback = WandBCallback()
print("Loading pretrained model")
model, tokenizer = load(args.model)