diff --git a/llms/mlx_lm/tuner/trainer.py b/llms/mlx_lm/tuner/trainer.py index e6571ee5..b2c5ba69 100644 --- a/llms/mlx_lm/tuner/trainer.py +++ b/llms/mlx_lm/tuner/trainer.py @@ -99,6 +99,7 @@ def evaluate( num_batches, max_seq_length=2048, loss: callable = default_loss, + iterate_batches: callable = iterate_batches ): all_losses = [] ntokens = 0 @@ -126,6 +127,7 @@ def train( val_dataset, args: TrainingArgs = TrainingArgs(), loss: callable = default_loss, + iterate_batches: callable = iterate_batches ): # Create checkpoints directory if it does not exist if not os.path.exists("checkpoints"): @@ -186,6 +188,7 @@ def train( batch_size=args.batch_size, num_batches=args.val_batches, max_seq_length=args.max_seq_length, + iterate_batches=iterate_batches ) print( f"Iter {it + 1}: "