Some improvements to speedup alignment computation in MLX Whisper (#1259)

* some improvements to speedup alignment computation in MLX Whisper

* fix alignment
This commit is contained in:
Awni Hannun
2025-02-08 15:47:00 -08:00
committed by GitHub
parent 1503bd4f55
commit f58c7de901
2 changed files with 5 additions and 6 deletions

View File

@@ -84,7 +84,7 @@ class MultiHeadAttention(nn.Module):
w = mx.softmax(qk, axis=-1, precise=True)
out = (w @ v).transpose(0, 2, 1, 3)
out = out.reshape(n_batch, n_ctx, n_state)
return out, qk.astype(mx.float32)
return out, qk
class ResidualAttentionBlock(nn.Module):