mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-09-01 04:14:38 +08:00
fix lora for openelm (#773)
This commit is contained in:
@@ -64,7 +64,7 @@ class TrainingArgs:
|
||||
|
||||
|
||||
def default_loss(model, inputs, targets, lengths):
|
||||
logits, _ = model(inputs)
|
||||
logits = model(inputs)
|
||||
logits = logits.astype(mx.float32)
|
||||
|
||||
length_mask = mx.arange(inputs.shape[1])[None, :] < lengths[:, None]
|
||||
|
Reference in New Issue
Block a user