mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-25 01:41:19 +08:00
[Whisper] Add word timestamps and confidence scores (#201)
* Add word timestamps and confidence scores * Create a separate forward_with_cross_qk function * Move multiple ops from np to mlx, clean comments * Save alignment_heads * Cast qk to fp32 * Add test for word-level timestamps and confidence scores * format + readme * nit --------- Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
parent
25ebd36112
commit
bf9926489e
@ -2,7 +2,7 @@
|
|||||||
|
|
||||||
Speech recognition with Whisper in MLX. Whisper is a set of open source speech
|
Speech recognition with Whisper in MLX. Whisper is a set of open source speech
|
||||||
recognition models from OpenAI, ranging from 39 million to 1.5 billion
|
recognition models from OpenAI, ranging from 39 million to 1.5 billion
|
||||||
parameters[^1].
|
parameters.[^1]
|
||||||
|
|
||||||
### Setup
|
### Setup
|
||||||
|
|
||||||
@ -19,7 +19,8 @@ Install [`ffmpeg`](https://ffmpeg.org/):
|
|||||||
brew install ffmpeg
|
brew install ffmpeg
|
||||||
```
|
```
|
||||||
|
|
||||||
Next, download the Whisper PyTorch checkpoint and convert the weights to the MLX format. For example, to convert the `tiny` model use:
|
Next, download the Whisper PyTorch checkpoint and convert the weights to the
|
||||||
|
MLX format. For example, to convert the `tiny` model use:
|
||||||
|
|
||||||
```
|
```
|
||||||
python convert.py --torch-name-or-path tiny --mlx-path mlx_models/tiny
|
python convert.py --torch-name-or-path tiny --mlx-path mlx_models/tiny
|
||||||
@ -45,10 +46,24 @@ the converted `weights.npz` and `config.json` there.
|
|||||||
|
|
||||||
Transcribe audio with:
|
Transcribe audio with:
|
||||||
|
|
||||||
```
|
```python
|
||||||
import whisper
|
import whisper
|
||||||
|
|
||||||
text = whisper.transcribe(speech_file)["text"]
|
text = whisper.transcribe(speech_file)["text"]
|
||||||
```
|
```
|
||||||
|
|
||||||
|
The `transcribe` function also supports word-level timestamps. You can generate
|
||||||
|
these with:
|
||||||
|
|
||||||
|
```python
|
||||||
|
output = whisper.transcribe(speech_file, word_timestamps=True)
|
||||||
|
print(output["segments"][0]["words"])
|
||||||
|
```
|
||||||
|
|
||||||
|
To see more transcription options use:
|
||||||
|
|
||||||
|
```
|
||||||
|
>>> help(whisper.transcribe)
|
||||||
|
```
|
||||||
|
|
||||||
[^1]: Refer to the [arXiv paper](https://arxiv.org/abs/2212.04356), [blog post](https://openai.com/research/whisper), and [code](https://github.com/openai/whisper) for more details.
|
[^1]: Refer to the [arXiv paper](https://arxiv.org/abs/2212.04356), [blog post](https://openai.com/research/whisper), and [code](https://github.com/openai/whisper) for more details.
|
||||||
|
@ -199,6 +199,10 @@ def torch_to_mlx(
|
|||||||
mlx_model = Whisper(torch_model.dims, dtype)
|
mlx_model = Whisper(torch_model.dims, dtype)
|
||||||
params = tree_map(lambda p: p.astype(dtype), params)
|
params = tree_map(lambda p: p.astype(dtype), params)
|
||||||
mlx_model.update(params)
|
mlx_model.update(params)
|
||||||
|
|
||||||
|
if (alignment_heads := getattr(torch_model, "alignment_heads", None)) is not None:
|
||||||
|
mlx_model.set_alignment_heads(alignment_heads.indices().T.numpy())
|
||||||
|
|
||||||
return mlx_model
|
return mlx_model
|
||||||
|
|
||||||
|
|
||||||
|
131
whisper/test.py
131
whisper/test.py
@ -311,6 +311,137 @@ class TestWhisper(unittest.TestCase):
|
|||||||
check_segment(result["segments"][5], expected_5)
|
check_segment(result["segments"][5], expected_5)
|
||||||
check_segment(result["segments"][73], expected_73)
|
check_segment(result["segments"][73], expected_73)
|
||||||
|
|
||||||
|
def test_transcribe_word_level_timestamps_confidence_scores(self):
|
||||||
|
result = whisper.transcribe(
|
||||||
|
# TEST_AUDIO, model_path=MLX_FP32_MODEL_PATH, word_timestamps=True, fp16=False
|
||||||
|
TEST_AUDIO,
|
||||||
|
model_path=MLX_FP16_MODEL_PATH,
|
||||||
|
word_timestamps=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
# result predicted with openai-whisper
|
||||||
|
expected_0 = [
|
||||||
|
{
|
||||||
|
"word": " Then",
|
||||||
|
"start": 0.0,
|
||||||
|
"end": 0.94,
|
||||||
|
"probability": 0.855542778968811,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"word": " the",
|
||||||
|
"start": 0.94,
|
||||||
|
"end": 1.12,
|
||||||
|
"probability": 0.6500106453895569,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"word": " good",
|
||||||
|
"start": 1.12,
|
||||||
|
"end": 1.32,
|
||||||
|
"probability": 0.5503873825073242,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"word": " soul",
|
||||||
|
"start": 1.32,
|
||||||
|
"end": 1.56,
|
||||||
|
"probability": 0.46757155656814575,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"word": " openly",
|
||||||
|
"start": 1.56,
|
||||||
|
"end": 2.0,
|
||||||
|
"probability": 0.9840946793556213,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"word": " sorted",
|
||||||
|
"start": 2.0,
|
||||||
|
"end": 2.38,
|
||||||
|
"probability": 0.24167272448539734,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"word": " the",
|
||||||
|
"start": 2.38,
|
||||||
|
"end": 2.58,
|
||||||
|
"probability": 0.9875414967536926,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"word": " boat",
|
||||||
|
"start": 2.58,
|
||||||
|
"end": 2.8,
|
||||||
|
"probability": 0.5856029391288757,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"word": " and",
|
||||||
|
"start": 2.8,
|
||||||
|
"end": 2.98,
|
||||||
|
"probability": 0.913351833820343,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"word": " she",
|
||||||
|
"start": 2.98,
|
||||||
|
"end": 3.1,
|
||||||
|
"probability": 0.9913808703422546,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"word": " had",
|
||||||
|
"start": 3.1,
|
||||||
|
"end": 3.32,
|
||||||
|
"probability": 0.9952940344810486,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"word": " buoyed",
|
||||||
|
"start": 3.32,
|
||||||
|
"end": 3.58,
|
||||||
|
"probability": 0.6411589980125427,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"word": " so",
|
||||||
|
"start": 3.58,
|
||||||
|
"end": 3.8,
|
||||||
|
"probability": 0.9682658314704895,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"word": " long",
|
||||||
|
"start": 3.8,
|
||||||
|
"end": 4.06,
|
||||||
|
"probability": 0.9953522682189941,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"word": " in",
|
||||||
|
"start": 4.06,
|
||||||
|
"end": 4.26,
|
||||||
|
"probability": 0.6745936870574951,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"word": " secret",
|
||||||
|
"start": 4.26,
|
||||||
|
"end": 4.56,
|
||||||
|
"probability": 0.9905064702033997,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"word": " and",
|
||||||
|
"start": 4.56,
|
||||||
|
"end": 4.9,
|
||||||
|
"probability": 0.856008768081665,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"word": " bravely",
|
||||||
|
"start": 4.9,
|
||||||
|
"end": 5.28,
|
||||||
|
"probability": 0.8477402329444885,
|
||||||
|
},
|
||||||
|
]
|
||||||
|
|
||||||
|
def check_words(words, expected_words):
|
||||||
|
for word, expected_word in zip(words, expected_words):
|
||||||
|
for k, v in expected_word.items():
|
||||||
|
if isinstance(v, float):
|
||||||
|
self.assertAlmostEqual(word[k], v, places=1)
|
||||||
|
else:
|
||||||
|
self.assertEqual(word[k], v)
|
||||||
|
|
||||||
|
# Randomly check a couple of segments
|
||||||
|
check_words(result["segments"][0]["words"], expected_0)
|
||||||
|
|
||||||
|
|
||||||
class TestAudio(unittest.TestCase):
|
class TestAudio(unittest.TestCase):
|
||||||
def test_load(self):
|
def test_load(self):
|
||||||
|
@ -141,7 +141,7 @@ class Inference:
|
|||||||
# only need to use the last token except in the first forward pass
|
# only need to use the last token except in the first forward pass
|
||||||
tokens = tokens[:, -1:]
|
tokens = tokens[:, -1:]
|
||||||
|
|
||||||
logits, self.kv_cache = self.model.decoder(
|
logits, self.kv_cache, _ = self.model.decoder(
|
||||||
tokens, audio_features, kv_cache=self.kv_cache
|
tokens, audio_features, kv_cache=self.kv_cache
|
||||||
)
|
)
|
||||||
return logits.astype(mx.float32)
|
return logits.astype(mx.float32)
|
||||||
|
@ -1,15 +1,13 @@
|
|||||||
# Copyright © 2023 Apple Inc.
|
# Copyright © 2023 Apple Inc.
|
||||||
|
|
||||||
import itertools
|
import itertools
|
||||||
import subprocess
|
|
||||||
import warnings
|
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import TYPE_CHECKING, List
|
from typing import TYPE_CHECKING, List
|
||||||
|
|
||||||
|
import mlx.core as mx
|
||||||
import numba
|
import numba
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
from scipy import signal
|
||||||
import torch.nn.functional as F
|
|
||||||
|
|
||||||
from .audio import HOP_LENGTH, SAMPLE_RATE, TOKENS_PER_SECOND
|
from .audio import HOP_LENGTH, SAMPLE_RATE, TOKENS_PER_SECOND
|
||||||
from .tokenizer import Tokenizer
|
from .tokenizer import Tokenizer
|
||||||
@ -18,7 +16,7 @@ if TYPE_CHECKING:
|
|||||||
from .model import Whisper
|
from .model import Whisper
|
||||||
|
|
||||||
|
|
||||||
def median_filter(x: torch.Tensor, filter_width: int):
|
def median_filter(x: np.ndarray, filter_width: int):
|
||||||
"""Apply a median filter of width `filter_width` along the last dimension of `x`"""
|
"""Apply a median filter of width `filter_width` along the last dimension of `x`"""
|
||||||
pad_width = filter_width // 2
|
pad_width = filter_width // 2
|
||||||
if x.shape[-1] <= pad_width:
|
if x.shape[-1] <= pad_width:
|
||||||
@ -33,22 +31,12 @@ def median_filter(x: torch.Tensor, filter_width: int):
|
|||||||
filter_width > 0 and filter_width % 2 == 1
|
filter_width > 0 and filter_width % 2 == 1
|
||||||
), "`filter_width` should be an odd number"
|
), "`filter_width` should be an odd number"
|
||||||
|
|
||||||
result = None
|
x = np.pad(x, ((0, 0), (0, 0), (pad_width, pad_width)), mode="reflect")
|
||||||
x = F.pad(x, (filter_width // 2, filter_width // 2, 0, 0), mode="reflect")
|
|
||||||
if x.is_cuda:
|
|
||||||
try:
|
|
||||||
from .triton_ops import median_filter_cuda
|
|
||||||
|
|
||||||
result = median_filter_cuda(x, filter_width)
|
# todo: more efficient version in mlx
|
||||||
except (RuntimeError, subprocess.CalledProcessError):
|
result = signal.medfilt(x.astype(np.float32), kernel_size=(1, 1, filter_width))[
|
||||||
warnings.warn(
|
..., pad_width:-pad_width
|
||||||
"Failed to launch Triton kernels, likely due to missing CUDA toolkit; "
|
]
|
||||||
"falling back to a slower median kernel implementation..."
|
|
||||||
)
|
|
||||||
|
|
||||||
if result is None:
|
|
||||||
# sort() is faster than torch.median (https://github.com/pytorch/pytorch/issues/51450)
|
|
||||||
result = x.unfold(-1, filter_width, 1).sort()[0][..., filter_width // 2]
|
|
||||||
|
|
||||||
if ndim <= 2:
|
if ndim <= 2:
|
||||||
result = result[0, 0]
|
result = result[0, 0]
|
||||||
@ -107,50 +95,9 @@ def dtw_cpu(x: np.ndarray):
|
|||||||
return backtrace(trace)
|
return backtrace(trace)
|
||||||
|
|
||||||
|
|
||||||
def dtw_cuda(x, BLOCK_SIZE=1024):
|
def dtw(x: np.ndarray) -> np.ndarray:
|
||||||
from .triton_ops import dtw_kernel
|
# todo: more efficient version in mlx
|
||||||
|
return dtw_cpu(x)
|
||||||
M, N = x.shape
|
|
||||||
assert M < BLOCK_SIZE, f"M should be smaller than {BLOCK_SIZE=}"
|
|
||||||
|
|
||||||
x_skew = (
|
|
||||||
F.pad(x, (0, M + 1), value=np.inf).flatten()[: M * (N + M)].reshape(M, N + M)
|
|
||||||
)
|
|
||||||
x_skew = x_skew.T.contiguous()
|
|
||||||
cost = torch.ones(N + M + 2, M + 2) * np.inf
|
|
||||||
cost[0, 0] = 0
|
|
||||||
cost = cost.cuda()
|
|
||||||
trace = torch.zeros_like(cost, dtype=torch.int32)
|
|
||||||
|
|
||||||
dtw_kernel[(1,)](
|
|
||||||
cost,
|
|
||||||
trace,
|
|
||||||
x_skew,
|
|
||||||
x_skew.stride(0),
|
|
||||||
cost.stride(0),
|
|
||||||
trace.stride(0),
|
|
||||||
N,
|
|
||||||
M,
|
|
||||||
BLOCK_SIZE=BLOCK_SIZE,
|
|
||||||
)
|
|
||||||
|
|
||||||
trace = trace.T.flatten()[: (M + 1) * (M + N + 3)].reshape(M + 1, M + N + 3)[
|
|
||||||
:, : N + 1
|
|
||||||
]
|
|
||||||
return backtrace(trace.cpu().numpy())
|
|
||||||
|
|
||||||
|
|
||||||
def dtw(x: torch.Tensor) -> np.ndarray:
|
|
||||||
if x.is_cuda:
|
|
||||||
try:
|
|
||||||
return dtw_cuda(x)
|
|
||||||
except (RuntimeError, subprocess.CalledProcessError):
|
|
||||||
warnings.warn(
|
|
||||||
"Failed to launch Triton kernels, likely due to missing CUDA toolkit; "
|
|
||||||
"falling back to a slower DTW implementation..."
|
|
||||||
)
|
|
||||||
|
|
||||||
return dtw_cpu(x.double().cpu().numpy())
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@ -166,7 +113,7 @@ def find_alignment(
|
|||||||
model: "Whisper",
|
model: "Whisper",
|
||||||
tokenizer: Tokenizer,
|
tokenizer: Tokenizer,
|
||||||
text_tokens: List[int],
|
text_tokens: List[int],
|
||||||
mel: torch.Tensor,
|
mel: mx.array,
|
||||||
num_frames: int,
|
num_frames: int,
|
||||||
*,
|
*,
|
||||||
medfilt_width: int = 7,
|
medfilt_width: int = 7,
|
||||||
@ -175,41 +122,36 @@ def find_alignment(
|
|||||||
if len(text_tokens) == 0:
|
if len(text_tokens) == 0:
|
||||||
return []
|
return []
|
||||||
|
|
||||||
tokens = torch.tensor(
|
tokens = mx.array(
|
||||||
[
|
[
|
||||||
*tokenizer.sot_sequence,
|
*tokenizer.sot_sequence,
|
||||||
tokenizer.no_timestamps,
|
tokenizer.no_timestamps,
|
||||||
*text_tokens,
|
*text_tokens,
|
||||||
tokenizer.eot,
|
tokenizer.eot,
|
||||||
]
|
]
|
||||||
).to(model.device)
|
)
|
||||||
|
|
||||||
# install hooks on the cross attention layers to retrieve the attention weights
|
logits, cross_qk = model.forward_with_cross_qk(mel[None, :], tokens[None, :])
|
||||||
QKs = [None] * model.dims.n_text_layer
|
# consider only the logits associated with predicting text
|
||||||
hooks = [
|
sampled_logits = logits[0][len(tokenizer.sot_sequence) : -2, : tokenizer.eot]
|
||||||
block.cross_attn.register_forward_hook(
|
token_probs = mx.softmax(sampled_logits.astype(mx.float32), axis=-1).astype(
|
||||||
lambda _, ins, outs, index=i: QKs.__setitem__(index, outs[-1][0])
|
sampled_logits.dtype
|
||||||
)
|
)
|
||||||
for i, block in enumerate(model.decoder.blocks)
|
text_token_probs = mx.take_along_axis(
|
||||||
]
|
token_probs, mx.array(text_tokens)[:, None], axis=1
|
||||||
|
).squeeze(1)
|
||||||
with torch.no_grad():
|
text_token_probs = np.array(text_token_probs)
|
||||||
logits = model(mel.unsqueeze(0), tokens.unsqueeze(0))[0]
|
|
||||||
sampled_logits = logits[len(tokenizer.sot_sequence) :, : tokenizer.eot]
|
|
||||||
token_probs = sampled_logits.softmax(dim=-1)
|
|
||||||
text_token_probs = token_probs[np.arange(len(text_tokens)), text_tokens]
|
|
||||||
text_token_probs = text_token_probs.tolist()
|
|
||||||
|
|
||||||
for hook in hooks:
|
|
||||||
hook.remove()
|
|
||||||
|
|
||||||
# heads * tokens * frames
|
# heads * tokens * frames
|
||||||
weights = torch.stack([QKs[_l][_h] for _l, _h in model.alignment_heads.indices().T])
|
weights = mx.stack(
|
||||||
|
[cross_qk[_l.item()][0, _h.item()] for _l, _h in model.alignment_heads]
|
||||||
|
)
|
||||||
weights = weights[:, :, : num_frames // 2]
|
weights = weights[:, :, : num_frames // 2]
|
||||||
weights = (weights * qk_scale).softmax(dim=-1)
|
weights = mx.softmax(weights * qk_scale, axis=-1)
|
||||||
std, mean = torch.std_mean(weights, dim=-2, keepdim=True, unbiased=False)
|
mean = mx.mean(weights, axis=-2, keepdims=True)
|
||||||
|
std = mx.var(weights, axis=-2, keepdims=True, ddof=0).sqrt()
|
||||||
weights = (weights - mean) / std
|
weights = (weights - mean) / std
|
||||||
weights = median_filter(weights, medfilt_width)
|
weights = median_filter(np.array(weights), medfilt_width)
|
||||||
|
|
||||||
matrix = weights.mean(axis=0)
|
matrix = weights.mean(axis=0)
|
||||||
matrix = matrix[len(tokenizer.sot_sequence) : -1]
|
matrix = matrix[len(tokenizer.sot_sequence) : -1]
|
||||||
@ -281,7 +223,7 @@ def add_word_timestamps(
|
|||||||
segments: List[dict],
|
segments: List[dict],
|
||||||
model: "Whisper",
|
model: "Whisper",
|
||||||
tokenizer: Tokenizer,
|
tokenizer: Tokenizer,
|
||||||
mel: torch.Tensor,
|
mel: mx.array,
|
||||||
num_frames: int,
|
num_frames: int,
|
||||||
prepend_punctuations: str = "\"'“¿([{-",
|
prepend_punctuations: str = "\"'“¿([{-",
|
||||||
append_punctuations: str = "\"'.。,,!!??::”)]}、",
|
append_punctuations: str = "\"'.。,,!!??::”)]}、",
|
||||||
@ -301,6 +243,7 @@ def add_word_timestamps(
|
|||||||
word_durations = np.array([t.end - t.start for t in alignment])
|
word_durations = np.array([t.end - t.start for t in alignment])
|
||||||
word_durations = word_durations[word_durations.nonzero()]
|
word_durations = word_durations[word_durations.nonzero()]
|
||||||
median_duration = np.median(word_durations) if len(word_durations) > 0 else 0.0
|
median_duration = np.median(word_durations) if len(word_durations) > 0 else 0.0
|
||||||
|
median_duration = min(0.7, float(median_duration))
|
||||||
max_duration = median_duration * 2
|
max_duration = median_duration * 2
|
||||||
|
|
||||||
# hack: truncate long words at sentence boundaries.
|
# hack: truncate long words at sentence boundaries.
|
||||||
|
@ -1,7 +1,8 @@
|
|||||||
# Copyright © 2023 Apple Inc.
|
# Copyright © 2023 Apple Inc.
|
||||||
|
|
||||||
import sys
|
import sys
|
||||||
from typing import Optional, Tuple, Union
|
import warnings
|
||||||
|
from typing import List, Optional, Tuple, Union
|
||||||
|
|
||||||
import mlx.core as mx
|
import mlx.core as mx
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@ -18,7 +19,8 @@ from .audio import (
|
|||||||
)
|
)
|
||||||
from .decoding import DecodingOptions, DecodingResult
|
from .decoding import DecodingOptions, DecodingResult
|
||||||
from .load_models import load_model
|
from .load_models import load_model
|
||||||
from .tokenizer import LANGUAGES, TO_LANGUAGE_CODE, get_tokenizer
|
from .timing import add_word_timestamps
|
||||||
|
from .tokenizer import LANGUAGES, get_tokenizer
|
||||||
|
|
||||||
|
|
||||||
def _format_timestamp(seconds: float):
|
def _format_timestamp(seconds: float):
|
||||||
@ -38,6 +40,13 @@ def _format_timestamp(seconds: float):
|
|||||||
return f"{hours_marker}{minutes:02d}:{seconds:02d}.{milliseconds:03d}"
|
return f"{hours_marker}{minutes:02d}:{seconds:02d}.{milliseconds:03d}"
|
||||||
|
|
||||||
|
|
||||||
|
def _get_end(segments: List[dict]) -> Optional[float]:
|
||||||
|
return next(
|
||||||
|
(w["end"] for s in reversed(segments) for w in reversed(s["words"])),
|
||||||
|
segments[-1]["end"] if segments else None,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class ModelHolder:
|
class ModelHolder:
|
||||||
model = None
|
model = None
|
||||||
model_path = None
|
model_path = None
|
||||||
@ -61,8 +70,11 @@ def transcribe(
|
|||||||
no_speech_threshold: Optional[float] = 0.6,
|
no_speech_threshold: Optional[float] = 0.6,
|
||||||
condition_on_previous_text: bool = True,
|
condition_on_previous_text: bool = True,
|
||||||
initial_prompt: Optional[str] = None,
|
initial_prompt: Optional[str] = None,
|
||||||
|
word_timestamps: bool = False,
|
||||||
prepend_punctuations: str = "\"'“¿([{-",
|
prepend_punctuations: str = "\"'“¿([{-",
|
||||||
append_punctuations: str = "\"'.。,,!!??::”)]}、",
|
append_punctuations: str = "\"'.。,,!!??::”)]}、",
|
||||||
|
clip_timestamps: Union[str, List[float]] = "0",
|
||||||
|
hallucination_silence_threshold: Optional[float] = None,
|
||||||
**decode_options,
|
**decode_options,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
@ -99,6 +111,16 @@ def transcribe(
|
|||||||
disabling may make the text inconsistent across windows, but the model becomes less prone to
|
disabling may make the text inconsistent across windows, but the model becomes less prone to
|
||||||
getting stuck in a failure loop, such as repetition looping or timestamps going out of sync.
|
getting stuck in a failure loop, such as repetition looping or timestamps going out of sync.
|
||||||
|
|
||||||
|
word_timestamps: bool
|
||||||
|
Extract word-level timestamps using the cross-attention pattern and dynamic time warping,
|
||||||
|
and include the timestamps for each word in each segment.
|
||||||
|
|
||||||
|
prepend_punctuations: str
|
||||||
|
If word_timestamps is True, merge these punctuation symbols with the next word
|
||||||
|
|
||||||
|
append_punctuations: str
|
||||||
|
If word_timestamps is True, merge these punctuation symbols with the previous word
|
||||||
|
|
||||||
initial_prompt: Optional[str]
|
initial_prompt: Optional[str]
|
||||||
Optional text to provide as a prompt for the first window. This can be used to provide, or
|
Optional text to provide as a prompt for the first window. This can be used to provide, or
|
||||||
"prompt-engineer" a context for transcription, e.g. custom vocabularies or proper nouns
|
"prompt-engineer" a context for transcription, e.g. custom vocabularies or proper nouns
|
||||||
@ -107,6 +129,14 @@ def transcribe(
|
|||||||
decode_options: dict
|
decode_options: dict
|
||||||
Keyword arguments to construct `DecodingOptions` instances
|
Keyword arguments to construct `DecodingOptions` instances
|
||||||
|
|
||||||
|
clip_timestamps: Union[str, List[float]]
|
||||||
|
Comma-separated list start,end,start,end,... timestamps (in seconds) of clips to process.
|
||||||
|
The last end timestamp defaults to the end of the file.
|
||||||
|
|
||||||
|
hallucination_silence_threshold: Optional[float]
|
||||||
|
When word_timestamps is True, skip silent periods longer than this threshold (in seconds)
|
||||||
|
when a possible hallucination is detected
|
||||||
|
|
||||||
Returns
|
Returns
|
||||||
-------
|
-------
|
||||||
A dictionary containing the resulting text ("text") and segment-level details ("segments"), and
|
A dictionary containing the resulting text ("text") and segment-level details ("segments"), and
|
||||||
@ -119,6 +149,7 @@ def transcribe(
|
|||||||
# Pad 30-seconds of silence to the input audio, for slicing
|
# Pad 30-seconds of silence to the input audio, for slicing
|
||||||
mel = log_mel_spectrogram(audio, n_mels=model.dims.n_mels, padding=N_SAMPLES)
|
mel = log_mel_spectrogram(audio, n_mels=model.dims.n_mels, padding=N_SAMPLES)
|
||||||
content_frames = mel.shape[-2] - N_FRAMES
|
content_frames = mel.shape[-2] - N_FRAMES
|
||||||
|
content_duration = float(content_frames * HOP_LENGTH / SAMPLE_RATE)
|
||||||
|
|
||||||
if verbose:
|
if verbose:
|
||||||
system_encoding = sys.getdefaultencoding()
|
system_encoding = sys.getdefaultencoding()
|
||||||
@ -155,6 +186,22 @@ def transcribe(
|
|||||||
task=task,
|
task=task,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if isinstance(clip_timestamps, str):
|
||||||
|
clip_timestamps = [
|
||||||
|
float(ts) for ts in (clip_timestamps.split(",") if clip_timestamps else [])
|
||||||
|
]
|
||||||
|
seek_points: List[int] = [round(ts * FRAMES_PER_SECOND) for ts in clip_timestamps]
|
||||||
|
if len(seek_points) == 0:
|
||||||
|
seek_points.append(0)
|
||||||
|
if len(seek_points) % 2 == 1:
|
||||||
|
seek_points.append(content_frames)
|
||||||
|
seek_clips: List[Tuple[int, int]] = list(zip(seek_points[::2], seek_points[1::2]))
|
||||||
|
|
||||||
|
punctuation = "\"'“¿([{-\"'.。,,!!??::”)]}、"
|
||||||
|
|
||||||
|
if word_timestamps and task == "translate":
|
||||||
|
warnings.warn("Word-level timestamps on translations may not be reliable.")
|
||||||
|
|
||||||
def decode_with_fallback(segment: mx.array) -> DecodingResult:
|
def decode_with_fallback(segment: mx.array) -> DecodingResult:
|
||||||
temperatures = (
|
temperatures = (
|
||||||
[temperature] if isinstance(temperature, (int, float)) else temperature
|
[temperature] if isinstance(temperature, (int, float)) else temperature
|
||||||
@ -195,7 +242,8 @@ def transcribe(
|
|||||||
|
|
||||||
return decode_result
|
return decode_result
|
||||||
|
|
||||||
seek = 0
|
clip_idx = 0
|
||||||
|
seek = seek_clips[clip_idx][0]
|
||||||
input_stride = N_FRAMES // model.dims.n_audio_ctx # mel frames per output token: 2
|
input_stride = N_FRAMES // model.dims.n_audio_ctx # mel frames per output token: 2
|
||||||
time_precision = (
|
time_precision = (
|
||||||
input_stride * HOP_LENGTH / SAMPLE_RATE
|
input_stride * HOP_LENGTH / SAMPLE_RATE
|
||||||
@ -232,10 +280,23 @@ def transcribe(
|
|||||||
total=content_frames, unit="frames", disable=verbose is not False
|
total=content_frames, unit="frames", disable=verbose is not False
|
||||||
) as pbar:
|
) as pbar:
|
||||||
last_speech_timestamp = 0.0
|
last_speech_timestamp = 0.0
|
||||||
while seek < content_frames:
|
# NOTE: This loop is obscurely flattened to make the diff readable.
|
||||||
|
# A later commit should turn this into a simpler nested loop.
|
||||||
|
# for seek_clip_start, seek_clip_end in seek_clips:
|
||||||
|
# while seek < seek_clip_end
|
||||||
|
while clip_idx < len(seek_clips):
|
||||||
|
seek_clip_start, seek_clip_end = seek_clips[clip_idx]
|
||||||
|
if seek < seek_clip_start:
|
||||||
|
seek = seek_clip_start
|
||||||
|
if seek >= seek_clip_end:
|
||||||
|
clip_idx += 1
|
||||||
|
if clip_idx < len(seek_clips):
|
||||||
|
seek = seek_clips[clip_idx][0]
|
||||||
|
continue
|
||||||
time_offset = float(seek * HOP_LENGTH / SAMPLE_RATE)
|
time_offset = float(seek * HOP_LENGTH / SAMPLE_RATE)
|
||||||
mel_segment = mel[seek : seek + N_FRAMES]
|
window_end_time = float((seek + N_FRAMES) * HOP_LENGTH / SAMPLE_RATE)
|
||||||
segment_size = min(N_FRAMES, content_frames - seek)
|
segment_size = min(N_FRAMES, content_frames - seek, seek_clip_end - seek)
|
||||||
|
mel_segment = mel[seek : seek + segment_size]
|
||||||
segment_duration = segment_size * HOP_LENGTH / SAMPLE_RATE
|
segment_duration = segment_size * HOP_LENGTH / SAMPLE_RATE
|
||||||
mel_segment = pad_or_trim(mel_segment, N_FRAMES, axis=-2).astype(dtype)
|
mel_segment = pad_or_trim(mel_segment, N_FRAMES, axis=-2).astype(dtype)
|
||||||
|
|
||||||
@ -260,6 +321,30 @@ def transcribe(
|
|||||||
previous_seek = seek
|
previous_seek = seek
|
||||||
current_segments = []
|
current_segments = []
|
||||||
|
|
||||||
|
# anomalous words are very long/short/improbable
|
||||||
|
def word_anomaly_score(word: dict) -> float:
|
||||||
|
probability = word.get("probability", 0.0)
|
||||||
|
duration = word["end"] - word["start"]
|
||||||
|
score = 0.0
|
||||||
|
if probability < 0.15:
|
||||||
|
score += 1.0
|
||||||
|
if duration < 0.133:
|
||||||
|
score += (0.133 - duration) * 15
|
||||||
|
if duration > 2.0:
|
||||||
|
score += duration - 2.0
|
||||||
|
return score
|
||||||
|
|
||||||
|
def is_segment_anomaly(segment: Optional[dict]) -> bool:
|
||||||
|
if segment is None or not segment["words"]:
|
||||||
|
return False
|
||||||
|
words = [w for w in segment["words"] if w["word"] not in punctuation]
|
||||||
|
words = words[:8]
|
||||||
|
score = sum(word_anomaly_score(w) for w in words)
|
||||||
|
return score >= 3 or score + 0.01 >= len(words)
|
||||||
|
|
||||||
|
def next_words_segment(segments: List[dict]) -> Optional[dict]:
|
||||||
|
return next((s for s in segments if s["words"]), None)
|
||||||
|
|
||||||
timestamp_tokens = tokens >= tokenizer.timestamp_begin
|
timestamp_tokens = tokens >= tokenizer.timestamp_begin
|
||||||
single_timestamp_ending = timestamp_tokens[-2:].tolist() == [False, True]
|
single_timestamp_ending = timestamp_tokens[-2:].tolist() == [False, True]
|
||||||
|
|
||||||
@ -324,6 +409,83 @@ def transcribe(
|
|||||||
)
|
)
|
||||||
seek += segment_size
|
seek += segment_size
|
||||||
|
|
||||||
|
if word_timestamps:
|
||||||
|
add_word_timestamps(
|
||||||
|
segments=current_segments,
|
||||||
|
model=model,
|
||||||
|
tokenizer=tokenizer,
|
||||||
|
mel=mel_segment,
|
||||||
|
num_frames=segment_size,
|
||||||
|
prepend_punctuations=prepend_punctuations,
|
||||||
|
append_punctuations=append_punctuations,
|
||||||
|
last_speech_timestamp=last_speech_timestamp,
|
||||||
|
)
|
||||||
|
|
||||||
|
if not single_timestamp_ending:
|
||||||
|
last_word_end = _get_end(current_segments)
|
||||||
|
if last_word_end is not None and last_word_end > time_offset:
|
||||||
|
seek = round(last_word_end * FRAMES_PER_SECOND)
|
||||||
|
|
||||||
|
# skip silence before possible hallucinations
|
||||||
|
if hallucination_silence_threshold is not None:
|
||||||
|
threshold = hallucination_silence_threshold
|
||||||
|
if not single_timestamp_ending:
|
||||||
|
last_word_end = _get_end(current_segments)
|
||||||
|
if last_word_end is not None and last_word_end > time_offset:
|
||||||
|
remaining_duration = window_end_time - last_word_end
|
||||||
|
if remaining_duration > threshold:
|
||||||
|
seek = round(last_word_end * FRAMES_PER_SECOND)
|
||||||
|
else:
|
||||||
|
seek = previous_seek + segment_size
|
||||||
|
|
||||||
|
# if first segment might be a hallucination, skip leading silence
|
||||||
|
first_segment = next_words_segment(current_segments)
|
||||||
|
if first_segment is not None and is_segment_anomaly(first_segment):
|
||||||
|
gap = first_segment["start"] - time_offset
|
||||||
|
if gap > threshold:
|
||||||
|
seek = previous_seek + round(gap * FRAMES_PER_SECOND)
|
||||||
|
continue
|
||||||
|
|
||||||
|
# skip silence before any possible hallucination that is surrounded
|
||||||
|
# by silence or more hallucinations
|
||||||
|
hal_last_end = last_speech_timestamp
|
||||||
|
for si in range(len(current_segments)):
|
||||||
|
segment = current_segments[si]
|
||||||
|
if not segment["words"]:
|
||||||
|
continue
|
||||||
|
if is_segment_anomaly(segment):
|
||||||
|
next_segment = next_words_segment(
|
||||||
|
current_segments[si + 1 :]
|
||||||
|
)
|
||||||
|
if next_segment is not None:
|
||||||
|
hal_next_start = next_segment["words"][0]["start"]
|
||||||
|
else:
|
||||||
|
hal_next_start = time_offset + segment_duration
|
||||||
|
silence_before = (
|
||||||
|
segment["start"] - hal_last_end > threshold
|
||||||
|
or segment["start"] < threshold
|
||||||
|
or segment["start"] - time_offset < 2.0
|
||||||
|
)
|
||||||
|
silence_after = (
|
||||||
|
hal_next_start - segment["end"] > threshold
|
||||||
|
or is_segment_anomaly(next_segment)
|
||||||
|
or window_end_time - segment["end"] < 2.0
|
||||||
|
)
|
||||||
|
if silence_before and silence_after:
|
||||||
|
seek = round(
|
||||||
|
max(time_offset + 1, segment["start"])
|
||||||
|
* FRAMES_PER_SECOND
|
||||||
|
)
|
||||||
|
if content_duration - segment["end"] < threshold:
|
||||||
|
seek = content_frames
|
||||||
|
current_segments[si:] = []
|
||||||
|
break
|
||||||
|
hal_last_end = segment["end"]
|
||||||
|
|
||||||
|
last_word_end = _get_end(current_segments)
|
||||||
|
if last_word_end is not None:
|
||||||
|
last_speech_timestamp = last_word_end
|
||||||
|
|
||||||
if verbose:
|
if verbose:
|
||||||
for segment in current_segments:
|
for segment in current_segments:
|
||||||
start, end, text = segment["start"], segment["end"], segment["text"]
|
start, end, text = segment["start"], segment["end"], segment["text"]
|
||||||
|
@ -4,7 +4,7 @@ import base64
|
|||||||
import gzip
|
import gzip
|
||||||
import math
|
import math
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Dict, Iterable, Optional
|
from typing import Union
|
||||||
|
|
||||||
import mlx.core as mx
|
import mlx.core as mx
|
||||||
import mlx.nn as nn
|
import mlx.nn as nn
|
||||||
@ -72,8 +72,8 @@ class MultiHeadAttention(nn.Module):
|
|||||||
else:
|
else:
|
||||||
k, v = kv_cache
|
k, v = kv_cache
|
||||||
|
|
||||||
wv = self.qkv_attention(q, k, v, mask)
|
wv, qk = self.qkv_attention(q, k, v, mask)
|
||||||
return self.out(wv), (k, v)
|
return self.out(wv), (k, v), qk
|
||||||
|
|
||||||
def qkv_attention(self, q, k, v, mask=None):
|
def qkv_attention(self, q, k, v, mask=None):
|
||||||
n_batch, n_ctx, n_state = q.shape
|
n_batch, n_ctx, n_state = q.shape
|
||||||
@ -85,12 +85,12 @@ class MultiHeadAttention(nn.Module):
|
|||||||
qk = q @ k
|
qk = q @ k
|
||||||
if mask is not None:
|
if mask is not None:
|
||||||
qk = qk + mask[:n_ctx, :n_ctx]
|
qk = qk + mask[:n_ctx, :n_ctx]
|
||||||
|
|
||||||
qk = qk.astype(mx.float32)
|
qk = qk.astype(mx.float32)
|
||||||
|
|
||||||
w = mx.softmax(qk, axis=-1).astype(q.dtype)
|
w = mx.softmax(qk, axis=-1).astype(q.dtype)
|
||||||
out = (w @ v).transpose(0, 2, 1, 3)
|
out = (w @ v).transpose(0, 2, 1, 3)
|
||||||
out = out.reshape(n_batch, n_ctx, n_state)
|
out = out.reshape(n_batch, n_ctx, n_state)
|
||||||
return out
|
return out, qk
|
||||||
|
|
||||||
|
|
||||||
class ResidualAttentionBlock(nn.Module):
|
class ResidualAttentionBlock(nn.Module):
|
||||||
@ -112,13 +112,16 @@ class ResidualAttentionBlock(nn.Module):
|
|||||||
|
|
||||||
def __call__(self, x, xa=None, mask=None, kv_cache=None):
|
def __call__(self, x, xa=None, mask=None, kv_cache=None):
|
||||||
kv, cross_kv = kv_cache if kv_cache else (None, None)
|
kv, cross_kv = kv_cache if kv_cache else (None, None)
|
||||||
y, kv = self.attn(self.attn_ln(x), mask=mask, kv_cache=kv)
|
y, kv, _ = self.attn(self.attn_ln(x), mask=mask, kv_cache=kv)
|
||||||
x += y
|
x += y
|
||||||
|
cross_qk = None
|
||||||
if self.cross_attn:
|
if self.cross_attn:
|
||||||
y, cross_kv = self.cross_attn(self.cross_attn_ln(x), xa, kv_cache=cross_kv)
|
y, cross_kv, cross_qk = self.cross_attn(
|
||||||
|
self.cross_attn_ln(x), xa, kv_cache=cross_kv
|
||||||
|
)
|
||||||
x += y
|
x += y
|
||||||
x = x + self.mlp2(nn.gelu(self.mlp1(self.mlp_ln(x))).astype(x.dtype))
|
x = x + self.mlp2(nn.gelu(self.mlp1(self.mlp_ln(x))).astype(x.dtype))
|
||||||
return x, (kv, cross_kv)
|
return x, (kv, cross_kv), cross_qk
|
||||||
|
|
||||||
|
|
||||||
class AudioEncoder(nn.Module):
|
class AudioEncoder(nn.Module):
|
||||||
@ -146,7 +149,7 @@ class AudioEncoder(nn.Module):
|
|||||||
x = x + self._positional_embedding
|
x = x + self._positional_embedding
|
||||||
|
|
||||||
for block in self.blocks:
|
for block in self.blocks:
|
||||||
x, _ = block(x)
|
x, _, _ = block(x)
|
||||||
|
|
||||||
x = self.ln_post(x)
|
x = self.ln_post(x)
|
||||||
return x
|
return x
|
||||||
@ -191,11 +194,14 @@ class TextDecoder(nn.Module):
|
|||||||
|
|
||||||
if kv_cache is None:
|
if kv_cache is None:
|
||||||
kv_cache = [None] * len(self.blocks)
|
kv_cache = [None] * len(self.blocks)
|
||||||
|
cross_qk = [None] * len(self.blocks)
|
||||||
for e, block in enumerate(self.blocks):
|
for e, block in enumerate(self.blocks):
|
||||||
x, kv_cache[e] = block(x, xa, mask=self._mask, kv_cache=kv_cache[e])
|
x, kv_cache[e], cross_qk[e] = block(
|
||||||
|
x, xa, mask=self._mask, kv_cache=kv_cache[e]
|
||||||
|
)
|
||||||
|
|
||||||
x = self.ln(x)
|
x = self.ln(x)
|
||||||
return x @ self.token_embedding.weight.T, kv_cache
|
return x @ self.token_embedding.weight.T, kv_cache, cross_qk
|
||||||
|
|
||||||
|
|
||||||
class Whisper(nn.Module):
|
class Whisper(nn.Module):
|
||||||
@ -218,6 +224,28 @@ class Whisper(nn.Module):
|
|||||||
self.dims.n_text_layer,
|
self.dims.n_text_layer,
|
||||||
dtype,
|
dtype,
|
||||||
)
|
)
|
||||||
|
# use the last half among the decoder layers for time alignment by default;
|
||||||
|
# to use a specific set of heads, see `set_alignment_heads()` below.
|
||||||
|
all_heads = np.zeros(
|
||||||
|
(self.dims.n_text_layer, self.dims.n_text_head), dtype=bool
|
||||||
|
)
|
||||||
|
all_heads[self.dims.n_text_layer // 2 :] = True
|
||||||
|
self.alignment_heads = mx.array(np.asarray(all_heads.nonzero()).T)
|
||||||
|
|
||||||
|
def set_alignment_heads(self, dump: Union[bytes, np.ndarray]):
|
||||||
|
if isinstance(dump, np.ndarray):
|
||||||
|
self.alignment_heads = mx.array(dump)
|
||||||
|
elif isinstance(dump, bytes):
|
||||||
|
array = np.frombuffer(
|
||||||
|
gzip.decompress(base64.b85decode(dump)), dtype=bool
|
||||||
|
).copy()
|
||||||
|
mask = array.reshape(self.dims.n_text_layer, self.dims.n_text_head)
|
||||||
|
self.alignment_heads = mx.array(np.asarray(mask.nonzero()).T)
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
f"Invalid type for `dump`: {type(dump)}. Expected a np.ndarray or base85-encoded bytes containing"
|
||||||
|
" alignment_head information"
|
||||||
|
)
|
||||||
|
|
||||||
def embed_audio(self, mel):
|
def embed_audio(self, mel):
|
||||||
return self.encoder(mel)
|
return self.encoder(mel)
|
||||||
@ -225,6 +253,10 @@ class Whisper(nn.Module):
|
|||||||
def logits(self, tokens, audio_features):
|
def logits(self, tokens, audio_features):
|
||||||
return self.decoder(tokens, audio_features)[0]
|
return self.decoder(tokens, audio_features)[0]
|
||||||
|
|
||||||
|
def forward_with_cross_qk(self, mel, tokens):
|
||||||
|
logits, _, cross_qk = self.decoder(tokens, self.encoder(mel))
|
||||||
|
return logits, cross_qk
|
||||||
|
|
||||||
def __call__(self, mel, tokens):
|
def __call__(self, mel, tokens):
|
||||||
return self.decoder(tokens, self.encoder(mel))[0]
|
return self.decoder(tokens, self.encoder(mel))[0]
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user