From 9414429309bec6cda50125aa13ae9964f32f2300 Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Fri, 7 Feb 2025 09:07:24 -0800 Subject: [PATCH] some improvements to speedup alignment computation in MLX Whisper --- whisper/mlx_whisper/timing.py | 7 +++---- whisper/mlx_whisper/whisper.py | 6 +++--- 2 files changed, 6 insertions(+), 7 deletions(-) diff --git a/whisper/mlx_whisper/timing.py b/whisper/mlx_whisper/timing.py index 04915deb..5ec0a365 100644 --- a/whisper/mlx_whisper/timing.py +++ b/whisper/mlx_whisper/timing.py @@ -134,9 +134,7 @@ def find_alignment( logits, cross_qk = model.forward_with_cross_qk(mel[None, :], tokens[None, :]) # consider only the logits associated with predicting text sampled_logits = logits[0][len(tokenizer.sot_sequence) : -2, : tokenizer.eot] - token_probs = mx.softmax(sampled_logits.astype(mx.float32), axis=-1).astype( - sampled_logits.dtype - ) + token_probs = mx.softmax(sampled_logits, precise=True, axis=-1) text_token_probs = mx.take_along_axis( token_probs, mx.array(text_tokens)[:, None], axis=1 ).squeeze(1) @@ -147,7 +145,8 @@ def find_alignment( [cross_qk[_l.item()][0, _h.item()] for _l, _h in model.alignment_heads] ) weights = weights[:, :, : num_frames // 2] - weights = mx.softmax(weights * qk_scale, axis=-1) + weights = mx.softmax(weights * qk_scale, axis=-1, precise=True) + weights = weights.astype(mx.float32) mean = mx.mean(weights, axis=-2, keepdims=True) std = mx.var(weights, axis=-2, keepdims=True, ddof=0).sqrt() weights = (weights - mean) / std diff --git a/whisper/mlx_whisper/whisper.py b/whisper/mlx_whisper/whisper.py index 1c2b390e..badb05e7 100644 --- a/whisper/mlx_whisper/whisper.py +++ b/whisper/mlx_whisper/whisper.py @@ -84,7 +84,7 @@ class MultiHeadAttention(nn.Module): w = mx.softmax(qk, axis=-1, precise=True) out = (w @ v).transpose(0, 2, 1, 3) out = out.reshape(n_batch, n_ctx, n_state) - return out, qk.astype(mx.float32) + return out, qk class ResidualAttentionBlock(nn.Module): @@ -228,13 +228,13 @@ class Whisper(nn.Module): def set_alignment_heads(self, dump: Union[bytes, np.ndarray]): if isinstance(dump, np.ndarray): - self.alignment_heads = mx.array(dump) + self.alignment_heads = dump elif isinstance(dump, bytes): array = np.frombuffer( gzip.decompress(base64.b85decode(dump)), dtype=bool ).copy() mask = array.reshape(self.dims.n_text_layer, self.dims.n_text_head) - self.alignment_heads = mx.array(np.asarray(mask.nonzero()).T) + self.alignment_heads = np.asarray(mask.nonzero()).T else: raise ValueError( f"Invalid type for `dump`: {type(dump)}. Expected a np.ndarray or base85-encoded bytes containing"