diff --git a/whisper/mlx_whisper/timing.py b/whisper/mlx_whisper/timing.py index 04915deb..07b81186 100644 --- a/whisper/mlx_whisper/timing.py +++ b/whisper/mlx_whisper/timing.py @@ -134,9 +134,7 @@ def find_alignment( logits, cross_qk = model.forward_with_cross_qk(mel[None, :], tokens[None, :]) # consider only the logits associated with predicting text sampled_logits = logits[0][len(tokenizer.sot_sequence) : -2, : tokenizer.eot] - token_probs = mx.softmax(sampled_logits.astype(mx.float32), axis=-1).astype( - sampled_logits.dtype - ) + token_probs = mx.softmax(sampled_logits, precise=True, axis=-1) text_token_probs = mx.take_along_axis( token_probs, mx.array(text_tokens)[:, None], axis=1 ).squeeze(1) @@ -144,10 +142,11 @@ def find_alignment( # heads * tokens * frames weights = mx.stack( - [cross_qk[_l.item()][0, _h.item()] for _l, _h in model.alignment_heads] + [cross_qk[_l][0, _h] for _l, _h in model.alignment_heads.tolist()] ) weights = weights[:, :, : num_frames // 2] - weights = mx.softmax(weights * qk_scale, axis=-1) + weights = mx.softmax(weights * qk_scale, axis=-1, precise=True) + weights = weights.astype(mx.float32) mean = mx.mean(weights, axis=-2, keepdims=True) std = mx.var(weights, axis=-2, keepdims=True, ddof=0).sqrt() weights = (weights - mean) / std diff --git a/whisper/mlx_whisper/whisper.py b/whisper/mlx_whisper/whisper.py index 1c2b390e..5c85195c 100644 --- a/whisper/mlx_whisper/whisper.py +++ b/whisper/mlx_whisper/whisper.py @@ -84,7 +84,7 @@ class MultiHeadAttention(nn.Module): w = mx.softmax(qk, axis=-1, precise=True) out = (w @ v).transpose(0, 2, 1, 3) out = out.reshape(n_batch, n_ctx, n_state) - return out, qk.astype(mx.float32) + return out, qk class ResidualAttentionBlock(nn.Module):