mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-08-29 06:54:39 +08:00
Update trainer.py
This commit is contained in:
parent
231f5e870e
commit
e2ace6fb0f
@ -111,7 +111,9 @@ def cot_loss(
|
|||||||
data_token_id = tokenizer.encode(args.data_token)[0]
|
data_token_id = tokenizer.encode(args.data_token)[0]
|
||||||
|
|
||||||
reasoning_positions = mx.argmax(targets == reasoning_token_id, axis=1)
|
reasoning_positions = mx.argmax(targets == reasoning_token_id, axis=1)
|
||||||
data_positions = mx.argmax(targets == data_token_id, axis=1)
|
# find the LAST occurrence of data_token_id using slicing (in case generated dataset has multiple occurrences of [DATA])
|
||||||
|
data_positions = mx.argmax(targets[:, ::-1] == data_token_id, axis=1)
|
||||||
|
data_positions = targets.shape[1] - 1 - data_positions
|
||||||
|
|
||||||
seq_indices = mx.arange(targets.shape[1])[None, :]
|
seq_indices = mx.arange(targets.shape[1])[None, :]
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user