From a928bba3752f6e7be969b302683bd0be4247a6cd Mon Sep 17 00:00:00 2001 From: Chime Ogbuji Date: Tue, 28 Jan 2025 14:50:38 -0500 Subject: [PATCH] Pass down TrainingArgs instance to iterate_batches function and TrainingCallback methods Addresses #1224 --- llms/mlx_lm/tuner/trainer.py | 17 +++++++++++------ 1 file changed, 11 insertions(+), 6 deletions(-) diff --git a/llms/mlx_lm/tuner/trainer.py b/llms/mlx_lm/tuner/trainer.py index 63ca58bb..55719bb9 100644 --- a/llms/mlx_lm/tuner/trainer.py +++ b/llms/mlx_lm/tuner/trainer.py @@ -5,7 +5,7 @@ import shutil import time from dataclasses import dataclass, field from pathlib import Path -from typing import Union +from typing import Optional, Union import mlx.core as mx import mlx.nn as nn @@ -76,7 +76,9 @@ def default_loss(model, inputs, targets, lengths): return ce, ntoks -def iterate_batches(dataset, tokenizer, batch_size, max_seq_length, train=False): +def iterate_batches( + dataset, tokenizer, batch_size, max_seq_length, train=False, args=None +): # Sort by length: idx = sorted(range(len(dataset)), key=lambda idx: len(dataset[idx])) if len(dataset) < batch_size: @@ -167,11 +169,13 @@ def evaluate( class TrainingCallback: - def on_train_loss_report(self, train_info: dict): + def on_train_loss_report( + self, train_info: dict, args: Optional[TrainingArgs] = None + ): """Called to report training loss at specified intervals.""" pass - def on_val_loss_report(self, val_info: dict): + def on_val_loss_report(self, val_info: dict, args: Optional[TrainingArgs] = None): """Called to report validation loss at specified intervals or the beginning.""" pass @@ -227,6 +231,7 @@ def train( batch_size=args.batch_size, max_seq_length=args.max_seq_length, train=True, + args=args, ), ): # Report validation loss if needed, the first validation loss @@ -258,7 +263,7 @@ def train( "val_loss": val_loss, "val_time": val_time, } - training_callback.on_val_loss_report(val_info) + training_callback.on_val_loss_report(val_info, args=args) start = time.perf_counter() @@ -301,7 +306,7 @@ def train( "trained_tokens": trained_tokens, "peak_memory": peak_mem, } - training_callback.on_train_loss_report(train_info) + training_callback.on_train_loss_report(train_info, args=args) losses = 0 n_tokens = 0