diff --git a/llms/mlx_lm/tuner/trainer.py b/llms/mlx_lm/tuner/trainer.py index 63ca58bb..bf84d066 100644 --- a/llms/mlx_lm/tuner/trainer.py +++ b/llms/mlx_lm/tuner/trainer.py @@ -140,8 +140,8 @@ def evaluate( loss: callable = default_loss, iterate_batches: callable = iterate_batches, ): - all_losses = 0 - ntokens = 0 + all_losses = mx.array(0.0) + ntokens = mx.array(0) index_iterator = iter(range(num_batches)) if num_batches != -1 else iter(int, 1)