fix lora for openelm (#773)

This commit is contained in:
Awni Hannun
2024-05-10 09:51:41 -07:00
committed by GitHub
parent fad9598372
commit 6f0a69e682
2 changed files with 4 additions and 2 deletions

View File

@@ -64,7 +64,7 @@ class TrainingArgs:
def default_loss(model, inputs, targets, lengths):
logits, _ = model(inputs)
logits = model(inputs)
logits = logits.astype(mx.float32)
length_mask = mx.arange(inputs.shape[1])[None, :] < lengths[:, None]