mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-08-29 01:12:24 +08:00
Update trainer.py
This commit is contained in:
parent
a2b61afd05
commit
5b7581f41c
@ -69,6 +69,14 @@ class TrainingArgs:
|
|||||||
default=False,
|
default=False,
|
||||||
metadata={"help": "Use CoT loss masking with positioning penalty"},
|
metadata={"help": "Use CoT loss masking with positioning penalty"},
|
||||||
)
|
)
|
||||||
|
reasoning_token: str = field(
|
||||||
|
default="[REASONING]",
|
||||||
|
metadata={"help": "Reasoning token"},
|
||||||
|
)
|
||||||
|
data_token: str = field(
|
||||||
|
default="[DATA]",
|
||||||
|
metadata={"help": "Final answer token"},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def default_loss(model, batch, lengths):
|
def default_loss(model, batch, lengths):
|
||||||
@ -88,25 +96,19 @@ def default_loss(model, batch, lengths):
|
|||||||
return ce, ntoks
|
return ce, ntoks
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class CotTrainingArgs:
|
|
||||||
cot: bool = False
|
|
||||||
reasoning_token: str = "[REASONING]"
|
|
||||||
data_token: str = "[DATA]"
|
|
||||||
|
|
||||||
|
|
||||||
def cot_loss(
|
def cot_loss(
|
||||||
model: nn.Module,
|
model: nn.Module,
|
||||||
inputs: mx.array,
|
inputs: mx.array,
|
||||||
targets: mx.array,
|
targets: mx.array,
|
||||||
lengths: int,
|
lengths: int,
|
||||||
tokenizer: TokenizerWrapper,
|
tokenizer: TokenizerWrapper,
|
||||||
|
args: TrainingArgs,
|
||||||
penalty: mx.float32 = 10.0,
|
penalty: mx.float32 = 10.0,
|
||||||
) -> tuple[mx.array, mx.array]:
|
) -> tuple[mx.array, mx.array]:
|
||||||
logits = model(inputs).astype(mx.float32)
|
logits = model(inputs).astype(mx.float32)
|
||||||
|
|
||||||
reasoning_token_id = tokenizer.encode(CotTrainingArgs.reasoning_token)[0]
|
reasoning_token_id = tokenizer.encode(args.reasoning_token)[0]
|
||||||
data_token_id = tokenizer.encode(CotTrainingArgs.data_token)[0]
|
data_token_id = tokenizer.encode(args.data_token)[0]
|
||||||
|
|
||||||
reasoning_positions = mx.argmax(targets == reasoning_token_id, axis=1)
|
reasoning_positions = mx.argmax(targets == reasoning_token_id, axis=1)
|
||||||
data_positions = mx.argmax(targets == data_token_id, axis=1)
|
data_positions = mx.argmax(targets == data_token_id, axis=1)
|
||||||
@ -268,7 +270,7 @@ def train(
|
|||||||
grad_checkpoint(model.layers[0])
|
grad_checkpoint(model.layers[0])
|
||||||
|
|
||||||
if args.cot:
|
if args.cot:
|
||||||
loss = partial(cot_loss, tokenizer=tokenizer, penalty=10.0)
|
loss = partial(cot_loss, tokenizer=tokenizer, penalty=10.0, args=args)
|
||||||
else:
|
else:
|
||||||
loss = default_loss
|
loss = default_loss
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user