mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-10-23 22:18:06 +08:00
Whisper improvements (#1080)
* use safetensors in whisper * speed up decoder * version
This commit is contained in:
@@ -80,12 +80,11 @@ class MultiHeadAttention(nn.Module):
|
||||
qk = q @ k
|
||||
if mask is not None:
|
||||
qk = qk + mask[:n_ctx, :n_ctx]
|
||||
qk = qk.astype(mx.float32)
|
||||
|
||||
w = mx.softmax(qk, axis=-1).astype(q.dtype)
|
||||
w = mx.softmax(qk, axis=-1, precise=True)
|
||||
out = (w @ v).transpose(0, 2, 1, 3)
|
||||
out = out.reshape(n_batch, n_ctx, n_state)
|
||||
return out, qk
|
||||
return out, qk.astype(mx.float32)
|
||||
|
||||
|
||||
class ResidualAttentionBlock(nn.Module):
|
||||
|
Reference in New Issue
Block a user