Update trainer.py

This commit is contained in:
paNikitin 2025-02-24 09:12:31 +03:00
parent 231f5e870e
commit e2ace6fb0f

View File

@ -111,7 +111,9 @@ def cot_loss(
data_token_id = tokenizer.encode(args.data_token)[0] data_token_id = tokenizer.encode(args.data_token)[0]
reasoning_positions = mx.argmax(targets == reasoning_token_id, axis=1) reasoning_positions = mx.argmax(targets == reasoning_token_id, axis=1)
data_positions = mx.argmax(targets == data_token_id, axis=1) # find the LAST occurrence of data_token_id using slicing (in case generated dataset has multiple occurrences of [DATA])
data_positions = mx.argmax(targets[:, ::-1] == data_token_id, axis=1)
data_positions = targets.shape[1] - 1 - data_positions
seq_indices = mx.arange(targets.shape[1])[None, :] seq_indices = mx.arange(targets.shape[1])[None, :]