mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-24 09:21:18 +08:00
Possible bug (default_loss)
This commit is contained in:
parent
a9192f81b1
commit
70a55ace18
@ -64,7 +64,7 @@ class TrainingArgs:
|
||||
|
||||
|
||||
def default_loss(model, inputs, targets, lengths):
|
||||
logits, _ = model(inputs)
|
||||
logits = model(inputs)
|
||||
logits = logits.astype(mx.float32)
|
||||
|
||||
length_mask = mx.arange(inputs.shape[1])[None, :] < lengths[:, None]
|
||||
|
Loading…
Reference in New Issue
Block a user