Whisper improvements (#1080)

* use safetensors in whisper

* speed up decoder

* version
This commit is contained in:
Awni Hannun
2024-11-01 10:52:28 -07:00
committed by GitHub
parent 85ffd2c96a
commit 8160e0c4e5
6 changed files with 85 additions and 64 deletions

View File

@@ -80,12 +80,11 @@ class MultiHeadAttention(nn.Module):
qk = q @ k
if mask is not None:
qk = qk + mask[:n_ctx, :n_ctx]
qk = qk.astype(mx.float32)
w = mx.softmax(qk, axis=-1).astype(q.dtype)
w = mx.softmax(qk, axis=-1, precise=True)
out = (w @ v).transpose(0, 2, 1, 3)
out = out.reshape(n_batch, n_ctx, n_state)
return out, qk
return out, qk.astype(mx.float32)
class ResidualAttentionBlock(nn.Module):