Some improvements to speedup alignment computation in MLX Whisper (#1259)

* some improvements to speedup alignment computation in MLX Whisper

* fix alignment
This commit is contained in:
Awni Hannun 2025-02-08 15:47:00 -08:00 committed by GitHub
parent 1503bd4f55
commit f58c7de901
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 5 additions and 6 deletions

View File

@ -134,9 +134,7 @@ def find_alignment(
logits, cross_qk = model.forward_with_cross_qk(mel[None, :], tokens[None, :])
# consider only the logits associated with predicting text
sampled_logits = logits[0][len(tokenizer.sot_sequence) : -2, : tokenizer.eot]
token_probs = mx.softmax(sampled_logits.astype(mx.float32), axis=-1).astype(
sampled_logits.dtype
)
token_probs = mx.softmax(sampled_logits, precise=True, axis=-1)
text_token_probs = mx.take_along_axis(
token_probs, mx.array(text_tokens)[:, None], axis=1
).squeeze(1)
@ -144,10 +142,11 @@ def find_alignment(
# heads * tokens * frames
weights = mx.stack(
[cross_qk[_l.item()][0, _h.item()] for _l, _h in model.alignment_heads]
[cross_qk[_l][0, _h] for _l, _h in model.alignment_heads.tolist()]
)
weights = weights[:, :, : num_frames // 2]
weights = mx.softmax(weights * qk_scale, axis=-1)
weights = mx.softmax(weights * qk_scale, axis=-1, precise=True)
weights = weights.astype(mx.float32)
mean = mx.mean(weights, axis=-2, keepdims=True)
std = mx.var(weights, axis=-2, keepdims=True, ddof=0).sqrt()
weights = (weights - mean) / std

View File

@ -84,7 +84,7 @@ class MultiHeadAttention(nn.Module):
w = mx.softmax(qk, axis=-1, precise=True)
out = (w @ v).transpose(0, 2, 1, 3)
out = out.reshape(n_batch, n_ctx, n_state)
return out, qk.astype(mx.float32)
return out, qk
class ResidualAttentionBlock(nn.Module):