mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-24 17:31:18 +08:00
adding wandb reporting to lora.py
This commit is contained in:
parent
57175b7b95
commit
04537fa346
@ -63,6 +63,7 @@ CONFIG_DEFAULTS = {
|
|||||||
"lr_schedule": None,
|
"lr_schedule": None,
|
||||||
"lora_parameters": {"rank": 8, "alpha": 16, "dropout": 0.0, "scale": 10.0},
|
"lora_parameters": {"rank": 8, "alpha": 16, "dropout": 0.0, "scale": 10.0},
|
||||||
"mask_prompt": False,
|
"mask_prompt": False,
|
||||||
|
"report_to_wandb": False
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@ -169,6 +170,12 @@ def build_parser():
|
|||||||
help="Use gradient checkpointing to reduce memory use.",
|
help="Use gradient checkpointing to reduce memory use.",
|
||||||
default=None,
|
default=None,
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--report-to-wandb",
|
||||||
|
action="store_true",
|
||||||
|
help="Report the training args to WandB.",
|
||||||
|
default=None,
|
||||||
|
)
|
||||||
parser.add_argument("--seed", type=int, help="The PRNG seed")
|
parser.add_argument("--seed", type=int, help="The PRNG seed")
|
||||||
return parser
|
return parser
|
||||||
|
|
||||||
@ -267,6 +274,26 @@ def evaluate_model(args, model: nn.Module, tokenizer: TokenizerWrapper, test_set
|
|||||||
def run(args, training_callback: TrainingCallback = None):
|
def run(args, training_callback: TrainingCallback = None):
|
||||||
np.random.seed(args.seed)
|
np.random.seed(args.seed)
|
||||||
|
|
||||||
|
# Initialize WandB if requested
|
||||||
|
if args.report_to_wandb:
|
||||||
|
import wandb
|
||||||
|
wandb.init(project="mlx-finetuning", config=vars(args))
|
||||||
|
|
||||||
|
# Create a simple wandb callback that wraps the existing one
|
||||||
|
original_callback = training_callback
|
||||||
|
class WandBCallback(TrainingCallback):
|
||||||
|
def on_train_loss_report(self, train_info: dict):
|
||||||
|
wandb.log(train_info)
|
||||||
|
if original_callback:
|
||||||
|
original_callback.on_train_loss_report(train_info)
|
||||||
|
|
||||||
|
def on_val_loss_report(self, val_info: dict):
|
||||||
|
wandb.log(val_info)
|
||||||
|
if original_callback:
|
||||||
|
original_callback.on_val_loss_report(val_info)
|
||||||
|
|
||||||
|
training_callback = WandBCallback()
|
||||||
|
|
||||||
print("Loading pretrained model")
|
print("Loading pretrained model")
|
||||||
model, tokenizer = load(args.model)
|
model, tokenizer = load(args.model)
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user