From bf9926489ea5fbdb8a0ce5a224293c95631d588d Mon Sep 17 00:00:00 2001 From: bofeng huang Date: Sun, 7 Jan 2024 19:01:29 +0100 Subject: [PATCH] [Whisper] Add word timestamps and confidence scores (#201) * Add word timestamps and confidence scores * Create a separate forward_with_cross_qk function * Move multiple ops from np to mlx, clean comments * Save alignment_heads * Cast qk to fp32 * Add test for word-level timestamps and confidence scores * format + readme * nit --------- Co-authored-by: Awni Hannun --- whisper/README.md | 21 +++- whisper/convert.py | 4 + whisper/test.py | 131 +++++++++++++++++++++++++ whisper/whisper/decoding.py | 2 +- whisper/whisper/timing.py | 123 +++++++----------------- whisper/whisper/transcribe.py | 174 ++++++++++++++++++++++++++++++++-- whisper/whisper/whisper.py | 54 ++++++++--- 7 files changed, 398 insertions(+), 111 deletions(-) diff --git a/whisper/README.md b/whisper/README.md index 50fc0764..071b3fc4 100644 --- a/whisper/README.md +++ b/whisper/README.md @@ -2,7 +2,7 @@ Speech recognition with Whisper in MLX. Whisper is a set of open source speech recognition models from OpenAI, ranging from 39 million to 1.5 billion -parameters[^1]. +parameters.[^1] ### Setup @@ -19,7 +19,8 @@ Install [`ffmpeg`](https://ffmpeg.org/): brew install ffmpeg ``` -Next, download the Whisper PyTorch checkpoint and convert the weights to the MLX format. For example, to convert the `tiny` model use: +Next, download the Whisper PyTorch checkpoint and convert the weights to the +MLX format. For example, to convert the `tiny` model use: ``` python convert.py --torch-name-or-path tiny --mlx-path mlx_models/tiny @@ -45,10 +46,24 @@ the converted `weights.npz` and `config.json` there. Transcribe audio with: -``` +```python import whisper text = whisper.transcribe(speech_file)["text"] ``` +The `transcribe` function also supports word-level timestamps. You can generate +these with: + +```python +output = whisper.transcribe(speech_file, word_timestamps=True) +print(output["segments"][0]["words"]) +``` + +To see more transcription options use: + +``` +>>> help(whisper.transcribe) +``` + [^1]: Refer to the [arXiv paper](https://arxiv.org/abs/2212.04356), [blog post](https://openai.com/research/whisper), and [code](https://github.com/openai/whisper) for more details. diff --git a/whisper/convert.py b/whisper/convert.py index 15a12855..2e4ebce5 100644 --- a/whisper/convert.py +++ b/whisper/convert.py @@ -199,6 +199,10 @@ def torch_to_mlx( mlx_model = Whisper(torch_model.dims, dtype) params = tree_map(lambda p: p.astype(dtype), params) mlx_model.update(params) + + if (alignment_heads := getattr(torch_model, "alignment_heads", None)) is not None: + mlx_model.set_alignment_heads(alignment_heads.indices().T.numpy()) + return mlx_model diff --git a/whisper/test.py b/whisper/test.py index 48a09152..36ad6f1c 100644 --- a/whisper/test.py +++ b/whisper/test.py @@ -311,6 +311,137 @@ class TestWhisper(unittest.TestCase): check_segment(result["segments"][5], expected_5) check_segment(result["segments"][73], expected_73) + def test_transcribe_word_level_timestamps_confidence_scores(self): + result = whisper.transcribe( + # TEST_AUDIO, model_path=MLX_FP32_MODEL_PATH, word_timestamps=True, fp16=False + TEST_AUDIO, + model_path=MLX_FP16_MODEL_PATH, + word_timestamps=True, + ) + + # result predicted with openai-whisper + expected_0 = [ + { + "word": " Then", + "start": 0.0, + "end": 0.94, + "probability": 0.855542778968811, + }, + { + "word": " the", + "start": 0.94, + "end": 1.12, + "probability": 0.6500106453895569, + }, + { + "word": " good", + "start": 1.12, + "end": 1.32, + "probability": 0.5503873825073242, + }, + { + "word": " soul", + "start": 1.32, + "end": 1.56, + "probability": 0.46757155656814575, + }, + { + "word": " openly", + "start": 1.56, + "end": 2.0, + "probability": 0.9840946793556213, + }, + { + "word": " sorted", + "start": 2.0, + "end": 2.38, + "probability": 0.24167272448539734, + }, + { + "word": " the", + "start": 2.38, + "end": 2.58, + "probability": 0.9875414967536926, + }, + { + "word": " boat", + "start": 2.58, + "end": 2.8, + "probability": 0.5856029391288757, + }, + { + "word": " and", + "start": 2.8, + "end": 2.98, + "probability": 0.913351833820343, + }, + { + "word": " she", + "start": 2.98, + "end": 3.1, + "probability": 0.9913808703422546, + }, + { + "word": " had", + "start": 3.1, + "end": 3.32, + "probability": 0.9952940344810486, + }, + { + "word": " buoyed", + "start": 3.32, + "end": 3.58, + "probability": 0.6411589980125427, + }, + { + "word": " so", + "start": 3.58, + "end": 3.8, + "probability": 0.9682658314704895, + }, + { + "word": " long", + "start": 3.8, + "end": 4.06, + "probability": 0.9953522682189941, + }, + { + "word": " in", + "start": 4.06, + "end": 4.26, + "probability": 0.6745936870574951, + }, + { + "word": " secret", + "start": 4.26, + "end": 4.56, + "probability": 0.9905064702033997, + }, + { + "word": " and", + "start": 4.56, + "end": 4.9, + "probability": 0.856008768081665, + }, + { + "word": " bravely", + "start": 4.9, + "end": 5.28, + "probability": 0.8477402329444885, + }, + ] + + def check_words(words, expected_words): + for word, expected_word in zip(words, expected_words): + for k, v in expected_word.items(): + if isinstance(v, float): + self.assertAlmostEqual(word[k], v, places=1) + else: + self.assertEqual(word[k], v) + + # Randomly check a couple of segments + check_words(result["segments"][0]["words"], expected_0) + class TestAudio(unittest.TestCase): def test_load(self): diff --git a/whisper/whisper/decoding.py b/whisper/whisper/decoding.py index b7786f92..c2105972 100644 --- a/whisper/whisper/decoding.py +++ b/whisper/whisper/decoding.py @@ -141,7 +141,7 @@ class Inference: # only need to use the last token except in the first forward pass tokens = tokens[:, -1:] - logits, self.kv_cache = self.model.decoder( + logits, self.kv_cache, _ = self.model.decoder( tokens, audio_features, kv_cache=self.kv_cache ) return logits.astype(mx.float32) diff --git a/whisper/whisper/timing.py b/whisper/whisper/timing.py index 5aee6043..13c36315 100644 --- a/whisper/whisper/timing.py +++ b/whisper/whisper/timing.py @@ -1,15 +1,13 @@ # Copyright © 2023 Apple Inc. import itertools -import subprocess -import warnings from dataclasses import dataclass from typing import TYPE_CHECKING, List +import mlx.core as mx import numba import numpy as np -import torch -import torch.nn.functional as F +from scipy import signal from .audio import HOP_LENGTH, SAMPLE_RATE, TOKENS_PER_SECOND from .tokenizer import Tokenizer @@ -18,7 +16,7 @@ if TYPE_CHECKING: from .model import Whisper -def median_filter(x: torch.Tensor, filter_width: int): +def median_filter(x: np.ndarray, filter_width: int): """Apply a median filter of width `filter_width` along the last dimension of `x`""" pad_width = filter_width // 2 if x.shape[-1] <= pad_width: @@ -33,22 +31,12 @@ def median_filter(x: torch.Tensor, filter_width: int): filter_width > 0 and filter_width % 2 == 1 ), "`filter_width` should be an odd number" - result = None - x = F.pad(x, (filter_width // 2, filter_width // 2, 0, 0), mode="reflect") - if x.is_cuda: - try: - from .triton_ops import median_filter_cuda + x = np.pad(x, ((0, 0), (0, 0), (pad_width, pad_width)), mode="reflect") - result = median_filter_cuda(x, filter_width) - except (RuntimeError, subprocess.CalledProcessError): - warnings.warn( - "Failed to launch Triton kernels, likely due to missing CUDA toolkit; " - "falling back to a slower median kernel implementation..." - ) - - if result is None: - # sort() is faster than torch.median (https://github.com/pytorch/pytorch/issues/51450) - result = x.unfold(-1, filter_width, 1).sort()[0][..., filter_width // 2] + # todo: more efficient version in mlx + result = signal.medfilt(x.astype(np.float32), kernel_size=(1, 1, filter_width))[ + ..., pad_width:-pad_width + ] if ndim <= 2: result = result[0, 0] @@ -107,50 +95,9 @@ def dtw_cpu(x: np.ndarray): return backtrace(trace) -def dtw_cuda(x, BLOCK_SIZE=1024): - from .triton_ops import dtw_kernel - - M, N = x.shape - assert M < BLOCK_SIZE, f"M should be smaller than {BLOCK_SIZE=}" - - x_skew = ( - F.pad(x, (0, M + 1), value=np.inf).flatten()[: M * (N + M)].reshape(M, N + M) - ) - x_skew = x_skew.T.contiguous() - cost = torch.ones(N + M + 2, M + 2) * np.inf - cost[0, 0] = 0 - cost = cost.cuda() - trace = torch.zeros_like(cost, dtype=torch.int32) - - dtw_kernel[(1,)]( - cost, - trace, - x_skew, - x_skew.stride(0), - cost.stride(0), - trace.stride(0), - N, - M, - BLOCK_SIZE=BLOCK_SIZE, - ) - - trace = trace.T.flatten()[: (M + 1) * (M + N + 3)].reshape(M + 1, M + N + 3)[ - :, : N + 1 - ] - return backtrace(trace.cpu().numpy()) - - -def dtw(x: torch.Tensor) -> np.ndarray: - if x.is_cuda: - try: - return dtw_cuda(x) - except (RuntimeError, subprocess.CalledProcessError): - warnings.warn( - "Failed to launch Triton kernels, likely due to missing CUDA toolkit; " - "falling back to a slower DTW implementation..." - ) - - return dtw_cpu(x.double().cpu().numpy()) +def dtw(x: np.ndarray) -> np.ndarray: + # todo: more efficient version in mlx + return dtw_cpu(x) @dataclass @@ -166,7 +113,7 @@ def find_alignment( model: "Whisper", tokenizer: Tokenizer, text_tokens: List[int], - mel: torch.Tensor, + mel: mx.array, num_frames: int, *, medfilt_width: int = 7, @@ -175,41 +122,36 @@ def find_alignment( if len(text_tokens) == 0: return [] - tokens = torch.tensor( + tokens = mx.array( [ *tokenizer.sot_sequence, tokenizer.no_timestamps, *text_tokens, tokenizer.eot, ] - ).to(model.device) + ) - # install hooks on the cross attention layers to retrieve the attention weights - QKs = [None] * model.dims.n_text_layer - hooks = [ - block.cross_attn.register_forward_hook( - lambda _, ins, outs, index=i: QKs.__setitem__(index, outs[-1][0]) - ) - for i, block in enumerate(model.decoder.blocks) - ] - - with torch.no_grad(): - logits = model(mel.unsqueeze(0), tokens.unsqueeze(0))[0] - sampled_logits = logits[len(tokenizer.sot_sequence) :, : tokenizer.eot] - token_probs = sampled_logits.softmax(dim=-1) - text_token_probs = token_probs[np.arange(len(text_tokens)), text_tokens] - text_token_probs = text_token_probs.tolist() - - for hook in hooks: - hook.remove() + logits, cross_qk = model.forward_with_cross_qk(mel[None, :], tokens[None, :]) + # consider only the logits associated with predicting text + sampled_logits = logits[0][len(tokenizer.sot_sequence) : -2, : tokenizer.eot] + token_probs = mx.softmax(sampled_logits.astype(mx.float32), axis=-1).astype( + sampled_logits.dtype + ) + text_token_probs = mx.take_along_axis( + token_probs, mx.array(text_tokens)[:, None], axis=1 + ).squeeze(1) + text_token_probs = np.array(text_token_probs) # heads * tokens * frames - weights = torch.stack([QKs[_l][_h] for _l, _h in model.alignment_heads.indices().T]) + weights = mx.stack( + [cross_qk[_l.item()][0, _h.item()] for _l, _h in model.alignment_heads] + ) weights = weights[:, :, : num_frames // 2] - weights = (weights * qk_scale).softmax(dim=-1) - std, mean = torch.std_mean(weights, dim=-2, keepdim=True, unbiased=False) + weights = mx.softmax(weights * qk_scale, axis=-1) + mean = mx.mean(weights, axis=-2, keepdims=True) + std = mx.var(weights, axis=-2, keepdims=True, ddof=0).sqrt() weights = (weights - mean) / std - weights = median_filter(weights, medfilt_width) + weights = median_filter(np.array(weights), medfilt_width) matrix = weights.mean(axis=0) matrix = matrix[len(tokenizer.sot_sequence) : -1] @@ -281,7 +223,7 @@ def add_word_timestamps( segments: List[dict], model: "Whisper", tokenizer: Tokenizer, - mel: torch.Tensor, + mel: mx.array, num_frames: int, prepend_punctuations: str = "\"'“¿([{-", append_punctuations: str = "\"'.。,,!!??::”)]}、", @@ -301,6 +243,7 @@ def add_word_timestamps( word_durations = np.array([t.end - t.start for t in alignment]) word_durations = word_durations[word_durations.nonzero()] median_duration = np.median(word_durations) if len(word_durations) > 0 else 0.0 + median_duration = min(0.7, float(median_duration)) max_duration = median_duration * 2 # hack: truncate long words at sentence boundaries. diff --git a/whisper/whisper/transcribe.py b/whisper/whisper/transcribe.py index 330aef42..704fd36c 100644 --- a/whisper/whisper/transcribe.py +++ b/whisper/whisper/transcribe.py @@ -1,7 +1,8 @@ # Copyright © 2023 Apple Inc. import sys -from typing import Optional, Tuple, Union +import warnings +from typing import List, Optional, Tuple, Union import mlx.core as mx import numpy as np @@ -18,7 +19,8 @@ from .audio import ( ) from .decoding import DecodingOptions, DecodingResult from .load_models import load_model -from .tokenizer import LANGUAGES, TO_LANGUAGE_CODE, get_tokenizer +from .timing import add_word_timestamps +from .tokenizer import LANGUAGES, get_tokenizer def _format_timestamp(seconds: float): @@ -38,6 +40,13 @@ def _format_timestamp(seconds: float): return f"{hours_marker}{minutes:02d}:{seconds:02d}.{milliseconds:03d}" +def _get_end(segments: List[dict]) -> Optional[float]: + return next( + (w["end"] for s in reversed(segments) for w in reversed(s["words"])), + segments[-1]["end"] if segments else None, + ) + + class ModelHolder: model = None model_path = None @@ -61,8 +70,11 @@ def transcribe( no_speech_threshold: Optional[float] = 0.6, condition_on_previous_text: bool = True, initial_prompt: Optional[str] = None, + word_timestamps: bool = False, prepend_punctuations: str = "\"'“¿([{-", append_punctuations: str = "\"'.。,,!!??::”)]}、", + clip_timestamps: Union[str, List[float]] = "0", + hallucination_silence_threshold: Optional[float] = None, **decode_options, ): """ @@ -99,6 +111,16 @@ def transcribe( disabling may make the text inconsistent across windows, but the model becomes less prone to getting stuck in a failure loop, such as repetition looping or timestamps going out of sync. + word_timestamps: bool + Extract word-level timestamps using the cross-attention pattern and dynamic time warping, + and include the timestamps for each word in each segment. + + prepend_punctuations: str + If word_timestamps is True, merge these punctuation symbols with the next word + + append_punctuations: str + If word_timestamps is True, merge these punctuation symbols with the previous word + initial_prompt: Optional[str] Optional text to provide as a prompt for the first window. This can be used to provide, or "prompt-engineer" a context for transcription, e.g. custom vocabularies or proper nouns @@ -107,6 +129,14 @@ def transcribe( decode_options: dict Keyword arguments to construct `DecodingOptions` instances + clip_timestamps: Union[str, List[float]] + Comma-separated list start,end,start,end,... timestamps (in seconds) of clips to process. + The last end timestamp defaults to the end of the file. + + hallucination_silence_threshold: Optional[float] + When word_timestamps is True, skip silent periods longer than this threshold (in seconds) + when a possible hallucination is detected + Returns ------- A dictionary containing the resulting text ("text") and segment-level details ("segments"), and @@ -119,6 +149,7 @@ def transcribe( # Pad 30-seconds of silence to the input audio, for slicing mel = log_mel_spectrogram(audio, n_mels=model.dims.n_mels, padding=N_SAMPLES) content_frames = mel.shape[-2] - N_FRAMES + content_duration = float(content_frames * HOP_LENGTH / SAMPLE_RATE) if verbose: system_encoding = sys.getdefaultencoding() @@ -155,6 +186,22 @@ def transcribe( task=task, ) + if isinstance(clip_timestamps, str): + clip_timestamps = [ + float(ts) for ts in (clip_timestamps.split(",") if clip_timestamps else []) + ] + seek_points: List[int] = [round(ts * FRAMES_PER_SECOND) for ts in clip_timestamps] + if len(seek_points) == 0: + seek_points.append(0) + if len(seek_points) % 2 == 1: + seek_points.append(content_frames) + seek_clips: List[Tuple[int, int]] = list(zip(seek_points[::2], seek_points[1::2])) + + punctuation = "\"'“¿([{-\"'.。,,!!??::”)]}、" + + if word_timestamps and task == "translate": + warnings.warn("Word-level timestamps on translations may not be reliable.") + def decode_with_fallback(segment: mx.array) -> DecodingResult: temperatures = ( [temperature] if isinstance(temperature, (int, float)) else temperature @@ -195,7 +242,8 @@ def transcribe( return decode_result - seek = 0 + clip_idx = 0 + seek = seek_clips[clip_idx][0] input_stride = N_FRAMES // model.dims.n_audio_ctx # mel frames per output token: 2 time_precision = ( input_stride * HOP_LENGTH / SAMPLE_RATE @@ -232,10 +280,23 @@ def transcribe( total=content_frames, unit="frames", disable=verbose is not False ) as pbar: last_speech_timestamp = 0.0 - while seek < content_frames: + # NOTE: This loop is obscurely flattened to make the diff readable. + # A later commit should turn this into a simpler nested loop. + # for seek_clip_start, seek_clip_end in seek_clips: + # while seek < seek_clip_end + while clip_idx < len(seek_clips): + seek_clip_start, seek_clip_end = seek_clips[clip_idx] + if seek < seek_clip_start: + seek = seek_clip_start + if seek >= seek_clip_end: + clip_idx += 1 + if clip_idx < len(seek_clips): + seek = seek_clips[clip_idx][0] + continue time_offset = float(seek * HOP_LENGTH / SAMPLE_RATE) - mel_segment = mel[seek : seek + N_FRAMES] - segment_size = min(N_FRAMES, content_frames - seek) + window_end_time = float((seek + N_FRAMES) * HOP_LENGTH / SAMPLE_RATE) + segment_size = min(N_FRAMES, content_frames - seek, seek_clip_end - seek) + mel_segment = mel[seek : seek + segment_size] segment_duration = segment_size * HOP_LENGTH / SAMPLE_RATE mel_segment = pad_or_trim(mel_segment, N_FRAMES, axis=-2).astype(dtype) @@ -260,6 +321,30 @@ def transcribe( previous_seek = seek current_segments = [] + # anomalous words are very long/short/improbable + def word_anomaly_score(word: dict) -> float: + probability = word.get("probability", 0.0) + duration = word["end"] - word["start"] + score = 0.0 + if probability < 0.15: + score += 1.0 + if duration < 0.133: + score += (0.133 - duration) * 15 + if duration > 2.0: + score += duration - 2.0 + return score + + def is_segment_anomaly(segment: Optional[dict]) -> bool: + if segment is None or not segment["words"]: + return False + words = [w for w in segment["words"] if w["word"] not in punctuation] + words = words[:8] + score = sum(word_anomaly_score(w) for w in words) + return score >= 3 or score + 0.01 >= len(words) + + def next_words_segment(segments: List[dict]) -> Optional[dict]: + return next((s for s in segments if s["words"]), None) + timestamp_tokens = tokens >= tokenizer.timestamp_begin single_timestamp_ending = timestamp_tokens[-2:].tolist() == [False, True] @@ -324,6 +409,83 @@ def transcribe( ) seek += segment_size + if word_timestamps: + add_word_timestamps( + segments=current_segments, + model=model, + tokenizer=tokenizer, + mel=mel_segment, + num_frames=segment_size, + prepend_punctuations=prepend_punctuations, + append_punctuations=append_punctuations, + last_speech_timestamp=last_speech_timestamp, + ) + + if not single_timestamp_ending: + last_word_end = _get_end(current_segments) + if last_word_end is not None and last_word_end > time_offset: + seek = round(last_word_end * FRAMES_PER_SECOND) + + # skip silence before possible hallucinations + if hallucination_silence_threshold is not None: + threshold = hallucination_silence_threshold + if not single_timestamp_ending: + last_word_end = _get_end(current_segments) + if last_word_end is not None and last_word_end > time_offset: + remaining_duration = window_end_time - last_word_end + if remaining_duration > threshold: + seek = round(last_word_end * FRAMES_PER_SECOND) + else: + seek = previous_seek + segment_size + + # if first segment might be a hallucination, skip leading silence + first_segment = next_words_segment(current_segments) + if first_segment is not None and is_segment_anomaly(first_segment): + gap = first_segment["start"] - time_offset + if gap > threshold: + seek = previous_seek + round(gap * FRAMES_PER_SECOND) + continue + + # skip silence before any possible hallucination that is surrounded + # by silence or more hallucinations + hal_last_end = last_speech_timestamp + for si in range(len(current_segments)): + segment = current_segments[si] + if not segment["words"]: + continue + if is_segment_anomaly(segment): + next_segment = next_words_segment( + current_segments[si + 1 :] + ) + if next_segment is not None: + hal_next_start = next_segment["words"][0]["start"] + else: + hal_next_start = time_offset + segment_duration + silence_before = ( + segment["start"] - hal_last_end > threshold + or segment["start"] < threshold + or segment["start"] - time_offset < 2.0 + ) + silence_after = ( + hal_next_start - segment["end"] > threshold + or is_segment_anomaly(next_segment) + or window_end_time - segment["end"] < 2.0 + ) + if silence_before and silence_after: + seek = round( + max(time_offset + 1, segment["start"]) + * FRAMES_PER_SECOND + ) + if content_duration - segment["end"] < threshold: + seek = content_frames + current_segments[si:] = [] + break + hal_last_end = segment["end"] + + last_word_end = _get_end(current_segments) + if last_word_end is not None: + last_speech_timestamp = last_word_end + if verbose: for segment in current_segments: start, end, text = segment["start"], segment["end"], segment["text"] diff --git a/whisper/whisper/whisper.py b/whisper/whisper/whisper.py index dfeb1e73..183eacc9 100644 --- a/whisper/whisper/whisper.py +++ b/whisper/whisper/whisper.py @@ -4,7 +4,7 @@ import base64 import gzip import math from dataclasses import dataclass -from typing import Dict, Iterable, Optional +from typing import Union import mlx.core as mx import mlx.nn as nn @@ -72,8 +72,8 @@ class MultiHeadAttention(nn.Module): else: k, v = kv_cache - wv = self.qkv_attention(q, k, v, mask) - return self.out(wv), (k, v) + wv, qk = self.qkv_attention(q, k, v, mask) + return self.out(wv), (k, v), qk def qkv_attention(self, q, k, v, mask=None): n_batch, n_ctx, n_state = q.shape @@ -85,12 +85,12 @@ class MultiHeadAttention(nn.Module): qk = q @ k if mask is not None: qk = qk + mask[:n_ctx, :n_ctx] - qk = qk.astype(mx.float32) + w = mx.softmax(qk, axis=-1).astype(q.dtype) out = (w @ v).transpose(0, 2, 1, 3) out = out.reshape(n_batch, n_ctx, n_state) - return out + return out, qk class ResidualAttentionBlock(nn.Module): @@ -112,13 +112,16 @@ class ResidualAttentionBlock(nn.Module): def __call__(self, x, xa=None, mask=None, kv_cache=None): kv, cross_kv = kv_cache if kv_cache else (None, None) - y, kv = self.attn(self.attn_ln(x), mask=mask, kv_cache=kv) + y, kv, _ = self.attn(self.attn_ln(x), mask=mask, kv_cache=kv) x += y + cross_qk = None if self.cross_attn: - y, cross_kv = self.cross_attn(self.cross_attn_ln(x), xa, kv_cache=cross_kv) + y, cross_kv, cross_qk = self.cross_attn( + self.cross_attn_ln(x), xa, kv_cache=cross_kv + ) x += y x = x + self.mlp2(nn.gelu(self.mlp1(self.mlp_ln(x))).astype(x.dtype)) - return x, (kv, cross_kv) + return x, (kv, cross_kv), cross_qk class AudioEncoder(nn.Module): @@ -146,7 +149,7 @@ class AudioEncoder(nn.Module): x = x + self._positional_embedding for block in self.blocks: - x, _ = block(x) + x, _, _ = block(x) x = self.ln_post(x) return x @@ -191,11 +194,14 @@ class TextDecoder(nn.Module): if kv_cache is None: kv_cache = [None] * len(self.blocks) + cross_qk = [None] * len(self.blocks) for e, block in enumerate(self.blocks): - x, kv_cache[e] = block(x, xa, mask=self._mask, kv_cache=kv_cache[e]) + x, kv_cache[e], cross_qk[e] = block( + x, xa, mask=self._mask, kv_cache=kv_cache[e] + ) x = self.ln(x) - return x @ self.token_embedding.weight.T, kv_cache + return x @ self.token_embedding.weight.T, kv_cache, cross_qk class Whisper(nn.Module): @@ -218,6 +224,28 @@ class Whisper(nn.Module): self.dims.n_text_layer, dtype, ) + # use the last half among the decoder layers for time alignment by default; + # to use a specific set of heads, see `set_alignment_heads()` below. + all_heads = np.zeros( + (self.dims.n_text_layer, self.dims.n_text_head), dtype=bool + ) + all_heads[self.dims.n_text_layer // 2 :] = True + self.alignment_heads = mx.array(np.asarray(all_heads.nonzero()).T) + + def set_alignment_heads(self, dump: Union[bytes, np.ndarray]): + if isinstance(dump, np.ndarray): + self.alignment_heads = mx.array(dump) + elif isinstance(dump, bytes): + array = np.frombuffer( + gzip.decompress(base64.b85decode(dump)), dtype=bool + ).copy() + mask = array.reshape(self.dims.n_text_layer, self.dims.n_text_head) + self.alignment_heads = mx.array(np.asarray(mask.nonzero()).T) + else: + raise ValueError( + f"Invalid type for `dump`: {type(dump)}. Expected a np.ndarray or base85-encoded bytes containing" + " alignment_head information" + ) def embed_audio(self, mel): return self.encoder(mel) @@ -225,6 +253,10 @@ class Whisper(nn.Module): def logits(self, tokens, audio_features): return self.decoder(tokens, audio_features)[0] + def forward_with_cross_qk(self, mel, tokens): + logits, _, cross_qk = self.decoder(tokens, self.encoder(mel)) + return logits, cross_qk + def __call__(self, mel, tokens): return self.decoder(tokens, self.encoder(mel))[0]