mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-09-24 15:58:11 +08:00
Possible bug (default_loss)
This commit is contained in:
@@ -64,7 +64,7 @@ class TrainingArgs:
|
|||||||
|
|
||||||
|
|
||||||
def default_loss(model, inputs, targets, lengths):
|
def default_loss(model, inputs, targets, lengths):
|
||||||
logits, _ = model(inputs)
|
logits = model(inputs)
|
||||||
logits = logits.astype(mx.float32)
|
logits = logits.astype(mx.float32)
|
||||||
|
|
||||||
length_mask = mx.arange(inputs.shape[1])[None, :] < lengths[:, None]
|
length_mask = mx.arange(inputs.shape[1])[None, :] < lengths[:, None]
|
||||||
|
Reference in New Issue
Block a user