mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-08-28 06:00:19 +08:00
Update trainer.py
This commit is contained in:
parent
a2b61afd05
commit
5b7581f41c
@ -69,6 +69,14 @@ class TrainingArgs:
|
||||
default=False,
|
||||
metadata={"help": "Use CoT loss masking with positioning penalty"},
|
||||
)
|
||||
reasoning_token: str = field(
|
||||
default="[REASONING]",
|
||||
metadata={"help": "Reasoning token"},
|
||||
)
|
||||
data_token: str = field(
|
||||
default="[DATA]",
|
||||
metadata={"help": "Final answer token"},
|
||||
)
|
||||
|
||||
|
||||
def default_loss(model, batch, lengths):
|
||||
@ -88,25 +96,19 @@ def default_loss(model, batch, lengths):
|
||||
return ce, ntoks
|
||||
|
||||
|
||||
@dataclass
|
||||
class CotTrainingArgs:
|
||||
cot: bool = False
|
||||
reasoning_token: str = "[REASONING]"
|
||||
data_token: str = "[DATA]"
|
||||
|
||||
|
||||
def cot_loss(
|
||||
model: nn.Module,
|
||||
inputs: mx.array,
|
||||
targets: mx.array,
|
||||
lengths: int,
|
||||
tokenizer: TokenizerWrapper,
|
||||
args: TrainingArgs,
|
||||
penalty: mx.float32 = 10.0,
|
||||
) -> tuple[mx.array, mx.array]:
|
||||
logits = model(inputs).astype(mx.float32)
|
||||
|
||||
reasoning_token_id = tokenizer.encode(CotTrainingArgs.reasoning_token)[0]
|
||||
data_token_id = tokenizer.encode(CotTrainingArgs.data_token)[0]
|
||||
reasoning_token_id = tokenizer.encode(args.reasoning_token)[0]
|
||||
data_token_id = tokenizer.encode(args.data_token)[0]
|
||||
|
||||
reasoning_positions = mx.argmax(targets == reasoning_token_id, axis=1)
|
||||
data_positions = mx.argmax(targets == data_token_id, axis=1)
|
||||
@ -268,7 +270,7 @@ def train(
|
||||
grad_checkpoint(model.layers[0])
|
||||
|
||||
if args.cot:
|
||||
loss = partial(cot_loss, tokenizer=tokenizer, penalty=10.0)
|
||||
loss = partial(cot_loss, tokenizer=tokenizer, penalty=10.0, args=args)
|
||||
else:
|
||||
loss = default_loss
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user