From d9924d08d15fbc145466f06489d106b219f12323 Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Mon, 3 Feb 2025 09:55:24 -0800 Subject: [PATCH] Fix no validation in lora (#1241) --- llms/mlx_lm/tuner/trainer.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/llms/mlx_lm/tuner/trainer.py b/llms/mlx_lm/tuner/trainer.py index 63ca58bb..bf84d066 100644 --- a/llms/mlx_lm/tuner/trainer.py +++ b/llms/mlx_lm/tuner/trainer.py @@ -140,8 +140,8 @@ def evaluate( loss: callable = default_loss, iterate_batches: callable = iterate_batches, ): - all_losses = 0 - ntokens = 0 + all_losses = mx.array(0.0) + ntokens = mx.array(0) index_iterator = iter(range(num_batches)) if num_batches != -1 else iter(int, 1)