From e446598f62b4e89e8d31b50080db5e071fd739ef Mon Sep 17 00:00:00 2001 From: Chime Ogbuji Date: Tue, 13 Feb 2024 10:03:25 -0500 Subject: [PATCH] Passing parameterized loss and batching to trainer (#391) --- llms/mlx_lm/tuner/trainer.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/llms/mlx_lm/tuner/trainer.py b/llms/mlx_lm/tuner/trainer.py index e6571ee5..b2c5ba69 100644 --- a/llms/mlx_lm/tuner/trainer.py +++ b/llms/mlx_lm/tuner/trainer.py @@ -99,6 +99,7 @@ def evaluate( num_batches, max_seq_length=2048, loss: callable = default_loss, + iterate_batches: callable = iterate_batches ): all_losses = [] ntokens = 0 @@ -126,6 +127,7 @@ def train( val_dataset, args: TrainingArgs = TrainingArgs(), loss: callable = default_loss, + iterate_batches: callable = iterate_batches ): # Create checkpoints directory if it does not exist if not os.path.exists("checkpoints"): @@ -186,6 +188,7 @@ def train( batch_size=args.batch_size, num_batches=args.val_batches, max_seq_length=args.max_seq_length, + iterate_batches=iterate_batches ) print( f"Iter {it + 1}: "