mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-24 09:21:18 +08:00
adding wandb reporting to lora.py
This commit is contained in:
parent
57175b7b95
commit
04537fa346
@ -63,6 +63,7 @@ CONFIG_DEFAULTS = {
|
||||
"lr_schedule": None,
|
||||
"lora_parameters": {"rank": 8, "alpha": 16, "dropout": 0.0, "scale": 10.0},
|
||||
"mask_prompt": False,
|
||||
"report_to_wandb": False
|
||||
}
|
||||
|
||||
|
||||
@ -169,6 +170,12 @@ def build_parser():
|
||||
help="Use gradient checkpointing to reduce memory use.",
|
||||
default=None,
|
||||
)
|
||||
parser.add_argument(
|
||||
"--report-to-wandb",
|
||||
action="store_true",
|
||||
help="Report the training args to WandB.",
|
||||
default=None,
|
||||
)
|
||||
parser.add_argument("--seed", type=int, help="The PRNG seed")
|
||||
return parser
|
||||
|
||||
@ -267,6 +274,26 @@ def evaluate_model(args, model: nn.Module, tokenizer: TokenizerWrapper, test_set
|
||||
def run(args, training_callback: TrainingCallback = None):
|
||||
np.random.seed(args.seed)
|
||||
|
||||
# Initialize WandB if requested
|
||||
if args.report_to_wandb:
|
||||
import wandb
|
||||
wandb.init(project="mlx-finetuning", config=vars(args))
|
||||
|
||||
# Create a simple wandb callback that wraps the existing one
|
||||
original_callback = training_callback
|
||||
class WandBCallback(TrainingCallback):
|
||||
def on_train_loss_report(self, train_info: dict):
|
||||
wandb.log(train_info)
|
||||
if original_callback:
|
||||
original_callback.on_train_loss_report(train_info)
|
||||
|
||||
def on_val_loss_report(self, val_info: dict):
|
||||
wandb.log(val_info)
|
||||
if original_callback:
|
||||
original_callback.on_val_loss_report(val_info)
|
||||
|
||||
training_callback = WandBCallback()
|
||||
|
||||
print("Loading pretrained model")
|
||||
model, tokenizer = load(args.model)
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user