From e2ace6fb0f1411e28b82d0c07337a72a6233317e Mon Sep 17 00:00:00 2001 From: paNikitin <115797306+paNikitin@users.noreply.github.com> Date: Mon, 24 Feb 2025 09:12:31 +0300 Subject: [PATCH] Update trainer.py --- llms/mlx_lm/tuner/trainer.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/llms/mlx_lm/tuner/trainer.py b/llms/mlx_lm/tuner/trainer.py index cd8f513c..04933fe7 100644 --- a/llms/mlx_lm/tuner/trainer.py +++ b/llms/mlx_lm/tuner/trainer.py @@ -111,7 +111,9 @@ def cot_loss( data_token_id = tokenizer.encode(args.data_token)[0] reasoning_positions = mx.argmax(targets == reasoning_token_id, axis=1) - data_positions = mx.argmax(targets == data_token_id, axis=1) + # find the LAST occurrence of data_token_id using slicing (in case generated dataset has multiple occurrences of [DATA]) + data_positions = mx.argmax(targets[:, ::-1] == data_token_id, axis=1) + data_positions = targets.shape[1] - 1 - data_positions seq_indices = mx.arange(targets.shape[1])[None, :]