Possible bug (default_loss)

This commit is contained in:
JosefAlbers
2024-05-10 22:48:59 +09:00
committed by GitHub
parent a9192f81b1
commit 70a55ace18

View File

@@ -64,7 +64,7 @@ class TrainingArgs:
def default_loss(model, inputs, targets, lengths):
logits, _ = model(inputs)
logits = model(inputs)
logits = logits.astype(mx.float32)
length_mask = mx.arange(inputs.shape[1])[None, :] < lengths[:, None]