diff --git a/llms/mlx_lm/tuner/trainer.py b/llms/mlx_lm/tuner/trainer.py index e2b55db3..f5957782 100644 --- a/llms/mlx_lm/tuner/trainer.py +++ b/llms/mlx_lm/tuner/trainer.py @@ -64,7 +64,7 @@ class TrainingArgs: def default_loss(model, inputs, targets, lengths): - logits, _ = model(inputs) + logits = model(inputs) logits = logits.astype(mx.float32) length_mask = mx.arange(inputs.shape[1])[None, :] < lengths[:, None]