some improvements to speedup alignment computation in MLX Whisper

This commit is contained in:
Awni Hannun 2025-02-07 09:07:24 -08:00
parent e2e5478da5
commit 9414429309
2 changed files with 6 additions and 7 deletions

View File

@ -134,9 +134,7 @@ def find_alignment(
logits, cross_qk = model.forward_with_cross_qk(mel[None, :], tokens[None, :]) logits, cross_qk = model.forward_with_cross_qk(mel[None, :], tokens[None, :])
# consider only the logits associated with predicting text # consider only the logits associated with predicting text
sampled_logits = logits[0][len(tokenizer.sot_sequence) : -2, : tokenizer.eot] sampled_logits = logits[0][len(tokenizer.sot_sequence) : -2, : tokenizer.eot]
token_probs = mx.softmax(sampled_logits.astype(mx.float32), axis=-1).astype( token_probs = mx.softmax(sampled_logits, precise=True, axis=-1)
sampled_logits.dtype
)
text_token_probs = mx.take_along_axis( text_token_probs = mx.take_along_axis(
token_probs, mx.array(text_tokens)[:, None], axis=1 token_probs, mx.array(text_tokens)[:, None], axis=1
).squeeze(1) ).squeeze(1)
@ -147,7 +145,8 @@ def find_alignment(
[cross_qk[_l.item()][0, _h.item()] for _l, _h in model.alignment_heads] [cross_qk[_l.item()][0, _h.item()] for _l, _h in model.alignment_heads]
) )
weights = weights[:, :, : num_frames // 2] weights = weights[:, :, : num_frames // 2]
weights = mx.softmax(weights * qk_scale, axis=-1) weights = mx.softmax(weights * qk_scale, axis=-1, precise=True)
weights = weights.astype(mx.float32)
mean = mx.mean(weights, axis=-2, keepdims=True) mean = mx.mean(weights, axis=-2, keepdims=True)
std = mx.var(weights, axis=-2, keepdims=True, ddof=0).sqrt() std = mx.var(weights, axis=-2, keepdims=True, ddof=0).sqrt()
weights = (weights - mean) / std weights = (weights - mean) / std

View File

@ -84,7 +84,7 @@ class MultiHeadAttention(nn.Module):
w = mx.softmax(qk, axis=-1, precise=True) w = mx.softmax(qk, axis=-1, precise=True)
out = (w @ v).transpose(0, 2, 1, 3) out = (w @ v).transpose(0, 2, 1, 3)
out = out.reshape(n_batch, n_ctx, n_state) out = out.reshape(n_batch, n_ctx, n_state)
return out, qk.astype(mx.float32) return out, qk
class ResidualAttentionBlock(nn.Module): class ResidualAttentionBlock(nn.Module):
@ -228,13 +228,13 @@ class Whisper(nn.Module):
def set_alignment_heads(self, dump: Union[bytes, np.ndarray]): def set_alignment_heads(self, dump: Union[bytes, np.ndarray]):
if isinstance(dump, np.ndarray): if isinstance(dump, np.ndarray):
self.alignment_heads = mx.array(dump) self.alignment_heads = dump
elif isinstance(dump, bytes): elif isinstance(dump, bytes):
array = np.frombuffer( array = np.frombuffer(
gzip.decompress(base64.b85decode(dump)), dtype=bool gzip.decompress(base64.b85decode(dump)), dtype=bool
).copy() ).copy()
mask = array.reshape(self.dims.n_text_layer, self.dims.n_text_head) mask = array.reshape(self.dims.n_text_layer, self.dims.n_text_head)
self.alignment_heads = mx.array(np.asarray(mask.nonzero()).T) self.alignment_heads = np.asarray(mask.nonzero()).T
else: else:
raise ValueError( raise ValueError(
f"Invalid type for `dump`: {type(dump)}. Expected a np.ndarray or base85-encoded bytes containing" f"Invalid type for `dump`: {type(dump)}. Expected a np.ndarray or base85-encoded bytes containing"