Update trainer.py

This commit is contained in:
paNikitin 2025-02-23 12:56:09 +03:00
parent a2b61afd05
commit 5b7581f41c

View File

@ -69,6 +69,14 @@ class TrainingArgs:
default=False,
metadata={"help": "Use CoT loss masking with positioning penalty"},
)
reasoning_token: str = field(
default="[REASONING]",
metadata={"help": "Reasoning token"},
)
data_token: str = field(
default="[DATA]",
metadata={"help": "Final answer token"},
)
def default_loss(model, batch, lengths):
@ -88,25 +96,19 @@ def default_loss(model, batch, lengths):
return ce, ntoks
@dataclass
class CotTrainingArgs:
cot: bool = False
reasoning_token: str = "[REASONING]"
data_token: str = "[DATA]"
def cot_loss(
model: nn.Module,
inputs: mx.array,
targets: mx.array,
lengths: int,
tokenizer: TokenizerWrapper,
args: TrainingArgs,
penalty: mx.float32 = 10.0,
) -> tuple[mx.array, mx.array]:
logits = model(inputs).astype(mx.float32)
reasoning_token_id = tokenizer.encode(CotTrainingArgs.reasoning_token)[0]
data_token_id = tokenizer.encode(CotTrainingArgs.data_token)[0]
reasoning_token_id = tokenizer.encode(args.reasoning_token)[0]
data_token_id = tokenizer.encode(args.data_token)[0]
reasoning_positions = mx.argmax(targets == reasoning_token_id, axis=1)
data_positions = mx.argmax(targets == data_token_id, axis=1)
@ -268,7 +270,7 @@ def train(
grad_checkpoint(model.layers[0])
if args.cot:
loss = partial(cot_loss, tokenizer=tokenizer, penalty=10.0)
loss = partial(cot_loss, tokenizer=tokenizer, penalty=10.0, args=args)
else:
loss = default_loss