diff --git a/llms/mlx_lm/tuner/trainer.py b/llms/mlx_lm/tuner/trainer.py index 24e93f92..cd8f513c 100644 --- a/llms/mlx_lm/tuner/trainer.py +++ b/llms/mlx_lm/tuner/trainer.py @@ -69,6 +69,14 @@ class TrainingArgs: default=False, metadata={"help": "Use CoT loss masking with positioning penalty"}, ) + reasoning_token: str = field( + default="[REASONING]", + metadata={"help": "Reasoning token"}, + ) + data_token: str = field( + default="[DATA]", + metadata={"help": "Final answer token"}, + ) def default_loss(model, batch, lengths): @@ -88,25 +96,19 @@ def default_loss(model, batch, lengths): return ce, ntoks -@dataclass -class CotTrainingArgs: - cot: bool = False - reasoning_token: str = "[REASONING]" - data_token: str = "[DATA]" - - def cot_loss( model: nn.Module, inputs: mx.array, targets: mx.array, lengths: int, tokenizer: TokenizerWrapper, + args: TrainingArgs, penalty: mx.float32 = 10.0, ) -> tuple[mx.array, mx.array]: logits = model(inputs).astype(mx.float32) - reasoning_token_id = tokenizer.encode(CotTrainingArgs.reasoning_token)[0] - data_token_id = tokenizer.encode(CotTrainingArgs.data_token)[0] + reasoning_token_id = tokenizer.encode(args.reasoning_token)[0] + data_token_id = tokenizer.encode(args.data_token)[0] reasoning_positions = mx.argmax(targets == reasoning_token_id, axis=1) data_positions = mx.argmax(targets == data_token_id, axis=1) @@ -268,7 +270,7 @@ def train( grad_checkpoint(model.layers[0]) if args.cot: - loss = partial(cot_loss, tokenizer=tokenizer, penalty=10.0) + loss = partial(cot_loss, tokenizer=tokenizer, penalty=10.0, args=args) else: loss = default_loss