mirror of
				https://github.com/ml-explore/mlx-examples.git
				synced 2025-10-23 14:08:07 +08:00 
			
		
		
		
	[Whisper] Add word timestamps and confidence scores (#201)
* Add word timestamps and confidence scores * Create a separate forward_with_cross_qk function * Move multiple ops from np to mlx, clean comments * Save alignment_heads * Cast qk to fp32 * Add test for word-level timestamps and confidence scores * format + readme * nit --------- Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
		| @@ -2,7 +2,7 @@ | ||||
|  | ||||
| Speech recognition with Whisper in MLX. Whisper is a set of open source speech | ||||
| recognition models from OpenAI, ranging from 39 million to 1.5 billion | ||||
| parameters[^1]. | ||||
| parameters.[^1] | ||||
|  | ||||
| ### Setup | ||||
|  | ||||
| @@ -19,7 +19,8 @@ Install [`ffmpeg`](https://ffmpeg.org/): | ||||
| brew install ffmpeg | ||||
| ``` | ||||
|  | ||||
| Next, download the Whisper PyTorch checkpoint and convert the weights to the MLX format. For example, to convert the `tiny` model use: | ||||
| Next, download the Whisper PyTorch checkpoint and convert the weights to the | ||||
| MLX format. For example, to convert the `tiny` model use: | ||||
|  | ||||
| ``` | ||||
| python convert.py --torch-name-or-path tiny --mlx-path mlx_models/tiny | ||||
| @@ -45,10 +46,24 @@ the converted `weights.npz` and `config.json` there. | ||||
|  | ||||
| Transcribe audio with: | ||||
|  | ||||
| ``` | ||||
| ```python | ||||
| import whisper | ||||
|  | ||||
| text = whisper.transcribe(speech_file)["text"] | ||||
| ``` | ||||
|  | ||||
| The `transcribe` function also supports word-level timestamps. You can generate | ||||
| these with: | ||||
|  | ||||
| ```python | ||||
| output = whisper.transcribe(speech_file, word_timestamps=True) | ||||
| print(output["segments"][0]["words"]) | ||||
| ``` | ||||
|  | ||||
| To see more transcription options use: | ||||
|  | ||||
| ``` | ||||
| >>> help(whisper.transcribe) | ||||
| ``` | ||||
|  | ||||
| [^1]: Refer to the [arXiv paper](https://arxiv.org/abs/2212.04356), [blog post](https://openai.com/research/whisper), and [code](https://github.com/openai/whisper) for more details. | ||||
|   | ||||
| @@ -199,6 +199,10 @@ def torch_to_mlx( | ||||
|     mlx_model = Whisper(torch_model.dims, dtype) | ||||
|     params = tree_map(lambda p: p.astype(dtype), params) | ||||
|     mlx_model.update(params) | ||||
|  | ||||
|     if (alignment_heads := getattr(torch_model, "alignment_heads", None)) is not None: | ||||
|         mlx_model.set_alignment_heads(alignment_heads.indices().T.numpy()) | ||||
|  | ||||
|     return mlx_model | ||||
|  | ||||
|  | ||||
|   | ||||
							
								
								
									
										131
									
								
								whisper/test.py
									
									
									
									
									
								
							
							
						
						
									
										131
									
								
								whisper/test.py
									
									
									
									
									
								
							| @@ -311,6 +311,137 @@ class TestWhisper(unittest.TestCase): | ||||
|         check_segment(result["segments"][5], expected_5) | ||||
|         check_segment(result["segments"][73], expected_73) | ||||
|  | ||||
|     def test_transcribe_word_level_timestamps_confidence_scores(self): | ||||
|         result = whisper.transcribe( | ||||
|             # TEST_AUDIO, model_path=MLX_FP32_MODEL_PATH, word_timestamps=True, fp16=False | ||||
|             TEST_AUDIO, | ||||
|             model_path=MLX_FP16_MODEL_PATH, | ||||
|             word_timestamps=True, | ||||
|         ) | ||||
|  | ||||
|         # result predicted with openai-whisper | ||||
|         expected_0 = [ | ||||
|             { | ||||
|                 "word": " Then", | ||||
|                 "start": 0.0, | ||||
|                 "end": 0.94, | ||||
|                 "probability": 0.855542778968811, | ||||
|             }, | ||||
|             { | ||||
|                 "word": " the", | ||||
|                 "start": 0.94, | ||||
|                 "end": 1.12, | ||||
|                 "probability": 0.6500106453895569, | ||||
|             }, | ||||
|             { | ||||
|                 "word": " good", | ||||
|                 "start": 1.12, | ||||
|                 "end": 1.32, | ||||
|                 "probability": 0.5503873825073242, | ||||
|             }, | ||||
|             { | ||||
|                 "word": " soul", | ||||
|                 "start": 1.32, | ||||
|                 "end": 1.56, | ||||
|                 "probability": 0.46757155656814575, | ||||
|             }, | ||||
|             { | ||||
|                 "word": " openly", | ||||
|                 "start": 1.56, | ||||
|                 "end": 2.0, | ||||
|                 "probability": 0.9840946793556213, | ||||
|             }, | ||||
|             { | ||||
|                 "word": " sorted", | ||||
|                 "start": 2.0, | ||||
|                 "end": 2.38, | ||||
|                 "probability": 0.24167272448539734, | ||||
|             }, | ||||
|             { | ||||
|                 "word": " the", | ||||
|                 "start": 2.38, | ||||
|                 "end": 2.58, | ||||
|                 "probability": 0.9875414967536926, | ||||
|             }, | ||||
|             { | ||||
|                 "word": " boat", | ||||
|                 "start": 2.58, | ||||
|                 "end": 2.8, | ||||
|                 "probability": 0.5856029391288757, | ||||
|             }, | ||||
|             { | ||||
|                 "word": " and", | ||||
|                 "start": 2.8, | ||||
|                 "end": 2.98, | ||||
|                 "probability": 0.913351833820343, | ||||
|             }, | ||||
|             { | ||||
|                 "word": " she", | ||||
|                 "start": 2.98, | ||||
|                 "end": 3.1, | ||||
|                 "probability": 0.9913808703422546, | ||||
|             }, | ||||
|             { | ||||
|                 "word": " had", | ||||
|                 "start": 3.1, | ||||
|                 "end": 3.32, | ||||
|                 "probability": 0.9952940344810486, | ||||
|             }, | ||||
|             { | ||||
|                 "word": " buoyed", | ||||
|                 "start": 3.32, | ||||
|                 "end": 3.58, | ||||
|                 "probability": 0.6411589980125427, | ||||
|             }, | ||||
|             { | ||||
|                 "word": " so", | ||||
|                 "start": 3.58, | ||||
|                 "end": 3.8, | ||||
|                 "probability": 0.9682658314704895, | ||||
|             }, | ||||
|             { | ||||
|                 "word": " long", | ||||
|                 "start": 3.8, | ||||
|                 "end": 4.06, | ||||
|                 "probability": 0.9953522682189941, | ||||
|             }, | ||||
|             { | ||||
|                 "word": " in", | ||||
|                 "start": 4.06, | ||||
|                 "end": 4.26, | ||||
|                 "probability": 0.6745936870574951, | ||||
|             }, | ||||
|             { | ||||
|                 "word": " secret", | ||||
|                 "start": 4.26, | ||||
|                 "end": 4.56, | ||||
|                 "probability": 0.9905064702033997, | ||||
|             }, | ||||
|             { | ||||
|                 "word": " and", | ||||
|                 "start": 4.56, | ||||
|                 "end": 4.9, | ||||
|                 "probability": 0.856008768081665, | ||||
|             }, | ||||
|             { | ||||
|                 "word": " bravely", | ||||
|                 "start": 4.9, | ||||
|                 "end": 5.28, | ||||
|                 "probability": 0.8477402329444885, | ||||
|             }, | ||||
|         ] | ||||
|  | ||||
|         def check_words(words, expected_words): | ||||
|             for word, expected_word in zip(words, expected_words): | ||||
|                 for k, v in expected_word.items(): | ||||
|                     if isinstance(v, float): | ||||
|                         self.assertAlmostEqual(word[k], v, places=1) | ||||
|                     else: | ||||
|                         self.assertEqual(word[k], v) | ||||
|  | ||||
|         # Randomly check a couple of segments | ||||
|         check_words(result["segments"][0]["words"], expected_0) | ||||
|  | ||||
|  | ||||
| class TestAudio(unittest.TestCase): | ||||
|     def test_load(self): | ||||
|   | ||||
| @@ -141,7 +141,7 @@ class Inference: | ||||
|             # only need to use the last token except in the first forward pass | ||||
|             tokens = tokens[:, -1:] | ||||
|  | ||||
|         logits, self.kv_cache = self.model.decoder( | ||||
|         logits, self.kv_cache, _ = self.model.decoder( | ||||
|             tokens, audio_features, kv_cache=self.kv_cache | ||||
|         ) | ||||
|         return logits.astype(mx.float32) | ||||
|   | ||||
| @@ -1,15 +1,13 @@ | ||||
| # Copyright © 2023 Apple Inc. | ||||
|  | ||||
| import itertools | ||||
| import subprocess | ||||
| import warnings | ||||
| from dataclasses import dataclass | ||||
| from typing import TYPE_CHECKING, List | ||||
|  | ||||
| import mlx.core as mx | ||||
| import numba | ||||
| import numpy as np | ||||
| import torch | ||||
| import torch.nn.functional as F | ||||
| from scipy import signal | ||||
|  | ||||
| from .audio import HOP_LENGTH, SAMPLE_RATE, TOKENS_PER_SECOND | ||||
| from .tokenizer import Tokenizer | ||||
| @@ -18,7 +16,7 @@ if TYPE_CHECKING: | ||||
|     from .model import Whisper | ||||
|  | ||||
|  | ||||
| def median_filter(x: torch.Tensor, filter_width: int): | ||||
| def median_filter(x: np.ndarray, filter_width: int): | ||||
|     """Apply a median filter of width `filter_width` along the last dimension of `x`""" | ||||
|     pad_width = filter_width // 2 | ||||
|     if x.shape[-1] <= pad_width: | ||||
| @@ -33,22 +31,12 @@ def median_filter(x: torch.Tensor, filter_width: int): | ||||
|         filter_width > 0 and filter_width % 2 == 1 | ||||
|     ), "`filter_width` should be an odd number" | ||||
|  | ||||
|     result = None | ||||
|     x = F.pad(x, (filter_width // 2, filter_width // 2, 0, 0), mode="reflect") | ||||
|     if x.is_cuda: | ||||
|         try: | ||||
|             from .triton_ops import median_filter_cuda | ||||
|     x = np.pad(x, ((0, 0), (0, 0), (pad_width, pad_width)), mode="reflect") | ||||
|  | ||||
|             result = median_filter_cuda(x, filter_width) | ||||
|         except (RuntimeError, subprocess.CalledProcessError): | ||||
|             warnings.warn( | ||||
|                 "Failed to launch Triton kernels, likely due to missing CUDA toolkit; " | ||||
|                 "falling back to a slower median kernel implementation..." | ||||
|             ) | ||||
|  | ||||
|     if result is None: | ||||
|         # sort() is faster than torch.median (https://github.com/pytorch/pytorch/issues/51450) | ||||
|         result = x.unfold(-1, filter_width, 1).sort()[0][..., filter_width // 2] | ||||
|     # todo: more efficient version in mlx | ||||
|     result = signal.medfilt(x.astype(np.float32), kernel_size=(1, 1, filter_width))[ | ||||
|         ..., pad_width:-pad_width | ||||
|     ] | ||||
|  | ||||
|     if ndim <= 2: | ||||
|         result = result[0, 0] | ||||
| @@ -107,50 +95,9 @@ def dtw_cpu(x: np.ndarray): | ||||
|     return backtrace(trace) | ||||
|  | ||||
|  | ||||
| def dtw_cuda(x, BLOCK_SIZE=1024): | ||||
|     from .triton_ops import dtw_kernel | ||||
|  | ||||
|     M, N = x.shape | ||||
|     assert M < BLOCK_SIZE, f"M should be smaller than {BLOCK_SIZE=}" | ||||
|  | ||||
|     x_skew = ( | ||||
|         F.pad(x, (0, M + 1), value=np.inf).flatten()[: M * (N + M)].reshape(M, N + M) | ||||
|     ) | ||||
|     x_skew = x_skew.T.contiguous() | ||||
|     cost = torch.ones(N + M + 2, M + 2) * np.inf | ||||
|     cost[0, 0] = 0 | ||||
|     cost = cost.cuda() | ||||
|     trace = torch.zeros_like(cost, dtype=torch.int32) | ||||
|  | ||||
|     dtw_kernel[(1,)]( | ||||
|         cost, | ||||
|         trace, | ||||
|         x_skew, | ||||
|         x_skew.stride(0), | ||||
|         cost.stride(0), | ||||
|         trace.stride(0), | ||||
|         N, | ||||
|         M, | ||||
|         BLOCK_SIZE=BLOCK_SIZE, | ||||
|     ) | ||||
|  | ||||
|     trace = trace.T.flatten()[: (M + 1) * (M + N + 3)].reshape(M + 1, M + N + 3)[ | ||||
|         :, : N + 1 | ||||
|     ] | ||||
|     return backtrace(trace.cpu().numpy()) | ||||
|  | ||||
|  | ||||
| def dtw(x: torch.Tensor) -> np.ndarray: | ||||
|     if x.is_cuda: | ||||
|         try: | ||||
|             return dtw_cuda(x) | ||||
|         except (RuntimeError, subprocess.CalledProcessError): | ||||
|             warnings.warn( | ||||
|                 "Failed to launch Triton kernels, likely due to missing CUDA toolkit; " | ||||
|                 "falling back to a slower DTW implementation..." | ||||
|             ) | ||||
|  | ||||
|     return dtw_cpu(x.double().cpu().numpy()) | ||||
| def dtw(x: np.ndarray) -> np.ndarray: | ||||
|     # todo: more efficient version in mlx | ||||
|     return dtw_cpu(x) | ||||
|  | ||||
|  | ||||
| @dataclass | ||||
| @@ -166,7 +113,7 @@ def find_alignment( | ||||
|     model: "Whisper", | ||||
|     tokenizer: Tokenizer, | ||||
|     text_tokens: List[int], | ||||
|     mel: torch.Tensor, | ||||
|     mel: mx.array, | ||||
|     num_frames: int, | ||||
|     *, | ||||
|     medfilt_width: int = 7, | ||||
| @@ -175,41 +122,36 @@ def find_alignment( | ||||
|     if len(text_tokens) == 0: | ||||
|         return [] | ||||
|  | ||||
|     tokens = torch.tensor( | ||||
|     tokens = mx.array( | ||||
|         [ | ||||
|             *tokenizer.sot_sequence, | ||||
|             tokenizer.no_timestamps, | ||||
|             *text_tokens, | ||||
|             tokenizer.eot, | ||||
|         ] | ||||
|     ).to(model.device) | ||||
|     ) | ||||
|  | ||||
|     # install hooks on the cross attention layers to retrieve the attention weights | ||||
|     QKs = [None] * model.dims.n_text_layer | ||||
|     hooks = [ | ||||
|         block.cross_attn.register_forward_hook( | ||||
|             lambda _, ins, outs, index=i: QKs.__setitem__(index, outs[-1][0]) | ||||
|         ) | ||||
|         for i, block in enumerate(model.decoder.blocks) | ||||
|     ] | ||||
|  | ||||
|     with torch.no_grad(): | ||||
|         logits = model(mel.unsqueeze(0), tokens.unsqueeze(0))[0] | ||||
|         sampled_logits = logits[len(tokenizer.sot_sequence) :, : tokenizer.eot] | ||||
|         token_probs = sampled_logits.softmax(dim=-1) | ||||
|         text_token_probs = token_probs[np.arange(len(text_tokens)), text_tokens] | ||||
|         text_token_probs = text_token_probs.tolist() | ||||
|  | ||||
|     for hook in hooks: | ||||
|         hook.remove() | ||||
|     logits, cross_qk = model.forward_with_cross_qk(mel[None, :], tokens[None, :]) | ||||
|     # consider only the logits associated with predicting text | ||||
|     sampled_logits = logits[0][len(tokenizer.sot_sequence) : -2, : tokenizer.eot] | ||||
|     token_probs = mx.softmax(sampled_logits.astype(mx.float32), axis=-1).astype( | ||||
|         sampled_logits.dtype | ||||
|     ) | ||||
|     text_token_probs = mx.take_along_axis( | ||||
|         token_probs, mx.array(text_tokens)[:, None], axis=1 | ||||
|     ).squeeze(1) | ||||
|     text_token_probs = np.array(text_token_probs) | ||||
|  | ||||
|     # heads * tokens * frames | ||||
|     weights = torch.stack([QKs[_l][_h] for _l, _h in model.alignment_heads.indices().T]) | ||||
|     weights = mx.stack( | ||||
|         [cross_qk[_l.item()][0, _h.item()] for _l, _h in model.alignment_heads] | ||||
|     ) | ||||
|     weights = weights[:, :, : num_frames // 2] | ||||
|     weights = (weights * qk_scale).softmax(dim=-1) | ||||
|     std, mean = torch.std_mean(weights, dim=-2, keepdim=True, unbiased=False) | ||||
|     weights = mx.softmax(weights * qk_scale, axis=-1) | ||||
|     mean = mx.mean(weights, axis=-2, keepdims=True) | ||||
|     std = mx.var(weights, axis=-2, keepdims=True, ddof=0).sqrt() | ||||
|     weights = (weights - mean) / std | ||||
|     weights = median_filter(weights, medfilt_width) | ||||
|     weights = median_filter(np.array(weights), medfilt_width) | ||||
|  | ||||
|     matrix = weights.mean(axis=0) | ||||
|     matrix = matrix[len(tokenizer.sot_sequence) : -1] | ||||
| @@ -281,7 +223,7 @@ def add_word_timestamps( | ||||
|     segments: List[dict], | ||||
|     model: "Whisper", | ||||
|     tokenizer: Tokenizer, | ||||
|     mel: torch.Tensor, | ||||
|     mel: mx.array, | ||||
|     num_frames: int, | ||||
|     prepend_punctuations: str = "\"'“¿([{-", | ||||
|     append_punctuations: str = "\"'.。,,!!??::”)]}、", | ||||
| @@ -301,6 +243,7 @@ def add_word_timestamps( | ||||
|     word_durations = np.array([t.end - t.start for t in alignment]) | ||||
|     word_durations = word_durations[word_durations.nonzero()] | ||||
|     median_duration = np.median(word_durations) if len(word_durations) > 0 else 0.0 | ||||
|     median_duration = min(0.7, float(median_duration)) | ||||
|     max_duration = median_duration * 2 | ||||
|  | ||||
|     # hack: truncate long words at sentence boundaries. | ||||
|   | ||||
| @@ -1,7 +1,8 @@ | ||||
| # Copyright © 2023 Apple Inc. | ||||
|  | ||||
| import sys | ||||
| from typing import Optional, Tuple, Union | ||||
| import warnings | ||||
| from typing import List, Optional, Tuple, Union | ||||
|  | ||||
| import mlx.core as mx | ||||
| import numpy as np | ||||
| @@ -18,7 +19,8 @@ from .audio import ( | ||||
| ) | ||||
| from .decoding import DecodingOptions, DecodingResult | ||||
| from .load_models import load_model | ||||
| from .tokenizer import LANGUAGES, TO_LANGUAGE_CODE, get_tokenizer | ||||
| from .timing import add_word_timestamps | ||||
| from .tokenizer import LANGUAGES, get_tokenizer | ||||
|  | ||||
|  | ||||
| def _format_timestamp(seconds: float): | ||||
| @@ -38,6 +40,13 @@ def _format_timestamp(seconds: float): | ||||
|     return f"{hours_marker}{minutes:02d}:{seconds:02d}.{milliseconds:03d}" | ||||
|  | ||||
|  | ||||
| def _get_end(segments: List[dict]) -> Optional[float]: | ||||
|     return next( | ||||
|         (w["end"] for s in reversed(segments) for w in reversed(s["words"])), | ||||
|         segments[-1]["end"] if segments else None, | ||||
|     ) | ||||
|  | ||||
|  | ||||
| class ModelHolder: | ||||
|     model = None | ||||
|     model_path = None | ||||
| @@ -61,8 +70,11 @@ def transcribe( | ||||
|     no_speech_threshold: Optional[float] = 0.6, | ||||
|     condition_on_previous_text: bool = True, | ||||
|     initial_prompt: Optional[str] = None, | ||||
|     word_timestamps: bool = False, | ||||
|     prepend_punctuations: str = "\"'“¿([{-", | ||||
|     append_punctuations: str = "\"'.。,,!!??::”)]}、", | ||||
|     clip_timestamps: Union[str, List[float]] = "0", | ||||
|     hallucination_silence_threshold: Optional[float] = None, | ||||
|     **decode_options, | ||||
| ): | ||||
|     """ | ||||
| @@ -99,6 +111,16 @@ def transcribe( | ||||
|         disabling may make the text inconsistent across windows, but the model becomes less prone to | ||||
|         getting stuck in a failure loop, such as repetition looping or timestamps going out of sync. | ||||
|  | ||||
|     word_timestamps: bool | ||||
|         Extract word-level timestamps using the cross-attention pattern and dynamic time warping, | ||||
|         and include the timestamps for each word in each segment. | ||||
|  | ||||
|     prepend_punctuations: str | ||||
|         If word_timestamps is True, merge these punctuation symbols with the next word | ||||
|  | ||||
|     append_punctuations: str | ||||
|         If word_timestamps is True, merge these punctuation symbols with the previous word | ||||
|  | ||||
|     initial_prompt: Optional[str] | ||||
|         Optional text to provide as a prompt for the first window. This can be used to provide, or | ||||
|         "prompt-engineer" a context for transcription, e.g. custom vocabularies or proper nouns | ||||
| @@ -107,6 +129,14 @@ def transcribe( | ||||
|     decode_options: dict | ||||
|         Keyword arguments to construct `DecodingOptions` instances | ||||
|  | ||||
|     clip_timestamps: Union[str, List[float]] | ||||
|         Comma-separated list start,end,start,end,... timestamps (in seconds) of clips to process. | ||||
|         The last end timestamp defaults to the end of the file. | ||||
|  | ||||
|     hallucination_silence_threshold: Optional[float] | ||||
|         When word_timestamps is True, skip silent periods longer than this threshold (in seconds) | ||||
|         when a possible hallucination is detected | ||||
|  | ||||
|     Returns | ||||
|     ------- | ||||
|     A dictionary containing the resulting text ("text") and segment-level details ("segments"), and | ||||
| @@ -119,6 +149,7 @@ def transcribe( | ||||
|     # Pad 30-seconds of silence to the input audio, for slicing | ||||
|     mel = log_mel_spectrogram(audio, n_mels=model.dims.n_mels, padding=N_SAMPLES) | ||||
|     content_frames = mel.shape[-2] - N_FRAMES | ||||
|     content_duration = float(content_frames * HOP_LENGTH / SAMPLE_RATE) | ||||
|  | ||||
|     if verbose: | ||||
|         system_encoding = sys.getdefaultencoding() | ||||
| @@ -155,6 +186,22 @@ def transcribe( | ||||
|         task=task, | ||||
|     ) | ||||
|  | ||||
|     if isinstance(clip_timestamps, str): | ||||
|         clip_timestamps = [ | ||||
|             float(ts) for ts in (clip_timestamps.split(",") if clip_timestamps else []) | ||||
|         ] | ||||
|     seek_points: List[int] = [round(ts * FRAMES_PER_SECOND) for ts in clip_timestamps] | ||||
|     if len(seek_points) == 0: | ||||
|         seek_points.append(0) | ||||
|     if len(seek_points) % 2 == 1: | ||||
|         seek_points.append(content_frames) | ||||
|     seek_clips: List[Tuple[int, int]] = list(zip(seek_points[::2], seek_points[1::2])) | ||||
|  | ||||
|     punctuation = "\"'“¿([{-\"'.。,,!!??::”)]}、" | ||||
|  | ||||
|     if word_timestamps and task == "translate": | ||||
|         warnings.warn("Word-level timestamps on translations may not be reliable.") | ||||
|  | ||||
|     def decode_with_fallback(segment: mx.array) -> DecodingResult: | ||||
|         temperatures = ( | ||||
|             [temperature] if isinstance(temperature, (int, float)) else temperature | ||||
| @@ -195,7 +242,8 @@ def transcribe( | ||||
|  | ||||
|         return decode_result | ||||
|  | ||||
|     seek = 0 | ||||
|     clip_idx = 0 | ||||
|     seek = seek_clips[clip_idx][0] | ||||
|     input_stride = N_FRAMES // model.dims.n_audio_ctx  # mel frames per output token: 2 | ||||
|     time_precision = ( | ||||
|         input_stride * HOP_LENGTH / SAMPLE_RATE | ||||
| @@ -232,10 +280,23 @@ def transcribe( | ||||
|         total=content_frames, unit="frames", disable=verbose is not False | ||||
|     ) as pbar: | ||||
|         last_speech_timestamp = 0.0 | ||||
|         while seek < content_frames: | ||||
|         # NOTE: This loop is obscurely flattened to make the diff readable. | ||||
|         # A later commit should turn this into a simpler nested loop. | ||||
|         # for seek_clip_start, seek_clip_end in seek_clips: | ||||
|         #     while seek < seek_clip_end | ||||
|         while clip_idx < len(seek_clips): | ||||
|             seek_clip_start, seek_clip_end = seek_clips[clip_idx] | ||||
|             if seek < seek_clip_start: | ||||
|                 seek = seek_clip_start | ||||
|             if seek >= seek_clip_end: | ||||
|                 clip_idx += 1 | ||||
|                 if clip_idx < len(seek_clips): | ||||
|                     seek = seek_clips[clip_idx][0] | ||||
|                 continue | ||||
|             time_offset = float(seek * HOP_LENGTH / SAMPLE_RATE) | ||||
|             mel_segment = mel[seek : seek + N_FRAMES] | ||||
|             segment_size = min(N_FRAMES, content_frames - seek) | ||||
|             window_end_time = float((seek + N_FRAMES) * HOP_LENGTH / SAMPLE_RATE) | ||||
|             segment_size = min(N_FRAMES, content_frames - seek, seek_clip_end - seek) | ||||
|             mel_segment = mel[seek : seek + segment_size] | ||||
|             segment_duration = segment_size * HOP_LENGTH / SAMPLE_RATE | ||||
|             mel_segment = pad_or_trim(mel_segment, N_FRAMES, axis=-2).astype(dtype) | ||||
|  | ||||
| @@ -260,6 +321,30 @@ def transcribe( | ||||
|             previous_seek = seek | ||||
|             current_segments = [] | ||||
|  | ||||
|             # anomalous words are very long/short/improbable | ||||
|             def word_anomaly_score(word: dict) -> float: | ||||
|                 probability = word.get("probability", 0.0) | ||||
|                 duration = word["end"] - word["start"] | ||||
|                 score = 0.0 | ||||
|                 if probability < 0.15: | ||||
|                     score += 1.0 | ||||
|                 if duration < 0.133: | ||||
|                     score += (0.133 - duration) * 15 | ||||
|                 if duration > 2.0: | ||||
|                     score += duration - 2.0 | ||||
|                 return score | ||||
|  | ||||
|             def is_segment_anomaly(segment: Optional[dict]) -> bool: | ||||
|                 if segment is None or not segment["words"]: | ||||
|                     return False | ||||
|                 words = [w for w in segment["words"] if w["word"] not in punctuation] | ||||
|                 words = words[:8] | ||||
|                 score = sum(word_anomaly_score(w) for w in words) | ||||
|                 return score >= 3 or score + 0.01 >= len(words) | ||||
|  | ||||
|             def next_words_segment(segments: List[dict]) -> Optional[dict]: | ||||
|                 return next((s for s in segments if s["words"]), None) | ||||
|  | ||||
|             timestamp_tokens = tokens >= tokenizer.timestamp_begin | ||||
|             single_timestamp_ending = timestamp_tokens[-2:].tolist() == [False, True] | ||||
|  | ||||
| @@ -324,6 +409,83 @@ def transcribe( | ||||
|                 ) | ||||
|                 seek += segment_size | ||||
|  | ||||
|             if word_timestamps: | ||||
|                 add_word_timestamps( | ||||
|                     segments=current_segments, | ||||
|                     model=model, | ||||
|                     tokenizer=tokenizer, | ||||
|                     mel=mel_segment, | ||||
|                     num_frames=segment_size, | ||||
|                     prepend_punctuations=prepend_punctuations, | ||||
|                     append_punctuations=append_punctuations, | ||||
|                     last_speech_timestamp=last_speech_timestamp, | ||||
|                 ) | ||||
|  | ||||
|                 if not single_timestamp_ending: | ||||
|                     last_word_end = _get_end(current_segments) | ||||
|                     if last_word_end is not None and last_word_end > time_offset: | ||||
|                         seek = round(last_word_end * FRAMES_PER_SECOND) | ||||
|  | ||||
|                 # skip silence before possible hallucinations | ||||
|                 if hallucination_silence_threshold is not None: | ||||
|                     threshold = hallucination_silence_threshold | ||||
|                     if not single_timestamp_ending: | ||||
|                         last_word_end = _get_end(current_segments) | ||||
|                         if last_word_end is not None and last_word_end > time_offset: | ||||
|                             remaining_duration = window_end_time - last_word_end | ||||
|                             if remaining_duration > threshold: | ||||
|                                 seek = round(last_word_end * FRAMES_PER_SECOND) | ||||
|                             else: | ||||
|                                 seek = previous_seek + segment_size | ||||
|  | ||||
|                     # if first segment might be a hallucination, skip leading silence | ||||
|                     first_segment = next_words_segment(current_segments) | ||||
|                     if first_segment is not None and is_segment_anomaly(first_segment): | ||||
|                         gap = first_segment["start"] - time_offset | ||||
|                         if gap > threshold: | ||||
|                             seek = previous_seek + round(gap * FRAMES_PER_SECOND) | ||||
|                             continue | ||||
|  | ||||
|                     # skip silence before any possible hallucination that is surrounded | ||||
|                     # by silence or more hallucinations | ||||
|                     hal_last_end = last_speech_timestamp | ||||
|                     for si in range(len(current_segments)): | ||||
|                         segment = current_segments[si] | ||||
|                         if not segment["words"]: | ||||
|                             continue | ||||
|                         if is_segment_anomaly(segment): | ||||
|                             next_segment = next_words_segment( | ||||
|                                 current_segments[si + 1 :] | ||||
|                             ) | ||||
|                             if next_segment is not None: | ||||
|                                 hal_next_start = next_segment["words"][0]["start"] | ||||
|                             else: | ||||
|                                 hal_next_start = time_offset + segment_duration | ||||
|                             silence_before = ( | ||||
|                                 segment["start"] - hal_last_end > threshold | ||||
|                                 or segment["start"] < threshold | ||||
|                                 or segment["start"] - time_offset < 2.0 | ||||
|                             ) | ||||
|                             silence_after = ( | ||||
|                                 hal_next_start - segment["end"] > threshold | ||||
|                                 or is_segment_anomaly(next_segment) | ||||
|                                 or window_end_time - segment["end"] < 2.0 | ||||
|                             ) | ||||
|                             if silence_before and silence_after: | ||||
|                                 seek = round( | ||||
|                                     max(time_offset + 1, segment["start"]) | ||||
|                                     * FRAMES_PER_SECOND | ||||
|                                 ) | ||||
|                                 if content_duration - segment["end"] < threshold: | ||||
|                                     seek = content_frames | ||||
|                                 current_segments[si:] = [] | ||||
|                                 break | ||||
|                         hal_last_end = segment["end"] | ||||
|  | ||||
|                 last_word_end = _get_end(current_segments) | ||||
|                 if last_word_end is not None: | ||||
|                     last_speech_timestamp = last_word_end | ||||
|  | ||||
|             if verbose: | ||||
|                 for segment in current_segments: | ||||
|                     start, end, text = segment["start"], segment["end"], segment["text"] | ||||
|   | ||||
| @@ -4,7 +4,7 @@ import base64 | ||||
| import gzip | ||||
| import math | ||||
| from dataclasses import dataclass | ||||
| from typing import Dict, Iterable, Optional | ||||
| from typing import Union | ||||
|  | ||||
| import mlx.core as mx | ||||
| import mlx.nn as nn | ||||
| @@ -72,8 +72,8 @@ class MultiHeadAttention(nn.Module): | ||||
|         else: | ||||
|             k, v = kv_cache | ||||
|  | ||||
|         wv = self.qkv_attention(q, k, v, mask) | ||||
|         return self.out(wv), (k, v) | ||||
|         wv, qk = self.qkv_attention(q, k, v, mask) | ||||
|         return self.out(wv), (k, v), qk | ||||
|  | ||||
|     def qkv_attention(self, q, k, v, mask=None): | ||||
|         n_batch, n_ctx, n_state = q.shape | ||||
| @@ -85,12 +85,12 @@ class MultiHeadAttention(nn.Module): | ||||
|         qk = q @ k | ||||
|         if mask is not None: | ||||
|             qk = qk + mask[:n_ctx, :n_ctx] | ||||
|  | ||||
|         qk = qk.astype(mx.float32) | ||||
|  | ||||
|         w = mx.softmax(qk, axis=-1).astype(q.dtype) | ||||
|         out = (w @ v).transpose(0, 2, 1, 3) | ||||
|         out = out.reshape(n_batch, n_ctx, n_state) | ||||
|         return out | ||||
|         return out, qk | ||||
|  | ||||
|  | ||||
| class ResidualAttentionBlock(nn.Module): | ||||
| @@ -112,13 +112,16 @@ class ResidualAttentionBlock(nn.Module): | ||||
|  | ||||
|     def __call__(self, x, xa=None, mask=None, kv_cache=None): | ||||
|         kv, cross_kv = kv_cache if kv_cache else (None, None) | ||||
|         y, kv = self.attn(self.attn_ln(x), mask=mask, kv_cache=kv) | ||||
|         y, kv, _ = self.attn(self.attn_ln(x), mask=mask, kv_cache=kv) | ||||
|         x += y | ||||
|         cross_qk = None | ||||
|         if self.cross_attn: | ||||
|             y, cross_kv = self.cross_attn(self.cross_attn_ln(x), xa, kv_cache=cross_kv) | ||||
|             y, cross_kv, cross_qk = self.cross_attn( | ||||
|                 self.cross_attn_ln(x), xa, kv_cache=cross_kv | ||||
|             ) | ||||
|             x += y | ||||
|         x = x + self.mlp2(nn.gelu(self.mlp1(self.mlp_ln(x))).astype(x.dtype)) | ||||
|         return x, (kv, cross_kv) | ||||
|         return x, (kv, cross_kv), cross_qk | ||||
|  | ||||
|  | ||||
| class AudioEncoder(nn.Module): | ||||
| @@ -146,7 +149,7 @@ class AudioEncoder(nn.Module): | ||||
|         x = x + self._positional_embedding | ||||
|  | ||||
|         for block in self.blocks: | ||||
|             x, _ = block(x) | ||||
|             x, _, _ = block(x) | ||||
|  | ||||
|         x = self.ln_post(x) | ||||
|         return x | ||||
| @@ -191,11 +194,14 @@ class TextDecoder(nn.Module): | ||||
|  | ||||
|         if kv_cache is None: | ||||
|             kv_cache = [None] * len(self.blocks) | ||||
|         cross_qk = [None] * len(self.blocks) | ||||
|         for e, block in enumerate(self.blocks): | ||||
|             x, kv_cache[e] = block(x, xa, mask=self._mask, kv_cache=kv_cache[e]) | ||||
|             x, kv_cache[e], cross_qk[e] = block( | ||||
|                 x, xa, mask=self._mask, kv_cache=kv_cache[e] | ||||
|             ) | ||||
|  | ||||
|         x = self.ln(x) | ||||
|         return x @ self.token_embedding.weight.T, kv_cache | ||||
|         return x @ self.token_embedding.weight.T, kv_cache, cross_qk | ||||
|  | ||||
|  | ||||
| class Whisper(nn.Module): | ||||
| @@ -218,6 +224,28 @@ class Whisper(nn.Module): | ||||
|             self.dims.n_text_layer, | ||||
|             dtype, | ||||
|         ) | ||||
|         # use the last half among the decoder layers for time alignment by default; | ||||
|         # to use a specific set of heads, see `set_alignment_heads()` below. | ||||
|         all_heads = np.zeros( | ||||
|             (self.dims.n_text_layer, self.dims.n_text_head), dtype=bool | ||||
|         ) | ||||
|         all_heads[self.dims.n_text_layer // 2 :] = True | ||||
|         self.alignment_heads = mx.array(np.asarray(all_heads.nonzero()).T) | ||||
|  | ||||
|     def set_alignment_heads(self, dump: Union[bytes, np.ndarray]): | ||||
|         if isinstance(dump, np.ndarray): | ||||
|             self.alignment_heads = mx.array(dump) | ||||
|         elif isinstance(dump, bytes): | ||||
|             array = np.frombuffer( | ||||
|                 gzip.decompress(base64.b85decode(dump)), dtype=bool | ||||
|             ).copy() | ||||
|             mask = array.reshape(self.dims.n_text_layer, self.dims.n_text_head) | ||||
|             self.alignment_heads = mx.array(np.asarray(mask.nonzero()).T) | ||||
|         else: | ||||
|             raise ValueError( | ||||
|                 f"Invalid type for `dump`: {type(dump)}. Expected a np.ndarray or base85-encoded bytes containing" | ||||
|                 " alignment_head information" | ||||
|             ) | ||||
|  | ||||
|     def embed_audio(self, mel): | ||||
|         return self.encoder(mel) | ||||
| @@ -225,6 +253,10 @@ class Whisper(nn.Module): | ||||
|     def logits(self, tokens, audio_features): | ||||
|         return self.decoder(tokens, audio_features)[0] | ||||
|  | ||||
|     def forward_with_cross_qk(self, mel, tokens): | ||||
|         logits, _, cross_qk = self.decoder(tokens, self.encoder(mel)) | ||||
|         return logits, cross_qk | ||||
|  | ||||
|     def __call__(self, mel, tokens): | ||||
|         return self.decoder(tokens, self.encoder(mel))[0] | ||||
|  | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 bofeng huang
					bofeng huang