diff --git a/llms/mlx_lm/tuner/trainer.py b/llms/mlx_lm/tuner/trainer.py index cd8f513c..04933fe7 100644 --- a/llms/mlx_lm/tuner/trainer.py +++ b/llms/mlx_lm/tuner/trainer.py @@ -111,7 +111,9 @@ def cot_loss( data_token_id = tokenizer.encode(args.data_token)[0] reasoning_positions = mx.argmax(targets == reasoning_token_id, axis=1) - data_positions = mx.argmax(targets == data_token_id, axis=1) + # find the LAST occurrence of data_token_id using slicing (in case generated dataset has multiple occurrences of [DATA]) + data_positions = mx.argmax(targets[:, ::-1] == data_token_id, axis=1) + data_positions = targets.shape[1] - 1 - data_positions seq_indices = mx.arange(targets.shape[1])[None, :]