Possible bug (default_loss)

This commit is contained in:
JosefAlbers 2024-05-10 22:48:59 +09:00 committed by GitHub
parent a9192f81b1
commit 70a55ace18
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -64,7 +64,7 @@ class TrainingArgs:
def default_loss(model, inputs, targets, lengths): def default_loss(model, inputs, targets, lengths):
logits, _ = model(inputs) logits = model(inputs)
logits = logits.astype(mx.float32) logits = logits.astype(mx.float32)
length_mask = mx.arange(inputs.shape[1])[None, :] < lengths[:, None] length_mask = mx.arange(inputs.shape[1])[None, :] < lengths[:, None]