mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-09-23 23:28:12 +08:00
Possible bug (default_loss)
This commit is contained in:
@@ -64,7 +64,7 @@ class TrainingArgs:
|
||||
|
||||
|
||||
def default_loss(model, inputs, targets, lengths):
|
||||
logits, _ = model(inputs)
|
||||
logits = model(inputs)
|
||||
logits = logits.astype(mx.float32)
|
||||
|
||||
length_mask = mx.arange(inputs.shape[1])[None, :] < lengths[:, None]
|
||||
|
Reference in New Issue
Block a user