From f3960c2b4bddfc762a7c5d543c4bbf2f5610fefb Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Fri, 7 Feb 2025 09:30:20 -0800 Subject: [PATCH] fix alignment --- whisper/mlx_whisper/timing.py | 2 +- whisper/mlx_whisper/whisper.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/whisper/mlx_whisper/timing.py b/whisper/mlx_whisper/timing.py index 5ec0a365..07b81186 100644 --- a/whisper/mlx_whisper/timing.py +++ b/whisper/mlx_whisper/timing.py @@ -142,7 +142,7 @@ def find_alignment( # heads * tokens * frames weights = mx.stack( - [cross_qk[_l.item()][0, _h.item()] for _l, _h in model.alignment_heads] + [cross_qk[_l][0, _h] for _l, _h in model.alignment_heads.tolist()] ) weights = weights[:, :, : num_frames // 2] weights = mx.softmax(weights * qk_scale, axis=-1, precise=True) diff --git a/whisper/mlx_whisper/whisper.py b/whisper/mlx_whisper/whisper.py index badb05e7..5c85195c 100644 --- a/whisper/mlx_whisper/whisper.py +++ b/whisper/mlx_whisper/whisper.py @@ -228,13 +228,13 @@ class Whisper(nn.Module): def set_alignment_heads(self, dump: Union[bytes, np.ndarray]): if isinstance(dump, np.ndarray): - self.alignment_heads = dump + self.alignment_heads = mx.array(dump) elif isinstance(dump, bytes): array = np.frombuffer( gzip.decompress(base64.b85decode(dump)), dtype=bool ).copy() mask = array.reshape(self.dims.n_text_layer, self.dims.n_text_head) - self.alignment_heads = np.asarray(mask.nonzero()).T + self.alignment_heads = mx.array(np.asarray(mask.nonzero()).T) else: raise ValueError( f"Invalid type for `dump`: {type(dump)}. Expected a np.ndarray or base85-encoded bytes containing"