[Whisper] Add word timestamps and confidence scores (#201)

* Add word timestamps and confidence scores

* Create a separate forward_with_cross_qk function

* Move multiple ops from np to mlx, clean comments

* Save alignment_heads

* Cast qk to fp32

* Add test for word-level timestamps and confidence scores

* format + readme

* nit

---------

Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
bofeng huang 2024-01-07 19:01:29 +01:00 committed by GitHub
parent 25ebd36112
commit bf9926489e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 398 additions and 111 deletions

View File

@ -2,7 +2,7 @@
Speech recognition with Whisper in MLX. Whisper is a set of open source speech Speech recognition with Whisper in MLX. Whisper is a set of open source speech
recognition models from OpenAI, ranging from 39 million to 1.5 billion recognition models from OpenAI, ranging from 39 million to 1.5 billion
parameters[^1]. parameters.[^1]
### Setup ### Setup
@ -19,7 +19,8 @@ Install [`ffmpeg`](https://ffmpeg.org/):
brew install ffmpeg brew install ffmpeg
``` ```
Next, download the Whisper PyTorch checkpoint and convert the weights to the MLX format. For example, to convert the `tiny` model use: Next, download the Whisper PyTorch checkpoint and convert the weights to the
MLX format. For example, to convert the `tiny` model use:
``` ```
python convert.py --torch-name-or-path tiny --mlx-path mlx_models/tiny python convert.py --torch-name-or-path tiny --mlx-path mlx_models/tiny
@ -45,10 +46,24 @@ the converted `weights.npz` and `config.json` there.
Transcribe audio with: Transcribe audio with:
``` ```python
import whisper import whisper
text = whisper.transcribe(speech_file)["text"] text = whisper.transcribe(speech_file)["text"]
``` ```
The `transcribe` function also supports word-level timestamps. You can generate
these with:
```python
output = whisper.transcribe(speech_file, word_timestamps=True)
print(output["segments"][0]["words"])
```
To see more transcription options use:
```
>>> help(whisper.transcribe)
```
[^1]: Refer to the [arXiv paper](https://arxiv.org/abs/2212.04356), [blog post](https://openai.com/research/whisper), and [code](https://github.com/openai/whisper) for more details. [^1]: Refer to the [arXiv paper](https://arxiv.org/abs/2212.04356), [blog post](https://openai.com/research/whisper), and [code](https://github.com/openai/whisper) for more details.

View File

@ -199,6 +199,10 @@ def torch_to_mlx(
mlx_model = Whisper(torch_model.dims, dtype) mlx_model = Whisper(torch_model.dims, dtype)
params = tree_map(lambda p: p.astype(dtype), params) params = tree_map(lambda p: p.astype(dtype), params)
mlx_model.update(params) mlx_model.update(params)
if (alignment_heads := getattr(torch_model, "alignment_heads", None)) is not None:
mlx_model.set_alignment_heads(alignment_heads.indices().T.numpy())
return mlx_model return mlx_model

View File

@ -311,6 +311,137 @@ class TestWhisper(unittest.TestCase):
check_segment(result["segments"][5], expected_5) check_segment(result["segments"][5], expected_5)
check_segment(result["segments"][73], expected_73) check_segment(result["segments"][73], expected_73)
def test_transcribe_word_level_timestamps_confidence_scores(self):
result = whisper.transcribe(
# TEST_AUDIO, model_path=MLX_FP32_MODEL_PATH, word_timestamps=True, fp16=False
TEST_AUDIO,
model_path=MLX_FP16_MODEL_PATH,
word_timestamps=True,
)
# result predicted with openai-whisper
expected_0 = [
{
"word": " Then",
"start": 0.0,
"end": 0.94,
"probability": 0.855542778968811,
},
{
"word": " the",
"start": 0.94,
"end": 1.12,
"probability": 0.6500106453895569,
},
{
"word": " good",
"start": 1.12,
"end": 1.32,
"probability": 0.5503873825073242,
},
{
"word": " soul",
"start": 1.32,
"end": 1.56,
"probability": 0.46757155656814575,
},
{
"word": " openly",
"start": 1.56,
"end": 2.0,
"probability": 0.9840946793556213,
},
{
"word": " sorted",
"start": 2.0,
"end": 2.38,
"probability": 0.24167272448539734,
},
{
"word": " the",
"start": 2.38,
"end": 2.58,
"probability": 0.9875414967536926,
},
{
"word": " boat",
"start": 2.58,
"end": 2.8,
"probability": 0.5856029391288757,
},
{
"word": " and",
"start": 2.8,
"end": 2.98,
"probability": 0.913351833820343,
},
{
"word": " she",
"start": 2.98,
"end": 3.1,
"probability": 0.9913808703422546,
},
{
"word": " had",
"start": 3.1,
"end": 3.32,
"probability": 0.9952940344810486,
},
{
"word": " buoyed",
"start": 3.32,
"end": 3.58,
"probability": 0.6411589980125427,
},
{
"word": " so",
"start": 3.58,
"end": 3.8,
"probability": 0.9682658314704895,
},
{
"word": " long",
"start": 3.8,
"end": 4.06,
"probability": 0.9953522682189941,
},
{
"word": " in",
"start": 4.06,
"end": 4.26,
"probability": 0.6745936870574951,
},
{
"word": " secret",
"start": 4.26,
"end": 4.56,
"probability": 0.9905064702033997,
},
{
"word": " and",
"start": 4.56,
"end": 4.9,
"probability": 0.856008768081665,
},
{
"word": " bravely",
"start": 4.9,
"end": 5.28,
"probability": 0.8477402329444885,
},
]
def check_words(words, expected_words):
for word, expected_word in zip(words, expected_words):
for k, v in expected_word.items():
if isinstance(v, float):
self.assertAlmostEqual(word[k], v, places=1)
else:
self.assertEqual(word[k], v)
# Randomly check a couple of segments
check_words(result["segments"][0]["words"], expected_0)
class TestAudio(unittest.TestCase): class TestAudio(unittest.TestCase):
def test_load(self): def test_load(self):

View File

@ -141,7 +141,7 @@ class Inference:
# only need to use the last token except in the first forward pass # only need to use the last token except in the first forward pass
tokens = tokens[:, -1:] tokens = tokens[:, -1:]
logits, self.kv_cache = self.model.decoder( logits, self.kv_cache, _ = self.model.decoder(
tokens, audio_features, kv_cache=self.kv_cache tokens, audio_features, kv_cache=self.kv_cache
) )
return logits.astype(mx.float32) return logits.astype(mx.float32)

View File

@ -1,15 +1,13 @@
# Copyright © 2023 Apple Inc. # Copyright © 2023 Apple Inc.
import itertools import itertools
import subprocess
import warnings
from dataclasses import dataclass from dataclasses import dataclass
from typing import TYPE_CHECKING, List from typing import TYPE_CHECKING, List
import mlx.core as mx
import numba import numba
import numpy as np import numpy as np
import torch from scipy import signal
import torch.nn.functional as F
from .audio import HOP_LENGTH, SAMPLE_RATE, TOKENS_PER_SECOND from .audio import HOP_LENGTH, SAMPLE_RATE, TOKENS_PER_SECOND
from .tokenizer import Tokenizer from .tokenizer import Tokenizer
@ -18,7 +16,7 @@ if TYPE_CHECKING:
from .model import Whisper from .model import Whisper
def median_filter(x: torch.Tensor, filter_width: int): def median_filter(x: np.ndarray, filter_width: int):
"""Apply a median filter of width `filter_width` along the last dimension of `x`""" """Apply a median filter of width `filter_width` along the last dimension of `x`"""
pad_width = filter_width // 2 pad_width = filter_width // 2
if x.shape[-1] <= pad_width: if x.shape[-1] <= pad_width:
@ -33,22 +31,12 @@ def median_filter(x: torch.Tensor, filter_width: int):
filter_width > 0 and filter_width % 2 == 1 filter_width > 0 and filter_width % 2 == 1
), "`filter_width` should be an odd number" ), "`filter_width` should be an odd number"
result = None x = np.pad(x, ((0, 0), (0, 0), (pad_width, pad_width)), mode="reflect")
x = F.pad(x, (filter_width // 2, filter_width // 2, 0, 0), mode="reflect")
if x.is_cuda:
try:
from .triton_ops import median_filter_cuda
result = median_filter_cuda(x, filter_width) # todo: more efficient version in mlx
except (RuntimeError, subprocess.CalledProcessError): result = signal.medfilt(x.astype(np.float32), kernel_size=(1, 1, filter_width))[
warnings.warn( ..., pad_width:-pad_width
"Failed to launch Triton kernels, likely due to missing CUDA toolkit; " ]
"falling back to a slower median kernel implementation..."
)
if result is None:
# sort() is faster than torch.median (https://github.com/pytorch/pytorch/issues/51450)
result = x.unfold(-1, filter_width, 1).sort()[0][..., filter_width // 2]
if ndim <= 2: if ndim <= 2:
result = result[0, 0] result = result[0, 0]
@ -107,50 +95,9 @@ def dtw_cpu(x: np.ndarray):
return backtrace(trace) return backtrace(trace)
def dtw_cuda(x, BLOCK_SIZE=1024): def dtw(x: np.ndarray) -> np.ndarray:
from .triton_ops import dtw_kernel # todo: more efficient version in mlx
return dtw_cpu(x)
M, N = x.shape
assert M < BLOCK_SIZE, f"M should be smaller than {BLOCK_SIZE=}"
x_skew = (
F.pad(x, (0, M + 1), value=np.inf).flatten()[: M * (N + M)].reshape(M, N + M)
)
x_skew = x_skew.T.contiguous()
cost = torch.ones(N + M + 2, M + 2) * np.inf
cost[0, 0] = 0
cost = cost.cuda()
trace = torch.zeros_like(cost, dtype=torch.int32)
dtw_kernel[(1,)](
cost,
trace,
x_skew,
x_skew.stride(0),
cost.stride(0),
trace.stride(0),
N,
M,
BLOCK_SIZE=BLOCK_SIZE,
)
trace = trace.T.flatten()[: (M + 1) * (M + N + 3)].reshape(M + 1, M + N + 3)[
:, : N + 1
]
return backtrace(trace.cpu().numpy())
def dtw(x: torch.Tensor) -> np.ndarray:
if x.is_cuda:
try:
return dtw_cuda(x)
except (RuntimeError, subprocess.CalledProcessError):
warnings.warn(
"Failed to launch Triton kernels, likely due to missing CUDA toolkit; "
"falling back to a slower DTW implementation..."
)
return dtw_cpu(x.double().cpu().numpy())
@dataclass @dataclass
@ -166,7 +113,7 @@ def find_alignment(
model: "Whisper", model: "Whisper",
tokenizer: Tokenizer, tokenizer: Tokenizer,
text_tokens: List[int], text_tokens: List[int],
mel: torch.Tensor, mel: mx.array,
num_frames: int, num_frames: int,
*, *,
medfilt_width: int = 7, medfilt_width: int = 7,
@ -175,41 +122,36 @@ def find_alignment(
if len(text_tokens) == 0: if len(text_tokens) == 0:
return [] return []
tokens = torch.tensor( tokens = mx.array(
[ [
*tokenizer.sot_sequence, *tokenizer.sot_sequence,
tokenizer.no_timestamps, tokenizer.no_timestamps,
*text_tokens, *text_tokens,
tokenizer.eot, tokenizer.eot,
] ]
).to(model.device) )
# install hooks on the cross attention layers to retrieve the attention weights logits, cross_qk = model.forward_with_cross_qk(mel[None, :], tokens[None, :])
QKs = [None] * model.dims.n_text_layer # consider only the logits associated with predicting text
hooks = [ sampled_logits = logits[0][len(tokenizer.sot_sequence) : -2, : tokenizer.eot]
block.cross_attn.register_forward_hook( token_probs = mx.softmax(sampled_logits.astype(mx.float32), axis=-1).astype(
lambda _, ins, outs, index=i: QKs.__setitem__(index, outs[-1][0]) sampled_logits.dtype
) )
for i, block in enumerate(model.decoder.blocks) text_token_probs = mx.take_along_axis(
] token_probs, mx.array(text_tokens)[:, None], axis=1
).squeeze(1)
with torch.no_grad(): text_token_probs = np.array(text_token_probs)
logits = model(mel.unsqueeze(0), tokens.unsqueeze(0))[0]
sampled_logits = logits[len(tokenizer.sot_sequence) :, : tokenizer.eot]
token_probs = sampled_logits.softmax(dim=-1)
text_token_probs = token_probs[np.arange(len(text_tokens)), text_tokens]
text_token_probs = text_token_probs.tolist()
for hook in hooks:
hook.remove()
# heads * tokens * frames # heads * tokens * frames
weights = torch.stack([QKs[_l][_h] for _l, _h in model.alignment_heads.indices().T]) weights = mx.stack(
[cross_qk[_l.item()][0, _h.item()] for _l, _h in model.alignment_heads]
)
weights = weights[:, :, : num_frames // 2] weights = weights[:, :, : num_frames // 2]
weights = (weights * qk_scale).softmax(dim=-1) weights = mx.softmax(weights * qk_scale, axis=-1)
std, mean = torch.std_mean(weights, dim=-2, keepdim=True, unbiased=False) mean = mx.mean(weights, axis=-2, keepdims=True)
std = mx.var(weights, axis=-2, keepdims=True, ddof=0).sqrt()
weights = (weights - mean) / std weights = (weights - mean) / std
weights = median_filter(weights, medfilt_width) weights = median_filter(np.array(weights), medfilt_width)
matrix = weights.mean(axis=0) matrix = weights.mean(axis=0)
matrix = matrix[len(tokenizer.sot_sequence) : -1] matrix = matrix[len(tokenizer.sot_sequence) : -1]
@ -281,7 +223,7 @@ def add_word_timestamps(
segments: List[dict], segments: List[dict],
model: "Whisper", model: "Whisper",
tokenizer: Tokenizer, tokenizer: Tokenizer,
mel: torch.Tensor, mel: mx.array,
num_frames: int, num_frames: int,
prepend_punctuations: str = "\"'“¿([{-", prepend_punctuations: str = "\"'“¿([{-",
append_punctuations: str = "\"'.。,!?::”)]}、", append_punctuations: str = "\"'.。,!?::”)]}、",
@ -301,6 +243,7 @@ def add_word_timestamps(
word_durations = np.array([t.end - t.start for t in alignment]) word_durations = np.array([t.end - t.start for t in alignment])
word_durations = word_durations[word_durations.nonzero()] word_durations = word_durations[word_durations.nonzero()]
median_duration = np.median(word_durations) if len(word_durations) > 0 else 0.0 median_duration = np.median(word_durations) if len(word_durations) > 0 else 0.0
median_duration = min(0.7, float(median_duration))
max_duration = median_duration * 2 max_duration = median_duration * 2
# hack: truncate long words at sentence boundaries. # hack: truncate long words at sentence boundaries.

View File

@ -1,7 +1,8 @@
# Copyright © 2023 Apple Inc. # Copyright © 2023 Apple Inc.
import sys import sys
from typing import Optional, Tuple, Union import warnings
from typing import List, Optional, Tuple, Union
import mlx.core as mx import mlx.core as mx
import numpy as np import numpy as np
@ -18,7 +19,8 @@ from .audio import (
) )
from .decoding import DecodingOptions, DecodingResult from .decoding import DecodingOptions, DecodingResult
from .load_models import load_model from .load_models import load_model
from .tokenizer import LANGUAGES, TO_LANGUAGE_CODE, get_tokenizer from .timing import add_word_timestamps
from .tokenizer import LANGUAGES, get_tokenizer
def _format_timestamp(seconds: float): def _format_timestamp(seconds: float):
@ -38,6 +40,13 @@ def _format_timestamp(seconds: float):
return f"{hours_marker}{minutes:02d}:{seconds:02d}.{milliseconds:03d}" return f"{hours_marker}{minutes:02d}:{seconds:02d}.{milliseconds:03d}"
def _get_end(segments: List[dict]) -> Optional[float]:
return next(
(w["end"] for s in reversed(segments) for w in reversed(s["words"])),
segments[-1]["end"] if segments else None,
)
class ModelHolder: class ModelHolder:
model = None model = None
model_path = None model_path = None
@ -61,8 +70,11 @@ def transcribe(
no_speech_threshold: Optional[float] = 0.6, no_speech_threshold: Optional[float] = 0.6,
condition_on_previous_text: bool = True, condition_on_previous_text: bool = True,
initial_prompt: Optional[str] = None, initial_prompt: Optional[str] = None,
word_timestamps: bool = False,
prepend_punctuations: str = "\"'“¿([{-", prepend_punctuations: str = "\"'“¿([{-",
append_punctuations: str = "\"'.。,!?::”)]}、", append_punctuations: str = "\"'.。,!?::”)]}、",
clip_timestamps: Union[str, List[float]] = "0",
hallucination_silence_threshold: Optional[float] = None,
**decode_options, **decode_options,
): ):
""" """
@ -99,6 +111,16 @@ def transcribe(
disabling may make the text inconsistent across windows, but the model becomes less prone to disabling may make the text inconsistent across windows, but the model becomes less prone to
getting stuck in a failure loop, such as repetition looping or timestamps going out of sync. getting stuck in a failure loop, such as repetition looping or timestamps going out of sync.
word_timestamps: bool
Extract word-level timestamps using the cross-attention pattern and dynamic time warping,
and include the timestamps for each word in each segment.
prepend_punctuations: str
If word_timestamps is True, merge these punctuation symbols with the next word
append_punctuations: str
If word_timestamps is True, merge these punctuation symbols with the previous word
initial_prompt: Optional[str] initial_prompt: Optional[str]
Optional text to provide as a prompt for the first window. This can be used to provide, or Optional text to provide as a prompt for the first window. This can be used to provide, or
"prompt-engineer" a context for transcription, e.g. custom vocabularies or proper nouns "prompt-engineer" a context for transcription, e.g. custom vocabularies or proper nouns
@ -107,6 +129,14 @@ def transcribe(
decode_options: dict decode_options: dict
Keyword arguments to construct `DecodingOptions` instances Keyword arguments to construct `DecodingOptions` instances
clip_timestamps: Union[str, List[float]]
Comma-separated list start,end,start,end,... timestamps (in seconds) of clips to process.
The last end timestamp defaults to the end of the file.
hallucination_silence_threshold: Optional[float]
When word_timestamps is True, skip silent periods longer than this threshold (in seconds)
when a possible hallucination is detected
Returns Returns
------- -------
A dictionary containing the resulting text ("text") and segment-level details ("segments"), and A dictionary containing the resulting text ("text") and segment-level details ("segments"), and
@ -119,6 +149,7 @@ def transcribe(
# Pad 30-seconds of silence to the input audio, for slicing # Pad 30-seconds of silence to the input audio, for slicing
mel = log_mel_spectrogram(audio, n_mels=model.dims.n_mels, padding=N_SAMPLES) mel = log_mel_spectrogram(audio, n_mels=model.dims.n_mels, padding=N_SAMPLES)
content_frames = mel.shape[-2] - N_FRAMES content_frames = mel.shape[-2] - N_FRAMES
content_duration = float(content_frames * HOP_LENGTH / SAMPLE_RATE)
if verbose: if verbose:
system_encoding = sys.getdefaultencoding() system_encoding = sys.getdefaultencoding()
@ -155,6 +186,22 @@ def transcribe(
task=task, task=task,
) )
if isinstance(clip_timestamps, str):
clip_timestamps = [
float(ts) for ts in (clip_timestamps.split(",") if clip_timestamps else [])
]
seek_points: List[int] = [round(ts * FRAMES_PER_SECOND) for ts in clip_timestamps]
if len(seek_points) == 0:
seek_points.append(0)
if len(seek_points) % 2 == 1:
seek_points.append(content_frames)
seek_clips: List[Tuple[int, int]] = list(zip(seek_points[::2], seek_points[1::2]))
punctuation = "\"'“¿([{-\"'.。,!?::”)]}、"
if word_timestamps and task == "translate":
warnings.warn("Word-level timestamps on translations may not be reliable.")
def decode_with_fallback(segment: mx.array) -> DecodingResult: def decode_with_fallback(segment: mx.array) -> DecodingResult:
temperatures = ( temperatures = (
[temperature] if isinstance(temperature, (int, float)) else temperature [temperature] if isinstance(temperature, (int, float)) else temperature
@ -195,7 +242,8 @@ def transcribe(
return decode_result return decode_result
seek = 0 clip_idx = 0
seek = seek_clips[clip_idx][0]
input_stride = N_FRAMES // model.dims.n_audio_ctx # mel frames per output token: 2 input_stride = N_FRAMES // model.dims.n_audio_ctx # mel frames per output token: 2
time_precision = ( time_precision = (
input_stride * HOP_LENGTH / SAMPLE_RATE input_stride * HOP_LENGTH / SAMPLE_RATE
@ -232,10 +280,23 @@ def transcribe(
total=content_frames, unit="frames", disable=verbose is not False total=content_frames, unit="frames", disable=verbose is not False
) as pbar: ) as pbar:
last_speech_timestamp = 0.0 last_speech_timestamp = 0.0
while seek < content_frames: # NOTE: This loop is obscurely flattened to make the diff readable.
# A later commit should turn this into a simpler nested loop.
# for seek_clip_start, seek_clip_end in seek_clips:
# while seek < seek_clip_end
while clip_idx < len(seek_clips):
seek_clip_start, seek_clip_end = seek_clips[clip_idx]
if seek < seek_clip_start:
seek = seek_clip_start
if seek >= seek_clip_end:
clip_idx += 1
if clip_idx < len(seek_clips):
seek = seek_clips[clip_idx][0]
continue
time_offset = float(seek * HOP_LENGTH / SAMPLE_RATE) time_offset = float(seek * HOP_LENGTH / SAMPLE_RATE)
mel_segment = mel[seek : seek + N_FRAMES] window_end_time = float((seek + N_FRAMES) * HOP_LENGTH / SAMPLE_RATE)
segment_size = min(N_FRAMES, content_frames - seek) segment_size = min(N_FRAMES, content_frames - seek, seek_clip_end - seek)
mel_segment = mel[seek : seek + segment_size]
segment_duration = segment_size * HOP_LENGTH / SAMPLE_RATE segment_duration = segment_size * HOP_LENGTH / SAMPLE_RATE
mel_segment = pad_or_trim(mel_segment, N_FRAMES, axis=-2).astype(dtype) mel_segment = pad_or_trim(mel_segment, N_FRAMES, axis=-2).astype(dtype)
@ -260,6 +321,30 @@ def transcribe(
previous_seek = seek previous_seek = seek
current_segments = [] current_segments = []
# anomalous words are very long/short/improbable
def word_anomaly_score(word: dict) -> float:
probability = word.get("probability", 0.0)
duration = word["end"] - word["start"]
score = 0.0
if probability < 0.15:
score += 1.0
if duration < 0.133:
score += (0.133 - duration) * 15
if duration > 2.0:
score += duration - 2.0
return score
def is_segment_anomaly(segment: Optional[dict]) -> bool:
if segment is None or not segment["words"]:
return False
words = [w for w in segment["words"] if w["word"] not in punctuation]
words = words[:8]
score = sum(word_anomaly_score(w) for w in words)
return score >= 3 or score + 0.01 >= len(words)
def next_words_segment(segments: List[dict]) -> Optional[dict]:
return next((s for s in segments if s["words"]), None)
timestamp_tokens = tokens >= tokenizer.timestamp_begin timestamp_tokens = tokens >= tokenizer.timestamp_begin
single_timestamp_ending = timestamp_tokens[-2:].tolist() == [False, True] single_timestamp_ending = timestamp_tokens[-2:].tolist() == [False, True]
@ -324,6 +409,83 @@ def transcribe(
) )
seek += segment_size seek += segment_size
if word_timestamps:
add_word_timestamps(
segments=current_segments,
model=model,
tokenizer=tokenizer,
mel=mel_segment,
num_frames=segment_size,
prepend_punctuations=prepend_punctuations,
append_punctuations=append_punctuations,
last_speech_timestamp=last_speech_timestamp,
)
if not single_timestamp_ending:
last_word_end = _get_end(current_segments)
if last_word_end is not None and last_word_end > time_offset:
seek = round(last_word_end * FRAMES_PER_SECOND)
# skip silence before possible hallucinations
if hallucination_silence_threshold is not None:
threshold = hallucination_silence_threshold
if not single_timestamp_ending:
last_word_end = _get_end(current_segments)
if last_word_end is not None and last_word_end > time_offset:
remaining_duration = window_end_time - last_word_end
if remaining_duration > threshold:
seek = round(last_word_end * FRAMES_PER_SECOND)
else:
seek = previous_seek + segment_size
# if first segment might be a hallucination, skip leading silence
first_segment = next_words_segment(current_segments)
if first_segment is not None and is_segment_anomaly(first_segment):
gap = first_segment["start"] - time_offset
if gap > threshold:
seek = previous_seek + round(gap * FRAMES_PER_SECOND)
continue
# skip silence before any possible hallucination that is surrounded
# by silence or more hallucinations
hal_last_end = last_speech_timestamp
for si in range(len(current_segments)):
segment = current_segments[si]
if not segment["words"]:
continue
if is_segment_anomaly(segment):
next_segment = next_words_segment(
current_segments[si + 1 :]
)
if next_segment is not None:
hal_next_start = next_segment["words"][0]["start"]
else:
hal_next_start = time_offset + segment_duration
silence_before = (
segment["start"] - hal_last_end > threshold
or segment["start"] < threshold
or segment["start"] - time_offset < 2.0
)
silence_after = (
hal_next_start - segment["end"] > threshold
or is_segment_anomaly(next_segment)
or window_end_time - segment["end"] < 2.0
)
if silence_before and silence_after:
seek = round(
max(time_offset + 1, segment["start"])
* FRAMES_PER_SECOND
)
if content_duration - segment["end"] < threshold:
seek = content_frames
current_segments[si:] = []
break
hal_last_end = segment["end"]
last_word_end = _get_end(current_segments)
if last_word_end is not None:
last_speech_timestamp = last_word_end
if verbose: if verbose:
for segment in current_segments: for segment in current_segments:
start, end, text = segment["start"], segment["end"], segment["text"] start, end, text = segment["start"], segment["end"], segment["text"]

View File

@ -4,7 +4,7 @@ import base64
import gzip import gzip
import math import math
from dataclasses import dataclass from dataclasses import dataclass
from typing import Dict, Iterable, Optional from typing import Union
import mlx.core as mx import mlx.core as mx
import mlx.nn as nn import mlx.nn as nn
@ -72,8 +72,8 @@ class MultiHeadAttention(nn.Module):
else: else:
k, v = kv_cache k, v = kv_cache
wv = self.qkv_attention(q, k, v, mask) wv, qk = self.qkv_attention(q, k, v, mask)
return self.out(wv), (k, v) return self.out(wv), (k, v), qk
def qkv_attention(self, q, k, v, mask=None): def qkv_attention(self, q, k, v, mask=None):
n_batch, n_ctx, n_state = q.shape n_batch, n_ctx, n_state = q.shape
@ -85,12 +85,12 @@ class MultiHeadAttention(nn.Module):
qk = q @ k qk = q @ k
if mask is not None: if mask is not None:
qk = qk + mask[:n_ctx, :n_ctx] qk = qk + mask[:n_ctx, :n_ctx]
qk = qk.astype(mx.float32) qk = qk.astype(mx.float32)
w = mx.softmax(qk, axis=-1).astype(q.dtype) w = mx.softmax(qk, axis=-1).astype(q.dtype)
out = (w @ v).transpose(0, 2, 1, 3) out = (w @ v).transpose(0, 2, 1, 3)
out = out.reshape(n_batch, n_ctx, n_state) out = out.reshape(n_batch, n_ctx, n_state)
return out return out, qk
class ResidualAttentionBlock(nn.Module): class ResidualAttentionBlock(nn.Module):
@ -112,13 +112,16 @@ class ResidualAttentionBlock(nn.Module):
def __call__(self, x, xa=None, mask=None, kv_cache=None): def __call__(self, x, xa=None, mask=None, kv_cache=None):
kv, cross_kv = kv_cache if kv_cache else (None, None) kv, cross_kv = kv_cache if kv_cache else (None, None)
y, kv = self.attn(self.attn_ln(x), mask=mask, kv_cache=kv) y, kv, _ = self.attn(self.attn_ln(x), mask=mask, kv_cache=kv)
x += y x += y
cross_qk = None
if self.cross_attn: if self.cross_attn:
y, cross_kv = self.cross_attn(self.cross_attn_ln(x), xa, kv_cache=cross_kv) y, cross_kv, cross_qk = self.cross_attn(
self.cross_attn_ln(x), xa, kv_cache=cross_kv
)
x += y x += y
x = x + self.mlp2(nn.gelu(self.mlp1(self.mlp_ln(x))).astype(x.dtype)) x = x + self.mlp2(nn.gelu(self.mlp1(self.mlp_ln(x))).astype(x.dtype))
return x, (kv, cross_kv) return x, (kv, cross_kv), cross_qk
class AudioEncoder(nn.Module): class AudioEncoder(nn.Module):
@ -146,7 +149,7 @@ class AudioEncoder(nn.Module):
x = x + self._positional_embedding x = x + self._positional_embedding
for block in self.blocks: for block in self.blocks:
x, _ = block(x) x, _, _ = block(x)
x = self.ln_post(x) x = self.ln_post(x)
return x return x
@ -191,11 +194,14 @@ class TextDecoder(nn.Module):
if kv_cache is None: if kv_cache is None:
kv_cache = [None] * len(self.blocks) kv_cache = [None] * len(self.blocks)
cross_qk = [None] * len(self.blocks)
for e, block in enumerate(self.blocks): for e, block in enumerate(self.blocks):
x, kv_cache[e] = block(x, xa, mask=self._mask, kv_cache=kv_cache[e]) x, kv_cache[e], cross_qk[e] = block(
x, xa, mask=self._mask, kv_cache=kv_cache[e]
)
x = self.ln(x) x = self.ln(x)
return x @ self.token_embedding.weight.T, kv_cache return x @ self.token_embedding.weight.T, kv_cache, cross_qk
class Whisper(nn.Module): class Whisper(nn.Module):
@ -218,6 +224,28 @@ class Whisper(nn.Module):
self.dims.n_text_layer, self.dims.n_text_layer,
dtype, dtype,
) )
# use the last half among the decoder layers for time alignment by default;
# to use a specific set of heads, see `set_alignment_heads()` below.
all_heads = np.zeros(
(self.dims.n_text_layer, self.dims.n_text_head), dtype=bool
)
all_heads[self.dims.n_text_layer // 2 :] = True
self.alignment_heads = mx.array(np.asarray(all_heads.nonzero()).T)
def set_alignment_heads(self, dump: Union[bytes, np.ndarray]):
if isinstance(dump, np.ndarray):
self.alignment_heads = mx.array(dump)
elif isinstance(dump, bytes):
array = np.frombuffer(
gzip.decompress(base64.b85decode(dump)), dtype=bool
).copy()
mask = array.reshape(self.dims.n_text_layer, self.dims.n_text_head)
self.alignment_heads = mx.array(np.asarray(mask.nonzero()).T)
else:
raise ValueError(
f"Invalid type for `dump`: {type(dump)}. Expected a np.ndarray or base85-encoded bytes containing"
" alignment_head information"
)
def embed_audio(self, mel): def embed_audio(self, mel):
return self.encoder(mel) return self.encoder(mel)
@ -225,6 +253,10 @@ class Whisper(nn.Module):
def logits(self, tokens, audio_features): def logits(self, tokens, audio_features):
return self.decoder(tokens, audio_features)[0] return self.decoder(tokens, audio_features)[0]
def forward_with_cross_qk(self, mel, tokens):
logits, _, cross_qk = self.decoder(tokens, self.encoder(mel))
return logits, cross_qk
def __call__(self, mel, tokens): def __call__(self, mel, tokens):
return self.decoder(tokens, self.encoder(mel))[0] return self.decoder(tokens, self.encoder(mel))[0]