mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-25 01:41:19 +08:00
Possible bug (default_loss)
This commit is contained in:
parent
a9192f81b1
commit
70a55ace18
@ -64,7 +64,7 @@ class TrainingArgs:
|
|||||||
|
|
||||||
|
|
||||||
def default_loss(model, inputs, targets, lengths):
|
def default_loss(model, inputs, targets, lengths):
|
||||||
logits, _ = model(inputs)
|
logits = model(inputs)
|
||||||
logits = logits.astype(mx.float32)
|
logits = logits.astype(mx.float32)
|
||||||
|
|
||||||
length_mask = mx.arange(inputs.shape[1])[None, :] < lengths[:, None]
|
length_mask = mx.arange(inputs.shape[1])[None, :] < lengths[:, None]
|
||||||
|
Loading…
Reference in New Issue
Block a user