From f58c7de9017b54b044703f88787e6c679db9ec7e Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Sat, 8 Feb 2025 15:47:00 -0800 Subject: [PATCH] Some improvements to speedup alignment computation in MLX Whisper (#1259) * some improvements to speedup alignment computation in MLX Whisper * fix alignment --- whisper/mlx_whisper/timing.py | 9 ++++----- whisper/mlx_whisper/whisper.py | 2 +- 2 files changed, 5 insertions(+), 6 deletions(-) diff --git a/whisper/mlx_whisper/timing.py b/whisper/mlx_whisper/timing.py index 04915deb..07b81186 100644 --- a/whisper/mlx_whisper/timing.py +++ b/whisper/mlx_whisper/timing.py @@ -134,9 +134,7 @@ def find_alignment( logits, cross_qk = model.forward_with_cross_qk(mel[None, :], tokens[None, :]) # consider only the logits associated with predicting text sampled_logits = logits[0][len(tokenizer.sot_sequence) : -2, : tokenizer.eot] - token_probs = mx.softmax(sampled_logits.astype(mx.float32), axis=-1).astype( - sampled_logits.dtype - ) + token_probs = mx.softmax(sampled_logits, precise=True, axis=-1) text_token_probs = mx.take_along_axis( token_probs, mx.array(text_tokens)[:, None], axis=1 ).squeeze(1) @@ -144,10 +142,11 @@ def find_alignment( # heads * tokens * frames weights = mx.stack( - [cross_qk[_l.item()][0, _h.item()] for _l, _h in model.alignment_heads] + [cross_qk[_l][0, _h] for _l, _h in model.alignment_heads.tolist()] ) weights = weights[:, :, : num_frames // 2] - weights = mx.softmax(weights * qk_scale, axis=-1) + weights = mx.softmax(weights * qk_scale, axis=-1, precise=True) + weights = weights.astype(mx.float32) mean = mx.mean(weights, axis=-2, keepdims=True) std = mx.var(weights, axis=-2, keepdims=True, ddof=0).sqrt() weights = (weights - mean) / std diff --git a/whisper/mlx_whisper/whisper.py b/whisper/mlx_whisper/whisper.py index 1c2b390e..5c85195c 100644 --- a/whisper/mlx_whisper/whisper.py +++ b/whisper/mlx_whisper/whisper.py @@ -84,7 +84,7 @@ class MultiHeadAttention(nn.Module): w = mx.softmax(qk, axis=-1, precise=True) out = (w @ v).transpose(0, 2, 1, 3) out = out.reshape(n_batch, n_ctx, n_state) - return out, qk.astype(mx.float32) + return out, qk class ResidualAttentionBlock(nn.Module):