mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-24 17:31:18 +08:00
[Whisper] Add word timestamps and confidence scores (#201)
* Add word timestamps and confidence scores * Create a separate forward_with_cross_qk function * Move multiple ops from np to mlx, clean comments * Save alignment_heads * Cast qk to fp32 * Add test for word-level timestamps and confidence scores * format + readme * nit --------- Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
parent
25ebd36112
commit
bf9926489e
@ -2,7 +2,7 @@
|
||||
|
||||
Speech recognition with Whisper in MLX. Whisper is a set of open source speech
|
||||
recognition models from OpenAI, ranging from 39 million to 1.5 billion
|
||||
parameters[^1].
|
||||
parameters.[^1]
|
||||
|
||||
### Setup
|
||||
|
||||
@ -19,7 +19,8 @@ Install [`ffmpeg`](https://ffmpeg.org/):
|
||||
brew install ffmpeg
|
||||
```
|
||||
|
||||
Next, download the Whisper PyTorch checkpoint and convert the weights to the MLX format. For example, to convert the `tiny` model use:
|
||||
Next, download the Whisper PyTorch checkpoint and convert the weights to the
|
||||
MLX format. For example, to convert the `tiny` model use:
|
||||
|
||||
```
|
||||
python convert.py --torch-name-or-path tiny --mlx-path mlx_models/tiny
|
||||
@ -45,10 +46,24 @@ the converted `weights.npz` and `config.json` there.
|
||||
|
||||
Transcribe audio with:
|
||||
|
||||
```
|
||||
```python
|
||||
import whisper
|
||||
|
||||
text = whisper.transcribe(speech_file)["text"]
|
||||
```
|
||||
|
||||
The `transcribe` function also supports word-level timestamps. You can generate
|
||||
these with:
|
||||
|
||||
```python
|
||||
output = whisper.transcribe(speech_file, word_timestamps=True)
|
||||
print(output["segments"][0]["words"])
|
||||
```
|
||||
|
||||
To see more transcription options use:
|
||||
|
||||
```
|
||||
>>> help(whisper.transcribe)
|
||||
```
|
||||
|
||||
[^1]: Refer to the [arXiv paper](https://arxiv.org/abs/2212.04356), [blog post](https://openai.com/research/whisper), and [code](https://github.com/openai/whisper) for more details.
|
||||
|
@ -199,6 +199,10 @@ def torch_to_mlx(
|
||||
mlx_model = Whisper(torch_model.dims, dtype)
|
||||
params = tree_map(lambda p: p.astype(dtype), params)
|
||||
mlx_model.update(params)
|
||||
|
||||
if (alignment_heads := getattr(torch_model, "alignment_heads", None)) is not None:
|
||||
mlx_model.set_alignment_heads(alignment_heads.indices().T.numpy())
|
||||
|
||||
return mlx_model
|
||||
|
||||
|
||||
|
131
whisper/test.py
131
whisper/test.py
@ -311,6 +311,137 @@ class TestWhisper(unittest.TestCase):
|
||||
check_segment(result["segments"][5], expected_5)
|
||||
check_segment(result["segments"][73], expected_73)
|
||||
|
||||
def test_transcribe_word_level_timestamps_confidence_scores(self):
|
||||
result = whisper.transcribe(
|
||||
# TEST_AUDIO, model_path=MLX_FP32_MODEL_PATH, word_timestamps=True, fp16=False
|
||||
TEST_AUDIO,
|
||||
model_path=MLX_FP16_MODEL_PATH,
|
||||
word_timestamps=True,
|
||||
)
|
||||
|
||||
# result predicted with openai-whisper
|
||||
expected_0 = [
|
||||
{
|
||||
"word": " Then",
|
||||
"start": 0.0,
|
||||
"end": 0.94,
|
||||
"probability": 0.855542778968811,
|
||||
},
|
||||
{
|
||||
"word": " the",
|
||||
"start": 0.94,
|
||||
"end": 1.12,
|
||||
"probability": 0.6500106453895569,
|
||||
},
|
||||
{
|
||||
"word": " good",
|
||||
"start": 1.12,
|
||||
"end": 1.32,
|
||||
"probability": 0.5503873825073242,
|
||||
},
|
||||
{
|
||||
"word": " soul",
|
||||
"start": 1.32,
|
||||
"end": 1.56,
|
||||
"probability": 0.46757155656814575,
|
||||
},
|
||||
{
|
||||
"word": " openly",
|
||||
"start": 1.56,
|
||||
"end": 2.0,
|
||||
"probability": 0.9840946793556213,
|
||||
},
|
||||
{
|
||||
"word": " sorted",
|
||||
"start": 2.0,
|
||||
"end": 2.38,
|
||||
"probability": 0.24167272448539734,
|
||||
},
|
||||
{
|
||||
"word": " the",
|
||||
"start": 2.38,
|
||||
"end": 2.58,
|
||||
"probability": 0.9875414967536926,
|
||||
},
|
||||
{
|
||||
"word": " boat",
|
||||
"start": 2.58,
|
||||
"end": 2.8,
|
||||
"probability": 0.5856029391288757,
|
||||
},
|
||||
{
|
||||
"word": " and",
|
||||
"start": 2.8,
|
||||
"end": 2.98,
|
||||
"probability": 0.913351833820343,
|
||||
},
|
||||
{
|
||||
"word": " she",
|
||||
"start": 2.98,
|
||||
"end": 3.1,
|
||||
"probability": 0.9913808703422546,
|
||||
},
|
||||
{
|
||||
"word": " had",
|
||||
"start": 3.1,
|
||||
"end": 3.32,
|
||||
"probability": 0.9952940344810486,
|
||||
},
|
||||
{
|
||||
"word": " buoyed",
|
||||
"start": 3.32,
|
||||
"end": 3.58,
|
||||
"probability": 0.6411589980125427,
|
||||
},
|
||||
{
|
||||
"word": " so",
|
||||
"start": 3.58,
|
||||
"end": 3.8,
|
||||
"probability": 0.9682658314704895,
|
||||
},
|
||||
{
|
||||
"word": " long",
|
||||
"start": 3.8,
|
||||
"end": 4.06,
|
||||
"probability": 0.9953522682189941,
|
||||
},
|
||||
{
|
||||
"word": " in",
|
||||
"start": 4.06,
|
||||
"end": 4.26,
|
||||
"probability": 0.6745936870574951,
|
||||
},
|
||||
{
|
||||
"word": " secret",
|
||||
"start": 4.26,
|
||||
"end": 4.56,
|
||||
"probability": 0.9905064702033997,
|
||||
},
|
||||
{
|
||||
"word": " and",
|
||||
"start": 4.56,
|
||||
"end": 4.9,
|
||||
"probability": 0.856008768081665,
|
||||
},
|
||||
{
|
||||
"word": " bravely",
|
||||
"start": 4.9,
|
||||
"end": 5.28,
|
||||
"probability": 0.8477402329444885,
|
||||
},
|
||||
]
|
||||
|
||||
def check_words(words, expected_words):
|
||||
for word, expected_word in zip(words, expected_words):
|
||||
for k, v in expected_word.items():
|
||||
if isinstance(v, float):
|
||||
self.assertAlmostEqual(word[k], v, places=1)
|
||||
else:
|
||||
self.assertEqual(word[k], v)
|
||||
|
||||
# Randomly check a couple of segments
|
||||
check_words(result["segments"][0]["words"], expected_0)
|
||||
|
||||
|
||||
class TestAudio(unittest.TestCase):
|
||||
def test_load(self):
|
||||
|
@ -141,7 +141,7 @@ class Inference:
|
||||
# only need to use the last token except in the first forward pass
|
||||
tokens = tokens[:, -1:]
|
||||
|
||||
logits, self.kv_cache = self.model.decoder(
|
||||
logits, self.kv_cache, _ = self.model.decoder(
|
||||
tokens, audio_features, kv_cache=self.kv_cache
|
||||
)
|
||||
return logits.astype(mx.float32)
|
||||
|
@ -1,15 +1,13 @@
|
||||
# Copyright © 2023 Apple Inc.
|
||||
|
||||
import itertools
|
||||
import subprocess
|
||||
import warnings
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, List
|
||||
|
||||
import mlx.core as mx
|
||||
import numba
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from scipy import signal
|
||||
|
||||
from .audio import HOP_LENGTH, SAMPLE_RATE, TOKENS_PER_SECOND
|
||||
from .tokenizer import Tokenizer
|
||||
@ -18,7 +16,7 @@ if TYPE_CHECKING:
|
||||
from .model import Whisper
|
||||
|
||||
|
||||
def median_filter(x: torch.Tensor, filter_width: int):
|
||||
def median_filter(x: np.ndarray, filter_width: int):
|
||||
"""Apply a median filter of width `filter_width` along the last dimension of `x`"""
|
||||
pad_width = filter_width // 2
|
||||
if x.shape[-1] <= pad_width:
|
||||
@ -33,22 +31,12 @@ def median_filter(x: torch.Tensor, filter_width: int):
|
||||
filter_width > 0 and filter_width % 2 == 1
|
||||
), "`filter_width` should be an odd number"
|
||||
|
||||
result = None
|
||||
x = F.pad(x, (filter_width // 2, filter_width // 2, 0, 0), mode="reflect")
|
||||
if x.is_cuda:
|
||||
try:
|
||||
from .triton_ops import median_filter_cuda
|
||||
x = np.pad(x, ((0, 0), (0, 0), (pad_width, pad_width)), mode="reflect")
|
||||
|
||||
result = median_filter_cuda(x, filter_width)
|
||||
except (RuntimeError, subprocess.CalledProcessError):
|
||||
warnings.warn(
|
||||
"Failed to launch Triton kernels, likely due to missing CUDA toolkit; "
|
||||
"falling back to a slower median kernel implementation..."
|
||||
)
|
||||
|
||||
if result is None:
|
||||
# sort() is faster than torch.median (https://github.com/pytorch/pytorch/issues/51450)
|
||||
result = x.unfold(-1, filter_width, 1).sort()[0][..., filter_width // 2]
|
||||
# todo: more efficient version in mlx
|
||||
result = signal.medfilt(x.astype(np.float32), kernel_size=(1, 1, filter_width))[
|
||||
..., pad_width:-pad_width
|
||||
]
|
||||
|
||||
if ndim <= 2:
|
||||
result = result[0, 0]
|
||||
@ -107,50 +95,9 @@ def dtw_cpu(x: np.ndarray):
|
||||
return backtrace(trace)
|
||||
|
||||
|
||||
def dtw_cuda(x, BLOCK_SIZE=1024):
|
||||
from .triton_ops import dtw_kernel
|
||||
|
||||
M, N = x.shape
|
||||
assert M < BLOCK_SIZE, f"M should be smaller than {BLOCK_SIZE=}"
|
||||
|
||||
x_skew = (
|
||||
F.pad(x, (0, M + 1), value=np.inf).flatten()[: M * (N + M)].reshape(M, N + M)
|
||||
)
|
||||
x_skew = x_skew.T.contiguous()
|
||||
cost = torch.ones(N + M + 2, M + 2) * np.inf
|
||||
cost[0, 0] = 0
|
||||
cost = cost.cuda()
|
||||
trace = torch.zeros_like(cost, dtype=torch.int32)
|
||||
|
||||
dtw_kernel[(1,)](
|
||||
cost,
|
||||
trace,
|
||||
x_skew,
|
||||
x_skew.stride(0),
|
||||
cost.stride(0),
|
||||
trace.stride(0),
|
||||
N,
|
||||
M,
|
||||
BLOCK_SIZE=BLOCK_SIZE,
|
||||
)
|
||||
|
||||
trace = trace.T.flatten()[: (M + 1) * (M + N + 3)].reshape(M + 1, M + N + 3)[
|
||||
:, : N + 1
|
||||
]
|
||||
return backtrace(trace.cpu().numpy())
|
||||
|
||||
|
||||
def dtw(x: torch.Tensor) -> np.ndarray:
|
||||
if x.is_cuda:
|
||||
try:
|
||||
return dtw_cuda(x)
|
||||
except (RuntimeError, subprocess.CalledProcessError):
|
||||
warnings.warn(
|
||||
"Failed to launch Triton kernels, likely due to missing CUDA toolkit; "
|
||||
"falling back to a slower DTW implementation..."
|
||||
)
|
||||
|
||||
return dtw_cpu(x.double().cpu().numpy())
|
||||
def dtw(x: np.ndarray) -> np.ndarray:
|
||||
# todo: more efficient version in mlx
|
||||
return dtw_cpu(x)
|
||||
|
||||
|
||||
@dataclass
|
||||
@ -166,7 +113,7 @@ def find_alignment(
|
||||
model: "Whisper",
|
||||
tokenizer: Tokenizer,
|
||||
text_tokens: List[int],
|
||||
mel: torch.Tensor,
|
||||
mel: mx.array,
|
||||
num_frames: int,
|
||||
*,
|
||||
medfilt_width: int = 7,
|
||||
@ -175,41 +122,36 @@ def find_alignment(
|
||||
if len(text_tokens) == 0:
|
||||
return []
|
||||
|
||||
tokens = torch.tensor(
|
||||
tokens = mx.array(
|
||||
[
|
||||
*tokenizer.sot_sequence,
|
||||
tokenizer.no_timestamps,
|
||||
*text_tokens,
|
||||
tokenizer.eot,
|
||||
]
|
||||
).to(model.device)
|
||||
|
||||
# install hooks on the cross attention layers to retrieve the attention weights
|
||||
QKs = [None] * model.dims.n_text_layer
|
||||
hooks = [
|
||||
block.cross_attn.register_forward_hook(
|
||||
lambda _, ins, outs, index=i: QKs.__setitem__(index, outs[-1][0])
|
||||
)
|
||||
for i, block in enumerate(model.decoder.blocks)
|
||||
]
|
||||
|
||||
with torch.no_grad():
|
||||
logits = model(mel.unsqueeze(0), tokens.unsqueeze(0))[0]
|
||||
sampled_logits = logits[len(tokenizer.sot_sequence) :, : tokenizer.eot]
|
||||
token_probs = sampled_logits.softmax(dim=-1)
|
||||
text_token_probs = token_probs[np.arange(len(text_tokens)), text_tokens]
|
||||
text_token_probs = text_token_probs.tolist()
|
||||
|
||||
for hook in hooks:
|
||||
hook.remove()
|
||||
logits, cross_qk = model.forward_with_cross_qk(mel[None, :], tokens[None, :])
|
||||
# consider only the logits associated with predicting text
|
||||
sampled_logits = logits[0][len(tokenizer.sot_sequence) : -2, : tokenizer.eot]
|
||||
token_probs = mx.softmax(sampled_logits.astype(mx.float32), axis=-1).astype(
|
||||
sampled_logits.dtype
|
||||
)
|
||||
text_token_probs = mx.take_along_axis(
|
||||
token_probs, mx.array(text_tokens)[:, None], axis=1
|
||||
).squeeze(1)
|
||||
text_token_probs = np.array(text_token_probs)
|
||||
|
||||
# heads * tokens * frames
|
||||
weights = torch.stack([QKs[_l][_h] for _l, _h in model.alignment_heads.indices().T])
|
||||
weights = mx.stack(
|
||||
[cross_qk[_l.item()][0, _h.item()] for _l, _h in model.alignment_heads]
|
||||
)
|
||||
weights = weights[:, :, : num_frames // 2]
|
||||
weights = (weights * qk_scale).softmax(dim=-1)
|
||||
std, mean = torch.std_mean(weights, dim=-2, keepdim=True, unbiased=False)
|
||||
weights = mx.softmax(weights * qk_scale, axis=-1)
|
||||
mean = mx.mean(weights, axis=-2, keepdims=True)
|
||||
std = mx.var(weights, axis=-2, keepdims=True, ddof=0).sqrt()
|
||||
weights = (weights - mean) / std
|
||||
weights = median_filter(weights, medfilt_width)
|
||||
weights = median_filter(np.array(weights), medfilt_width)
|
||||
|
||||
matrix = weights.mean(axis=0)
|
||||
matrix = matrix[len(tokenizer.sot_sequence) : -1]
|
||||
@ -281,7 +223,7 @@ def add_word_timestamps(
|
||||
segments: List[dict],
|
||||
model: "Whisper",
|
||||
tokenizer: Tokenizer,
|
||||
mel: torch.Tensor,
|
||||
mel: mx.array,
|
||||
num_frames: int,
|
||||
prepend_punctuations: str = "\"'“¿([{-",
|
||||
append_punctuations: str = "\"'.。,,!!??::”)]}、",
|
||||
@ -301,6 +243,7 @@ def add_word_timestamps(
|
||||
word_durations = np.array([t.end - t.start for t in alignment])
|
||||
word_durations = word_durations[word_durations.nonzero()]
|
||||
median_duration = np.median(word_durations) if len(word_durations) > 0 else 0.0
|
||||
median_duration = min(0.7, float(median_duration))
|
||||
max_duration = median_duration * 2
|
||||
|
||||
# hack: truncate long words at sentence boundaries.
|
||||
|
@ -1,7 +1,8 @@
|
||||
# Copyright © 2023 Apple Inc.
|
||||
|
||||
import sys
|
||||
from typing import Optional, Tuple, Union
|
||||
import warnings
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
import mlx.core as mx
|
||||
import numpy as np
|
||||
@ -18,7 +19,8 @@ from .audio import (
|
||||
)
|
||||
from .decoding import DecodingOptions, DecodingResult
|
||||
from .load_models import load_model
|
||||
from .tokenizer import LANGUAGES, TO_LANGUAGE_CODE, get_tokenizer
|
||||
from .timing import add_word_timestamps
|
||||
from .tokenizer import LANGUAGES, get_tokenizer
|
||||
|
||||
|
||||
def _format_timestamp(seconds: float):
|
||||
@ -38,6 +40,13 @@ def _format_timestamp(seconds: float):
|
||||
return f"{hours_marker}{minutes:02d}:{seconds:02d}.{milliseconds:03d}"
|
||||
|
||||
|
||||
def _get_end(segments: List[dict]) -> Optional[float]:
|
||||
return next(
|
||||
(w["end"] for s in reversed(segments) for w in reversed(s["words"])),
|
||||
segments[-1]["end"] if segments else None,
|
||||
)
|
||||
|
||||
|
||||
class ModelHolder:
|
||||
model = None
|
||||
model_path = None
|
||||
@ -61,8 +70,11 @@ def transcribe(
|
||||
no_speech_threshold: Optional[float] = 0.6,
|
||||
condition_on_previous_text: bool = True,
|
||||
initial_prompt: Optional[str] = None,
|
||||
word_timestamps: bool = False,
|
||||
prepend_punctuations: str = "\"'“¿([{-",
|
||||
append_punctuations: str = "\"'.。,,!!??::”)]}、",
|
||||
clip_timestamps: Union[str, List[float]] = "0",
|
||||
hallucination_silence_threshold: Optional[float] = None,
|
||||
**decode_options,
|
||||
):
|
||||
"""
|
||||
@ -99,6 +111,16 @@ def transcribe(
|
||||
disabling may make the text inconsistent across windows, but the model becomes less prone to
|
||||
getting stuck in a failure loop, such as repetition looping or timestamps going out of sync.
|
||||
|
||||
word_timestamps: bool
|
||||
Extract word-level timestamps using the cross-attention pattern and dynamic time warping,
|
||||
and include the timestamps for each word in each segment.
|
||||
|
||||
prepend_punctuations: str
|
||||
If word_timestamps is True, merge these punctuation symbols with the next word
|
||||
|
||||
append_punctuations: str
|
||||
If word_timestamps is True, merge these punctuation symbols with the previous word
|
||||
|
||||
initial_prompt: Optional[str]
|
||||
Optional text to provide as a prompt for the first window. This can be used to provide, or
|
||||
"prompt-engineer" a context for transcription, e.g. custom vocabularies or proper nouns
|
||||
@ -107,6 +129,14 @@ def transcribe(
|
||||
decode_options: dict
|
||||
Keyword arguments to construct `DecodingOptions` instances
|
||||
|
||||
clip_timestamps: Union[str, List[float]]
|
||||
Comma-separated list start,end,start,end,... timestamps (in seconds) of clips to process.
|
||||
The last end timestamp defaults to the end of the file.
|
||||
|
||||
hallucination_silence_threshold: Optional[float]
|
||||
When word_timestamps is True, skip silent periods longer than this threshold (in seconds)
|
||||
when a possible hallucination is detected
|
||||
|
||||
Returns
|
||||
-------
|
||||
A dictionary containing the resulting text ("text") and segment-level details ("segments"), and
|
||||
@ -119,6 +149,7 @@ def transcribe(
|
||||
# Pad 30-seconds of silence to the input audio, for slicing
|
||||
mel = log_mel_spectrogram(audio, n_mels=model.dims.n_mels, padding=N_SAMPLES)
|
||||
content_frames = mel.shape[-2] - N_FRAMES
|
||||
content_duration = float(content_frames * HOP_LENGTH / SAMPLE_RATE)
|
||||
|
||||
if verbose:
|
||||
system_encoding = sys.getdefaultencoding()
|
||||
@ -155,6 +186,22 @@ def transcribe(
|
||||
task=task,
|
||||
)
|
||||
|
||||
if isinstance(clip_timestamps, str):
|
||||
clip_timestamps = [
|
||||
float(ts) for ts in (clip_timestamps.split(",") if clip_timestamps else [])
|
||||
]
|
||||
seek_points: List[int] = [round(ts * FRAMES_PER_SECOND) for ts in clip_timestamps]
|
||||
if len(seek_points) == 0:
|
||||
seek_points.append(0)
|
||||
if len(seek_points) % 2 == 1:
|
||||
seek_points.append(content_frames)
|
||||
seek_clips: List[Tuple[int, int]] = list(zip(seek_points[::2], seek_points[1::2]))
|
||||
|
||||
punctuation = "\"'“¿([{-\"'.。,,!!??::”)]}、"
|
||||
|
||||
if word_timestamps and task == "translate":
|
||||
warnings.warn("Word-level timestamps on translations may not be reliable.")
|
||||
|
||||
def decode_with_fallback(segment: mx.array) -> DecodingResult:
|
||||
temperatures = (
|
||||
[temperature] if isinstance(temperature, (int, float)) else temperature
|
||||
@ -195,7 +242,8 @@ def transcribe(
|
||||
|
||||
return decode_result
|
||||
|
||||
seek = 0
|
||||
clip_idx = 0
|
||||
seek = seek_clips[clip_idx][0]
|
||||
input_stride = N_FRAMES // model.dims.n_audio_ctx # mel frames per output token: 2
|
||||
time_precision = (
|
||||
input_stride * HOP_LENGTH / SAMPLE_RATE
|
||||
@ -232,10 +280,23 @@ def transcribe(
|
||||
total=content_frames, unit="frames", disable=verbose is not False
|
||||
) as pbar:
|
||||
last_speech_timestamp = 0.0
|
||||
while seek < content_frames:
|
||||
# NOTE: This loop is obscurely flattened to make the diff readable.
|
||||
# A later commit should turn this into a simpler nested loop.
|
||||
# for seek_clip_start, seek_clip_end in seek_clips:
|
||||
# while seek < seek_clip_end
|
||||
while clip_idx < len(seek_clips):
|
||||
seek_clip_start, seek_clip_end = seek_clips[clip_idx]
|
||||
if seek < seek_clip_start:
|
||||
seek = seek_clip_start
|
||||
if seek >= seek_clip_end:
|
||||
clip_idx += 1
|
||||
if clip_idx < len(seek_clips):
|
||||
seek = seek_clips[clip_idx][0]
|
||||
continue
|
||||
time_offset = float(seek * HOP_LENGTH / SAMPLE_RATE)
|
||||
mel_segment = mel[seek : seek + N_FRAMES]
|
||||
segment_size = min(N_FRAMES, content_frames - seek)
|
||||
window_end_time = float((seek + N_FRAMES) * HOP_LENGTH / SAMPLE_RATE)
|
||||
segment_size = min(N_FRAMES, content_frames - seek, seek_clip_end - seek)
|
||||
mel_segment = mel[seek : seek + segment_size]
|
||||
segment_duration = segment_size * HOP_LENGTH / SAMPLE_RATE
|
||||
mel_segment = pad_or_trim(mel_segment, N_FRAMES, axis=-2).astype(dtype)
|
||||
|
||||
@ -260,6 +321,30 @@ def transcribe(
|
||||
previous_seek = seek
|
||||
current_segments = []
|
||||
|
||||
# anomalous words are very long/short/improbable
|
||||
def word_anomaly_score(word: dict) -> float:
|
||||
probability = word.get("probability", 0.0)
|
||||
duration = word["end"] - word["start"]
|
||||
score = 0.0
|
||||
if probability < 0.15:
|
||||
score += 1.0
|
||||
if duration < 0.133:
|
||||
score += (0.133 - duration) * 15
|
||||
if duration > 2.0:
|
||||
score += duration - 2.0
|
||||
return score
|
||||
|
||||
def is_segment_anomaly(segment: Optional[dict]) -> bool:
|
||||
if segment is None or not segment["words"]:
|
||||
return False
|
||||
words = [w for w in segment["words"] if w["word"] not in punctuation]
|
||||
words = words[:8]
|
||||
score = sum(word_anomaly_score(w) for w in words)
|
||||
return score >= 3 or score + 0.01 >= len(words)
|
||||
|
||||
def next_words_segment(segments: List[dict]) -> Optional[dict]:
|
||||
return next((s for s in segments if s["words"]), None)
|
||||
|
||||
timestamp_tokens = tokens >= tokenizer.timestamp_begin
|
||||
single_timestamp_ending = timestamp_tokens[-2:].tolist() == [False, True]
|
||||
|
||||
@ -324,6 +409,83 @@ def transcribe(
|
||||
)
|
||||
seek += segment_size
|
||||
|
||||
if word_timestamps:
|
||||
add_word_timestamps(
|
||||
segments=current_segments,
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
mel=mel_segment,
|
||||
num_frames=segment_size,
|
||||
prepend_punctuations=prepend_punctuations,
|
||||
append_punctuations=append_punctuations,
|
||||
last_speech_timestamp=last_speech_timestamp,
|
||||
)
|
||||
|
||||
if not single_timestamp_ending:
|
||||
last_word_end = _get_end(current_segments)
|
||||
if last_word_end is not None and last_word_end > time_offset:
|
||||
seek = round(last_word_end * FRAMES_PER_SECOND)
|
||||
|
||||
# skip silence before possible hallucinations
|
||||
if hallucination_silence_threshold is not None:
|
||||
threshold = hallucination_silence_threshold
|
||||
if not single_timestamp_ending:
|
||||
last_word_end = _get_end(current_segments)
|
||||
if last_word_end is not None and last_word_end > time_offset:
|
||||
remaining_duration = window_end_time - last_word_end
|
||||
if remaining_duration > threshold:
|
||||
seek = round(last_word_end * FRAMES_PER_SECOND)
|
||||
else:
|
||||
seek = previous_seek + segment_size
|
||||
|
||||
# if first segment might be a hallucination, skip leading silence
|
||||
first_segment = next_words_segment(current_segments)
|
||||
if first_segment is not None and is_segment_anomaly(first_segment):
|
||||
gap = first_segment["start"] - time_offset
|
||||
if gap > threshold:
|
||||
seek = previous_seek + round(gap * FRAMES_PER_SECOND)
|
||||
continue
|
||||
|
||||
# skip silence before any possible hallucination that is surrounded
|
||||
# by silence or more hallucinations
|
||||
hal_last_end = last_speech_timestamp
|
||||
for si in range(len(current_segments)):
|
||||
segment = current_segments[si]
|
||||
if not segment["words"]:
|
||||
continue
|
||||
if is_segment_anomaly(segment):
|
||||
next_segment = next_words_segment(
|
||||
current_segments[si + 1 :]
|
||||
)
|
||||
if next_segment is not None:
|
||||
hal_next_start = next_segment["words"][0]["start"]
|
||||
else:
|
||||
hal_next_start = time_offset + segment_duration
|
||||
silence_before = (
|
||||
segment["start"] - hal_last_end > threshold
|
||||
or segment["start"] < threshold
|
||||
or segment["start"] - time_offset < 2.0
|
||||
)
|
||||
silence_after = (
|
||||
hal_next_start - segment["end"] > threshold
|
||||
or is_segment_anomaly(next_segment)
|
||||
or window_end_time - segment["end"] < 2.0
|
||||
)
|
||||
if silence_before and silence_after:
|
||||
seek = round(
|
||||
max(time_offset + 1, segment["start"])
|
||||
* FRAMES_PER_SECOND
|
||||
)
|
||||
if content_duration - segment["end"] < threshold:
|
||||
seek = content_frames
|
||||
current_segments[si:] = []
|
||||
break
|
||||
hal_last_end = segment["end"]
|
||||
|
||||
last_word_end = _get_end(current_segments)
|
||||
if last_word_end is not None:
|
||||
last_speech_timestamp = last_word_end
|
||||
|
||||
if verbose:
|
||||
for segment in current_segments:
|
||||
start, end, text = segment["start"], segment["end"], segment["text"]
|
||||
|
@ -4,7 +4,7 @@ import base64
|
||||
import gzip
|
||||
import math
|
||||
from dataclasses import dataclass
|
||||
from typing import Dict, Iterable, Optional
|
||||
from typing import Union
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
@ -72,8 +72,8 @@ class MultiHeadAttention(nn.Module):
|
||||
else:
|
||||
k, v = kv_cache
|
||||
|
||||
wv = self.qkv_attention(q, k, v, mask)
|
||||
return self.out(wv), (k, v)
|
||||
wv, qk = self.qkv_attention(q, k, v, mask)
|
||||
return self.out(wv), (k, v), qk
|
||||
|
||||
def qkv_attention(self, q, k, v, mask=None):
|
||||
n_batch, n_ctx, n_state = q.shape
|
||||
@ -85,12 +85,12 @@ class MultiHeadAttention(nn.Module):
|
||||
qk = q @ k
|
||||
if mask is not None:
|
||||
qk = qk + mask[:n_ctx, :n_ctx]
|
||||
|
||||
qk = qk.astype(mx.float32)
|
||||
|
||||
w = mx.softmax(qk, axis=-1).astype(q.dtype)
|
||||
out = (w @ v).transpose(0, 2, 1, 3)
|
||||
out = out.reshape(n_batch, n_ctx, n_state)
|
||||
return out
|
||||
return out, qk
|
||||
|
||||
|
||||
class ResidualAttentionBlock(nn.Module):
|
||||
@ -112,13 +112,16 @@ class ResidualAttentionBlock(nn.Module):
|
||||
|
||||
def __call__(self, x, xa=None, mask=None, kv_cache=None):
|
||||
kv, cross_kv = kv_cache if kv_cache else (None, None)
|
||||
y, kv = self.attn(self.attn_ln(x), mask=mask, kv_cache=kv)
|
||||
y, kv, _ = self.attn(self.attn_ln(x), mask=mask, kv_cache=kv)
|
||||
x += y
|
||||
cross_qk = None
|
||||
if self.cross_attn:
|
||||
y, cross_kv = self.cross_attn(self.cross_attn_ln(x), xa, kv_cache=cross_kv)
|
||||
y, cross_kv, cross_qk = self.cross_attn(
|
||||
self.cross_attn_ln(x), xa, kv_cache=cross_kv
|
||||
)
|
||||
x += y
|
||||
x = x + self.mlp2(nn.gelu(self.mlp1(self.mlp_ln(x))).astype(x.dtype))
|
||||
return x, (kv, cross_kv)
|
||||
return x, (kv, cross_kv), cross_qk
|
||||
|
||||
|
||||
class AudioEncoder(nn.Module):
|
||||
@ -146,7 +149,7 @@ class AudioEncoder(nn.Module):
|
||||
x = x + self._positional_embedding
|
||||
|
||||
for block in self.blocks:
|
||||
x, _ = block(x)
|
||||
x, _, _ = block(x)
|
||||
|
||||
x = self.ln_post(x)
|
||||
return x
|
||||
@ -191,11 +194,14 @@ class TextDecoder(nn.Module):
|
||||
|
||||
if kv_cache is None:
|
||||
kv_cache = [None] * len(self.blocks)
|
||||
cross_qk = [None] * len(self.blocks)
|
||||
for e, block in enumerate(self.blocks):
|
||||
x, kv_cache[e] = block(x, xa, mask=self._mask, kv_cache=kv_cache[e])
|
||||
x, kv_cache[e], cross_qk[e] = block(
|
||||
x, xa, mask=self._mask, kv_cache=kv_cache[e]
|
||||
)
|
||||
|
||||
x = self.ln(x)
|
||||
return x @ self.token_embedding.weight.T, kv_cache
|
||||
return x @ self.token_embedding.weight.T, kv_cache, cross_qk
|
||||
|
||||
|
||||
class Whisper(nn.Module):
|
||||
@ -218,6 +224,28 @@ class Whisper(nn.Module):
|
||||
self.dims.n_text_layer,
|
||||
dtype,
|
||||
)
|
||||
# use the last half among the decoder layers for time alignment by default;
|
||||
# to use a specific set of heads, see `set_alignment_heads()` below.
|
||||
all_heads = np.zeros(
|
||||
(self.dims.n_text_layer, self.dims.n_text_head), dtype=bool
|
||||
)
|
||||
all_heads[self.dims.n_text_layer // 2 :] = True
|
||||
self.alignment_heads = mx.array(np.asarray(all_heads.nonzero()).T)
|
||||
|
||||
def set_alignment_heads(self, dump: Union[bytes, np.ndarray]):
|
||||
if isinstance(dump, np.ndarray):
|
||||
self.alignment_heads = mx.array(dump)
|
||||
elif isinstance(dump, bytes):
|
||||
array = np.frombuffer(
|
||||
gzip.decompress(base64.b85decode(dump)), dtype=bool
|
||||
).copy()
|
||||
mask = array.reshape(self.dims.n_text_layer, self.dims.n_text_head)
|
||||
self.alignment_heads = mx.array(np.asarray(mask.nonzero()).T)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Invalid type for `dump`: {type(dump)}. Expected a np.ndarray or base85-encoded bytes containing"
|
||||
" alignment_head information"
|
||||
)
|
||||
|
||||
def embed_audio(self, mel):
|
||||
return self.encoder(mel)
|
||||
@ -225,6 +253,10 @@ class Whisper(nn.Module):
|
||||
def logits(self, tokens, audio_features):
|
||||
return self.decoder(tokens, audio_features)[0]
|
||||
|
||||
def forward_with_cross_qk(self, mel, tokens):
|
||||
logits, _, cross_qk = self.decoder(tokens, self.encoder(mel))
|
||||
return logits, cross_qk
|
||||
|
||||
def __call__(self, mel, tokens):
|
||||
return self.decoder(tokens, self.encoder(mel))[0]
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user