| 
									
										
										
										
											2023-11-30 11:08:53 -08:00
										 |  |  |  | # Copyright © 2023 Apple Inc. | 
					
						
							|  |  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-11-29 08:17:26 -08:00
										 |  |  |  | import sys | 
					
						
							| 
									
										
										
										
											2024-01-07 19:01:29 +01:00
										 |  |  |  | import warnings | 
					
						
							|  |  |  |  | from typing import List, Optional, Tuple, Union | 
					
						
							| 
									
										
										
										
											2023-11-29 08:17:26 -08:00
										 |  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-12-20 10:22:25 -08:00
										 |  |  |  | import mlx.core as mx | 
					
						
							|  |  |  |  | import numpy as np | 
					
						
							|  |  |  |  | import tqdm | 
					
						
							| 
									
										
										
										
											2023-11-29 08:17:26 -08:00
										 |  |  |  | 
 | 
					
						
							|  |  |  |  | from .audio import ( | 
					
						
							|  |  |  |  |     FRAMES_PER_SECOND, | 
					
						
							|  |  |  |  |     HOP_LENGTH, | 
					
						
							|  |  |  |  |     N_FRAMES, | 
					
						
							|  |  |  |  |     N_SAMPLES, | 
					
						
							|  |  |  |  |     SAMPLE_RATE, | 
					
						
							|  |  |  |  |     log_mel_spectrogram, | 
					
						
							|  |  |  |  |     pad_or_trim, | 
					
						
							|  |  |  |  | ) | 
					
						
							|  |  |  |  | from .decoding import DecodingOptions, DecodingResult | 
					
						
							|  |  |  |  | from .load_models import load_model | 
					
						
							| 
									
										
										
										
											2024-01-07 19:01:29 +01:00
										 |  |  |  | from .timing import add_word_timestamps | 
					
						
							|  |  |  |  | from .tokenizer import LANGUAGES, get_tokenizer | 
					
						
							| 
									
										
										
										
											2023-11-29 08:17:26 -08:00
										 |  |  |  | 
 | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  | def _format_timestamp(seconds: float): | 
					
						
							|  |  |  |  |     assert seconds >= 0, "non-negative timestamp expected" | 
					
						
							|  |  |  |  |     milliseconds = round(seconds * 1000.0) | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  |     hours = milliseconds // 3_600_000 | 
					
						
							|  |  |  |  |     milliseconds -= hours * 3_600_000 | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  |     minutes = milliseconds // 60_000 | 
					
						
							|  |  |  |  |     milliseconds -= minutes * 60_000 | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  |     seconds = milliseconds // 1_000 | 
					
						
							|  |  |  |  |     milliseconds -= seconds * 1_000 | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  |     hours_marker = f"{hours:02d}:" if hours > 0 else "" | 
					
						
							|  |  |  |  |     return f"{hours_marker}{minutes:02d}:{seconds:02d}.{milliseconds:03d}" | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-01-07 19:01:29 +01:00
										 |  |  |  | def _get_end(segments: List[dict]) -> Optional[float]: | 
					
						
							|  |  |  |  |     return next( | 
					
						
							|  |  |  |  |         (w["end"] for s in reversed(segments) for w in reversed(s["words"])), | 
					
						
							|  |  |  |  |         segments[-1]["end"] if segments else None, | 
					
						
							|  |  |  |  |     ) | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-11-29 08:17:26 -08:00
										 |  |  |  | class ModelHolder: | 
					
						
							|  |  |  |  |     model = None | 
					
						
							| 
									
										
										
										
											2023-12-29 19:22:15 +01:00
										 |  |  |  |     model_path = None | 
					
						
							| 
									
										
										
										
											2023-11-29 08:17:26 -08:00
										 |  |  |  | 
 | 
					
						
							|  |  |  |  |     @classmethod | 
					
						
							| 
									
										
										
										
											2023-12-29 19:22:15 +01:00
										 |  |  |  |     def get_model(cls, model_path: str, dtype: mx.Dtype): | 
					
						
							|  |  |  |  |         if cls.model is None or model_path != cls.model_path: | 
					
						
							|  |  |  |  |             cls.model = load_model(model_path, dtype=dtype) | 
					
						
							|  |  |  |  |             cls.model_path = model_path | 
					
						
							| 
									
										
										
										
											2023-11-29 08:17:26 -08:00
										 |  |  |  |         return cls.model | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  | def transcribe( | 
					
						
							|  |  |  |  |     audio: Union[str, np.ndarray, mx.array], | 
					
						
							|  |  |  |  |     *, | 
					
						
							| 
									
										
										
										
											2024-01-10 16:13:06 -08:00
										 |  |  |  |     path_or_hf_repo: str = "mlx-community/whisper-tiny", | 
					
						
							| 
									
										
										
										
											2023-11-29 08:17:26 -08:00
										 |  |  |  |     verbose: Optional[bool] = None, | 
					
						
							|  |  |  |  |     temperature: Union[float, Tuple[float, ...]] = (0.0, 0.2, 0.4, 0.6, 0.8, 1.0), | 
					
						
							|  |  |  |  |     compression_ratio_threshold: Optional[float] = 2.4, | 
					
						
							|  |  |  |  |     logprob_threshold: Optional[float] = -1.0, | 
					
						
							|  |  |  |  |     no_speech_threshold: Optional[float] = 0.6, | 
					
						
							|  |  |  |  |     condition_on_previous_text: bool = True, | 
					
						
							|  |  |  |  |     initial_prompt: Optional[str] = None, | 
					
						
							| 
									
										
										
										
											2024-01-07 19:01:29 +01:00
										 |  |  |  |     word_timestamps: bool = False, | 
					
						
							| 
									
										
										
										
											2023-11-29 08:17:26 -08:00
										 |  |  |  |     prepend_punctuations: str = "\"'“¿([{-", | 
					
						
							|  |  |  |  |     append_punctuations: str = "\"'.。,,!!??::”)]}、", | 
					
						
							| 
									
										
										
										
											2024-01-07 19:01:29 +01:00
										 |  |  |  |     clip_timestamps: Union[str, List[float]] = "0", | 
					
						
							|  |  |  |  |     hallucination_silence_threshold: Optional[float] = None, | 
					
						
							| 
									
										
										
										
											2023-11-29 08:17:26 -08:00
										 |  |  |  |     **decode_options, | 
					
						
							|  |  |  |  | ): | 
					
						
							|  |  |  |  |     """
 | 
					
						
							|  |  |  |  |     Transcribe an audio file using Whisper | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  |     Parameters | 
					
						
							|  |  |  |  |     ---------- | 
					
						
							|  |  |  |  |     audio: Union[str, np.ndarray, mx.array] | 
					
						
							|  |  |  |  |         The path to the audio file to open, or the audio waveform | 
					
						
							|  |  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-01-08 19:50:00 +05:30
										 |  |  |  |     path_or_hf_repo: str | 
					
						
							|  |  |  |  |         The localpath to the Whisper model or HF Hub repo with the MLX converted weights. | 
					
						
							| 
									
										
										
										
											2023-11-29 08:17:26 -08:00
										 |  |  |  | 
 | 
					
						
							|  |  |  |  |     verbose: bool | 
					
						
							|  |  |  |  |         Whether to display the text being decoded to the console. If True, displays all the details, | 
					
						
							|  |  |  |  |         If False, displays minimal details. If None, does not display anything | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  |     temperature: Union[float, Tuple[float, ...]] | 
					
						
							|  |  |  |  |         Temperature for sampling. It can be a tuple of temperatures, which will be successively used | 
					
						
							|  |  |  |  |         upon failures according to either `compression_ratio_threshold` or `logprob_threshold`. | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  |     compression_ratio_threshold: float | 
					
						
							|  |  |  |  |         If the gzip compression ratio is above this value, treat as failed | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  |     logprob_threshold: float | 
					
						
							|  |  |  |  |         If the average log probability over sampled tokens is below this value, treat as failed | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  |     no_speech_threshold: float | 
					
						
							|  |  |  |  |         If the no_speech probability is higher than this value AND the average log probability | 
					
						
							|  |  |  |  |         over sampled tokens is below `logprob_threshold`, consider the segment as silent | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  |     condition_on_previous_text: bool | 
					
						
							|  |  |  |  |         if True, the previous output of the model is provided as a prompt for the next window; | 
					
						
							|  |  |  |  |         disabling may make the text inconsistent across windows, but the model becomes less prone to | 
					
						
							|  |  |  |  |         getting stuck in a failure loop, such as repetition looping or timestamps going out of sync. | 
					
						
							|  |  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-01-07 19:01:29 +01:00
										 |  |  |  |     word_timestamps: bool | 
					
						
							|  |  |  |  |         Extract word-level timestamps using the cross-attention pattern and dynamic time warping, | 
					
						
							|  |  |  |  |         and include the timestamps for each word in each segment. | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  |     prepend_punctuations: str | 
					
						
							|  |  |  |  |         If word_timestamps is True, merge these punctuation symbols with the next word | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  |     append_punctuations: str | 
					
						
							|  |  |  |  |         If word_timestamps is True, merge these punctuation symbols with the previous word | 
					
						
							|  |  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-11-29 08:17:26 -08:00
										 |  |  |  |     initial_prompt: Optional[str] | 
					
						
							|  |  |  |  |         Optional text to provide as a prompt for the first window. This can be used to provide, or | 
					
						
							|  |  |  |  |         "prompt-engineer" a context for transcription, e.g. custom vocabularies or proper nouns | 
					
						
							|  |  |  |  |         to make it more likely to predict those word correctly. | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  |     decode_options: dict | 
					
						
							|  |  |  |  |         Keyword arguments to construct `DecodingOptions` instances | 
					
						
							|  |  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-01-07 19:01:29 +01:00
										 |  |  |  |     clip_timestamps: Union[str, List[float]] | 
					
						
							|  |  |  |  |         Comma-separated list start,end,start,end,... timestamps (in seconds) of clips to process. | 
					
						
							|  |  |  |  |         The last end timestamp defaults to the end of the file. | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  |     hallucination_silence_threshold: Optional[float] | 
					
						
							|  |  |  |  |         When word_timestamps is True, skip silent periods longer than this threshold (in seconds) | 
					
						
							|  |  |  |  |         when a possible hallucination is detected | 
					
						
							|  |  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-11-29 08:17:26 -08:00
										 |  |  |  |     Returns | 
					
						
							|  |  |  |  |     ------- | 
					
						
							|  |  |  |  |     A dictionary containing the resulting text ("text") and segment-level details ("segments"), and | 
					
						
							|  |  |  |  |     the spoken language ("language"), which is detected when `decode_options["language"]` is None. | 
					
						
							|  |  |  |  |     """
 | 
					
						
							|  |  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-12-12 07:37:35 -08:00
										 |  |  |  |     dtype = mx.float16 if decode_options.get("fp16", True) else mx.float32 | 
					
						
							| 
									
										
										
										
											2024-01-08 19:50:00 +05:30
										 |  |  |  |     model = ModelHolder.get_model(path_or_hf_repo, dtype) | 
					
						
							| 
									
										
										
										
											2023-11-29 08:17:26 -08:00
										 |  |  |  | 
 | 
					
						
							|  |  |  |  |     # Pad 30-seconds of silence to the input audio, for slicing | 
					
						
							| 
									
										
										
										
											2023-12-12 17:26:52 +01:00
										 |  |  |  |     mel = log_mel_spectrogram(audio, n_mels=model.dims.n_mels, padding=N_SAMPLES) | 
					
						
							| 
									
										
										
										
											2023-11-29 08:17:26 -08:00
										 |  |  |  |     content_frames = mel.shape[-2] - N_FRAMES | 
					
						
							| 
									
										
										
										
											2024-01-07 19:01:29 +01:00
										 |  |  |  |     content_duration = float(content_frames * HOP_LENGTH / SAMPLE_RATE) | 
					
						
							| 
									
										
										
										
											2023-11-29 08:17:26 -08:00
										 |  |  |  | 
 | 
					
						
							|  |  |  |  |     if verbose: | 
					
						
							|  |  |  |  |         system_encoding = sys.getdefaultencoding() | 
					
						
							|  |  |  |  |         if system_encoding != "utf-8": | 
					
						
							|  |  |  |  |             make_safe = lambda x: x.encode(system_encoding, errors="replace").decode( | 
					
						
							|  |  |  |  |                 system_encoding | 
					
						
							|  |  |  |  |             ) | 
					
						
							|  |  |  |  |         else: | 
					
						
							|  |  |  |  |             make_safe = lambda x: x | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  |     if decode_options.get("language", None) is None: | 
					
						
							|  |  |  |  |         if not model.is_multilingual: | 
					
						
							|  |  |  |  |             decode_options["language"] = "en" | 
					
						
							|  |  |  |  |         else: | 
					
						
							|  |  |  |  |             if verbose: | 
					
						
							|  |  |  |  |                 print( | 
					
						
							|  |  |  |  |                     "Detecting language using up to the first 30 seconds. " | 
					
						
							|  |  |  |  |                     "Use the `language` decoding option to specify the language" | 
					
						
							|  |  |  |  |                 ) | 
					
						
							|  |  |  |  |             mel_segment = pad_or_trim(mel, N_FRAMES, axis=-2).astype(dtype) | 
					
						
							|  |  |  |  |             _, probs = model.detect_language(mel_segment) | 
					
						
							|  |  |  |  |             decode_options["language"] = max(probs, key=probs.get) | 
					
						
							|  |  |  |  |             if verbose is not None: | 
					
						
							|  |  |  |  |                 print( | 
					
						
							|  |  |  |  |                     f"Detected language: {LANGUAGES[decode_options['language']].title()}" | 
					
						
							|  |  |  |  |                 ) | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  |     language: str = decode_options["language"] | 
					
						
							|  |  |  |  |     task: str = decode_options.get("task", "transcribe") | 
					
						
							| 
									
										
										
										
											2023-12-12 17:26:52 +01:00
										 |  |  |  |     tokenizer = get_tokenizer( | 
					
						
							|  |  |  |  |         model.is_multilingual, | 
					
						
							|  |  |  |  |         num_languages=model.num_languages, | 
					
						
							|  |  |  |  |         language=language, | 
					
						
							|  |  |  |  |         task=task, | 
					
						
							|  |  |  |  |     ) | 
					
						
							| 
									
										
										
										
											2023-11-29 08:17:26 -08:00
										 |  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-01-07 19:01:29 +01:00
										 |  |  |  |     if isinstance(clip_timestamps, str): | 
					
						
							|  |  |  |  |         clip_timestamps = [ | 
					
						
							|  |  |  |  |             float(ts) for ts in (clip_timestamps.split(",") if clip_timestamps else []) | 
					
						
							|  |  |  |  |         ] | 
					
						
							|  |  |  |  |     seek_points: List[int] = [round(ts * FRAMES_PER_SECOND) for ts in clip_timestamps] | 
					
						
							|  |  |  |  |     if len(seek_points) == 0: | 
					
						
							|  |  |  |  |         seek_points.append(0) | 
					
						
							|  |  |  |  |     if len(seek_points) % 2 == 1: | 
					
						
							|  |  |  |  |         seek_points.append(content_frames) | 
					
						
							|  |  |  |  |     seek_clips: List[Tuple[int, int]] = list(zip(seek_points[::2], seek_points[1::2])) | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  |     punctuation = "\"'“¿([{-\"'.。,,!!??::”)]}、" | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  |     if word_timestamps and task == "translate": | 
					
						
							|  |  |  |  |         warnings.warn("Word-level timestamps on translations may not be reliable.") | 
					
						
							|  |  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-11-29 08:17:26 -08:00
										 |  |  |  |     def decode_with_fallback(segment: mx.array) -> DecodingResult: | 
					
						
							|  |  |  |  |         temperatures = ( | 
					
						
							|  |  |  |  |             [temperature] if isinstance(temperature, (int, float)) else temperature | 
					
						
							|  |  |  |  |         ) | 
					
						
							|  |  |  |  |         decode_result = None | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  |         for t in temperatures: | 
					
						
							|  |  |  |  |             kwargs = {**decode_options} | 
					
						
							|  |  |  |  |             if t > 0: | 
					
						
							|  |  |  |  |                 # disable beam_size and patience when t > 0 | 
					
						
							|  |  |  |  |                 kwargs.pop("beam_size", None) | 
					
						
							|  |  |  |  |                 kwargs.pop("patience", None) | 
					
						
							|  |  |  |  |             else: | 
					
						
							|  |  |  |  |                 # disable best_of when t == 0 | 
					
						
							|  |  |  |  |                 kwargs.pop("best_of", None) | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  |             options = DecodingOptions(**kwargs, temperature=t) | 
					
						
							|  |  |  |  |             decode_result = model.decode(segment, options) | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  |             needs_fallback = False | 
					
						
							|  |  |  |  |             if ( | 
					
						
							|  |  |  |  |                 compression_ratio_threshold is not None | 
					
						
							|  |  |  |  |                 and decode_result.compression_ratio > compression_ratio_threshold | 
					
						
							|  |  |  |  |             ): | 
					
						
							|  |  |  |  |                 needs_fallback = True  # too repetitive | 
					
						
							|  |  |  |  |             if ( | 
					
						
							|  |  |  |  |                 logprob_threshold is not None | 
					
						
							|  |  |  |  |                 and decode_result.avg_logprob < logprob_threshold | 
					
						
							|  |  |  |  |             ): | 
					
						
							|  |  |  |  |                 needs_fallback = True  # average log probability is too low | 
					
						
							|  |  |  |  |             if ( | 
					
						
							|  |  |  |  |                 no_speech_threshold is not None | 
					
						
							|  |  |  |  |                 and decode_result.no_speech_prob > no_speech_threshold | 
					
						
							|  |  |  |  |             ): | 
					
						
							|  |  |  |  |                 needs_fallback = False  # silence | 
					
						
							|  |  |  |  |             if not needs_fallback: | 
					
						
							|  |  |  |  |                 break | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  |         return decode_result | 
					
						
							|  |  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-01-07 19:01:29 +01:00
										 |  |  |  |     clip_idx = 0 | 
					
						
							|  |  |  |  |     seek = seek_clips[clip_idx][0] | 
					
						
							| 
									
										
										
										
											2023-11-29 08:17:26 -08:00
										 |  |  |  |     input_stride = N_FRAMES // model.dims.n_audio_ctx  # mel frames per output token: 2 | 
					
						
							|  |  |  |  |     time_precision = ( | 
					
						
							|  |  |  |  |         input_stride * HOP_LENGTH / SAMPLE_RATE | 
					
						
							|  |  |  |  |     )  # time per output token: 0.02 (seconds) | 
					
						
							|  |  |  |  |     all_tokens = [] | 
					
						
							|  |  |  |  |     all_segments = [] | 
					
						
							|  |  |  |  |     prompt_reset_since = 0 | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  |     if initial_prompt is not None: | 
					
						
							|  |  |  |  |         initial_prompt_tokens = tokenizer.encode(" " + initial_prompt.strip()) | 
					
						
							|  |  |  |  |         all_tokens.extend(initial_prompt_tokens) | 
					
						
							|  |  |  |  |     else: | 
					
						
							|  |  |  |  |         initial_prompt_tokens = [] | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  |     def new_segment( | 
					
						
							|  |  |  |  |         *, start: float, end: float, tokens: mx.array, result: DecodingResult | 
					
						
							|  |  |  |  |     ): | 
					
						
							|  |  |  |  |         tokens = tokens.tolist() | 
					
						
							|  |  |  |  |         text_tokens = [token for token in tokens if token < tokenizer.eot] | 
					
						
							|  |  |  |  |         return { | 
					
						
							|  |  |  |  |             "seek": seek, | 
					
						
							|  |  |  |  |             "start": start, | 
					
						
							|  |  |  |  |             "end": end, | 
					
						
							|  |  |  |  |             "text": tokenizer.decode(text_tokens), | 
					
						
							|  |  |  |  |             "tokens": tokens, | 
					
						
							|  |  |  |  |             "temperature": result.temperature, | 
					
						
							|  |  |  |  |             "avg_logprob": result.avg_logprob, | 
					
						
							|  |  |  |  |             "compression_ratio": result.compression_ratio, | 
					
						
							|  |  |  |  |             "no_speech_prob": result.no_speech_prob, | 
					
						
							|  |  |  |  |         } | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  |     # show the progress bar when verbose is False (if True, transcribed text will be printed) | 
					
						
							|  |  |  |  |     with tqdm.tqdm( | 
					
						
							|  |  |  |  |         total=content_frames, unit="frames", disable=verbose is not False | 
					
						
							|  |  |  |  |     ) as pbar: | 
					
						
							|  |  |  |  |         last_speech_timestamp = 0.0 | 
					
						
							| 
									
										
										
										
											2024-03-30 13:13:58 -07:00
										 |  |  |  |         for seek_clip_start, seek_clip_end in seek_clips: | 
					
						
							|  |  |  |  |             while seek < seek_clip_end: | 
					
						
							|  |  |  |  |                 time_offset = float(seek * HOP_LENGTH / SAMPLE_RATE) | 
					
						
							|  |  |  |  |                 window_end_time = float((seek + N_FRAMES) * HOP_LENGTH / SAMPLE_RATE) | 
					
						
							|  |  |  |  |                 segment_size = min( | 
					
						
							|  |  |  |  |                     N_FRAMES, content_frames - seek, seek_clip_end - seek | 
					
						
							|  |  |  |  |                 ) | 
					
						
							|  |  |  |  |                 mel_segment = mel[seek : seek + segment_size] | 
					
						
							|  |  |  |  |                 segment_duration = segment_size * HOP_LENGTH / SAMPLE_RATE | 
					
						
							|  |  |  |  |                 mel_segment = pad_or_trim(mel_segment, N_FRAMES, axis=-2).astype(dtype) | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  |                 decode_options["prompt"] = all_tokens[prompt_reset_since:] | 
					
						
							|  |  |  |  |                 result: DecodingResult = decode_with_fallback(mel_segment) | 
					
						
							| 
									
										
										
										
											2024-11-01 10:52:28 -07:00
										 |  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-03-30 13:13:58 -07:00
										 |  |  |  |                 tokens = np.array(result.tokens) | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  |                 if no_speech_threshold is not None: | 
					
						
							|  |  |  |  |                     # no voice activity check | 
					
						
							|  |  |  |  |                     should_skip = result.no_speech_prob > no_speech_threshold | 
					
						
							|  |  |  |  |                     if ( | 
					
						
							|  |  |  |  |                         logprob_threshold is not None | 
					
						
							|  |  |  |  |                         and result.avg_logprob > logprob_threshold | 
					
						
							|  |  |  |  |                     ): | 
					
						
							|  |  |  |  |                         # don't skip if the logprob is high enough, despite the no_speech_prob | 
					
						
							|  |  |  |  |                         should_skip = False | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  |                     if should_skip: | 
					
						
							|  |  |  |  |                         seek += ( | 
					
						
							|  |  |  |  |                             segment_size  # fast-forward to the next segment boundary | 
					
						
							|  |  |  |  |                         ) | 
					
						
							|  |  |  |  |                         continue | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  |                 previous_seek = seek | 
					
						
							|  |  |  |  |                 current_segments = [] | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  |                 # anomalous words are very long/short/improbable | 
					
						
							|  |  |  |  |                 def word_anomaly_score(word: dict) -> float: | 
					
						
							|  |  |  |  |                     probability = word.get("probability", 0.0) | 
					
						
							|  |  |  |  |                     duration = word["end"] - word["start"] | 
					
						
							|  |  |  |  |                     score = 0.0 | 
					
						
							|  |  |  |  |                     if probability < 0.15: | 
					
						
							|  |  |  |  |                         score += 1.0 | 
					
						
							|  |  |  |  |                     if duration < 0.133: | 
					
						
							|  |  |  |  |                         score += (0.133 - duration) * 15 | 
					
						
							|  |  |  |  |                     if duration > 2.0: | 
					
						
							|  |  |  |  |                         score += duration - 2.0 | 
					
						
							|  |  |  |  |                     return score | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  |                 def is_segment_anomaly(segment: Optional[dict]) -> bool: | 
					
						
							|  |  |  |  |                     if segment is None or not segment["words"]: | 
					
						
							|  |  |  |  |                         return False | 
					
						
							|  |  |  |  |                     words = [ | 
					
						
							|  |  |  |  |                         w for w in segment["words"] if w["word"] not in punctuation | 
					
						
							|  |  |  |  |                     ] | 
					
						
							|  |  |  |  |                     words = words[:8] | 
					
						
							|  |  |  |  |                     score = sum(word_anomaly_score(w) for w in words) | 
					
						
							|  |  |  |  |                     return score >= 3 or score + 0.01 >= len(words) | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  |                 def next_words_segment(segments: List[dict]) -> Optional[dict]: | 
					
						
							|  |  |  |  |                     return next((s for s in segments if s["words"]), None) | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  |                 timestamp_tokens = tokens >= tokenizer.timestamp_begin | 
					
						
							|  |  |  |  |                 single_timestamp_ending = timestamp_tokens[-2:].tolist() == [ | 
					
						
							|  |  |  |  |                     False, | 
					
						
							|  |  |  |  |                     True, | 
					
						
							|  |  |  |  |                 ] | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  |                 consecutive = np.where( | 
					
						
							|  |  |  |  |                     np.logical_and(timestamp_tokens[:-1], timestamp_tokens[1:]) | 
					
						
							|  |  |  |  |                 )[0] | 
					
						
							|  |  |  |  |                 consecutive += 1 | 
					
						
							|  |  |  |  |                 if len(consecutive) > 0: | 
					
						
							|  |  |  |  |                     # if the output contains two consecutive timestamp tokens | 
					
						
							|  |  |  |  |                     slices = consecutive.tolist() | 
					
						
							|  |  |  |  |                     if single_timestamp_ending: | 
					
						
							|  |  |  |  |                         slices.append(len(tokens)) | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  |                     last_slice = 0 | 
					
						
							|  |  |  |  |                     for current_slice in slices: | 
					
						
							|  |  |  |  |                         sliced_tokens = tokens[last_slice:current_slice] | 
					
						
							|  |  |  |  |                         start_timestamp_pos = ( | 
					
						
							|  |  |  |  |                             sliced_tokens[0].item() - tokenizer.timestamp_begin | 
					
						
							|  |  |  |  |                         ) | 
					
						
							|  |  |  |  |                         end_timestamp_pos = ( | 
					
						
							|  |  |  |  |                             sliced_tokens[-1].item() - tokenizer.timestamp_begin | 
					
						
							|  |  |  |  |                         ) | 
					
						
							|  |  |  |  |                         current_segments.append( | 
					
						
							|  |  |  |  |                             new_segment( | 
					
						
							|  |  |  |  |                                 start=time_offset | 
					
						
							|  |  |  |  |                                 + start_timestamp_pos * time_precision, | 
					
						
							|  |  |  |  |                                 end=time_offset + end_timestamp_pos * time_precision, | 
					
						
							|  |  |  |  |                                 tokens=sliced_tokens, | 
					
						
							|  |  |  |  |                                 result=result, | 
					
						
							|  |  |  |  |                             ) | 
					
						
							|  |  |  |  |                         ) | 
					
						
							|  |  |  |  |                         last_slice = current_slice | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  |                     if single_timestamp_ending: | 
					
						
							|  |  |  |  |                         # single timestamp at the end means no speech after the last timestamp. | 
					
						
							|  |  |  |  |                         seek += segment_size | 
					
						
							|  |  |  |  |                     else: | 
					
						
							|  |  |  |  |                         # otherwise, ignore the unfinished segment and seek to the last timestamp | 
					
						
							|  |  |  |  |                         last_timestamp_pos = ( | 
					
						
							|  |  |  |  |                             tokens[last_slice - 1].item() - tokenizer.timestamp_begin | 
					
						
							|  |  |  |  |                         ) | 
					
						
							|  |  |  |  |                         seek += last_timestamp_pos * input_stride | 
					
						
							|  |  |  |  |                 else: | 
					
						
							|  |  |  |  |                     duration = segment_duration | 
					
						
							|  |  |  |  |                     timestamps = tokens[timestamp_tokens.nonzero()[0]] | 
					
						
							|  |  |  |  |                     if ( | 
					
						
							|  |  |  |  |                         len(timestamps) > 0 | 
					
						
							|  |  |  |  |                         and timestamps[-1].item() != tokenizer.timestamp_begin | 
					
						
							|  |  |  |  |                     ): | 
					
						
							|  |  |  |  |                         # no consecutive timestamps but it has a timestamp; use the last one. | 
					
						
							|  |  |  |  |                         last_timestamp_pos = ( | 
					
						
							|  |  |  |  |                             timestamps[-1].item() - tokenizer.timestamp_begin | 
					
						
							|  |  |  |  |                         ) | 
					
						
							|  |  |  |  |                         duration = last_timestamp_pos * time_precision | 
					
						
							|  |  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-11-29 08:17:26 -08:00
										 |  |  |  |                     current_segments.append( | 
					
						
							|  |  |  |  |                         new_segment( | 
					
						
							| 
									
										
										
										
											2024-03-30 13:13:58 -07:00
										 |  |  |  |                             start=time_offset, | 
					
						
							|  |  |  |  |                             end=time_offset + duration, | 
					
						
							|  |  |  |  |                             tokens=tokens, | 
					
						
							| 
									
										
										
										
											2023-11-29 08:17:26 -08:00
										 |  |  |  |                             result=result, | 
					
						
							|  |  |  |  |                         ) | 
					
						
							|  |  |  |  |                     ) | 
					
						
							|  |  |  |  |                     seek += segment_size | 
					
						
							| 
									
										
										
										
											2024-01-07 19:01:29 +01:00
										 |  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-03-30 13:13:58 -07:00
										 |  |  |  |                 if word_timestamps: | 
					
						
							|  |  |  |  |                     add_word_timestamps( | 
					
						
							|  |  |  |  |                         segments=current_segments, | 
					
						
							|  |  |  |  |                         model=model, | 
					
						
							|  |  |  |  |                         tokenizer=tokenizer, | 
					
						
							|  |  |  |  |                         mel=mel_segment, | 
					
						
							|  |  |  |  |                         num_frames=segment_size, | 
					
						
							|  |  |  |  |                         prepend_punctuations=prepend_punctuations, | 
					
						
							|  |  |  |  |                         append_punctuations=append_punctuations, | 
					
						
							|  |  |  |  |                         last_speech_timestamp=last_speech_timestamp, | 
					
						
							|  |  |  |  |                     ) | 
					
						
							| 
									
										
										
										
											2024-01-07 19:01:29 +01:00
										 |  |  |  | 
 | 
					
						
							|  |  |  |  |                     if not single_timestamp_ending: | 
					
						
							|  |  |  |  |                         last_word_end = _get_end(current_segments) | 
					
						
							|  |  |  |  |                         if last_word_end is not None and last_word_end > time_offset: | 
					
						
							| 
									
										
										
										
											2024-03-30 13:13:58 -07:00
										 |  |  |  |                             seek = round(last_word_end * FRAMES_PER_SECOND) | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  |                     # skip silence before possible hallucinations | 
					
						
							|  |  |  |  |                     if hallucination_silence_threshold is not None: | 
					
						
							|  |  |  |  |                         threshold = hallucination_silence_threshold | 
					
						
							|  |  |  |  |                         if not single_timestamp_ending: | 
					
						
							|  |  |  |  |                             last_word_end = _get_end(current_segments) | 
					
						
							|  |  |  |  |                             if ( | 
					
						
							|  |  |  |  |                                 last_word_end is not None | 
					
						
							|  |  |  |  |                                 and last_word_end > time_offset | 
					
						
							|  |  |  |  |                             ): | 
					
						
							|  |  |  |  |                                 remaining_duration = window_end_time - last_word_end | 
					
						
							|  |  |  |  |                                 if remaining_duration > threshold: | 
					
						
							|  |  |  |  |                                     seek = round(last_word_end * FRAMES_PER_SECOND) | 
					
						
							|  |  |  |  |                                 else: | 
					
						
							|  |  |  |  |                                     seek = previous_seek + segment_size | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  |                         # if first segment might be a hallucination, skip leading silence | 
					
						
							|  |  |  |  |                         first_segment = next_words_segment(current_segments) | 
					
						
							|  |  |  |  |                         if first_segment is not None and is_segment_anomaly( | 
					
						
							|  |  |  |  |                             first_segment | 
					
						
							|  |  |  |  |                         ): | 
					
						
							|  |  |  |  |                             gap = first_segment["start"] - time_offset | 
					
						
							|  |  |  |  |                             if gap > threshold: | 
					
						
							|  |  |  |  |                                 seek = previous_seek + round(gap * FRAMES_PER_SECOND) | 
					
						
							|  |  |  |  |                                 continue | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  |                         # skip silence before any possible hallucination that is surrounded | 
					
						
							|  |  |  |  |                         # by silence or more hallucinations | 
					
						
							|  |  |  |  |                         hal_last_end = last_speech_timestamp | 
					
						
							|  |  |  |  |                         for si in range(len(current_segments)): | 
					
						
							|  |  |  |  |                             segment = current_segments[si] | 
					
						
							|  |  |  |  |                             if not segment["words"]: | 
					
						
							|  |  |  |  |                                 continue | 
					
						
							|  |  |  |  |                             if is_segment_anomaly(segment): | 
					
						
							|  |  |  |  |                                 next_segment = next_words_segment( | 
					
						
							|  |  |  |  |                                     current_segments[si + 1 :] | 
					
						
							| 
									
										
										
										
											2024-01-07 19:01:29 +01:00
										 |  |  |  |                                 ) | 
					
						
							| 
									
										
										
										
											2024-03-30 13:13:58 -07:00
										 |  |  |  |                                 if next_segment is not None: | 
					
						
							|  |  |  |  |                                     hal_next_start = next_segment["words"][0]["start"] | 
					
						
							|  |  |  |  |                                 else: | 
					
						
							|  |  |  |  |                                     hal_next_start = time_offset + segment_duration | 
					
						
							|  |  |  |  |                                 silence_before = ( | 
					
						
							|  |  |  |  |                                     segment["start"] - hal_last_end > threshold | 
					
						
							|  |  |  |  |                                     or segment["start"] < threshold | 
					
						
							|  |  |  |  |                                     or segment["start"] - time_offset < 2.0 | 
					
						
							|  |  |  |  |                                 ) | 
					
						
							|  |  |  |  |                                 silence_after = ( | 
					
						
							|  |  |  |  |                                     hal_next_start - segment["end"] > threshold | 
					
						
							|  |  |  |  |                                     or is_segment_anomaly(next_segment) | 
					
						
							|  |  |  |  |                                     or window_end_time - segment["end"] < 2.0 | 
					
						
							|  |  |  |  |                                 ) | 
					
						
							|  |  |  |  |                                 if silence_before and silence_after: | 
					
						
							|  |  |  |  |                                     seek = round( | 
					
						
							|  |  |  |  |                                         max(time_offset + 1, segment["start"]) | 
					
						
							|  |  |  |  |                                         * FRAMES_PER_SECOND | 
					
						
							|  |  |  |  |                                     ) | 
					
						
							|  |  |  |  |                                     if content_duration - segment["end"] < threshold: | 
					
						
							|  |  |  |  |                                         seek = content_frames | 
					
						
							|  |  |  |  |                                     current_segments[si:] = [] | 
					
						
							|  |  |  |  |                                     break | 
					
						
							|  |  |  |  |                             hal_last_end = segment["end"] | 
					
						
							| 
									
										
										
										
											2024-01-07 19:01:29 +01:00
										 |  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-03-30 13:13:58 -07:00
										 |  |  |  |                     last_word_end = _get_end(current_segments) | 
					
						
							|  |  |  |  |                     if last_word_end is not None: | 
					
						
							|  |  |  |  |                         last_speech_timestamp = last_word_end | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  |                 if verbose: | 
					
						
							|  |  |  |  |                     for segment in current_segments: | 
					
						
							|  |  |  |  |                         start, end, text = ( | 
					
						
							|  |  |  |  |                             segment["start"], | 
					
						
							|  |  |  |  |                             segment["end"], | 
					
						
							|  |  |  |  |                             segment["text"], | 
					
						
							|  |  |  |  |                         ) | 
					
						
							|  |  |  |  |                         line = f"[{_format_timestamp(start)} --> {_format_timestamp(end)}] {text}" | 
					
						
							|  |  |  |  |                         print(make_safe(line)) | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  |                 # if a segment is instantaneous or does not contain text, clear it | 
					
						
							|  |  |  |  |                 for i, segment in enumerate(current_segments): | 
					
						
							|  |  |  |  |                     if ( | 
					
						
							|  |  |  |  |                         segment["start"] == segment["end"] | 
					
						
							|  |  |  |  |                         or segment["text"].strip() == "" | 
					
						
							|  |  |  |  |                     ): | 
					
						
							|  |  |  |  |                         segment["text"] = "" | 
					
						
							|  |  |  |  |                         segment["tokens"] = [] | 
					
						
							|  |  |  |  |                         segment["words"] = [] | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  |                 all_segments.extend( | 
					
						
							|  |  |  |  |                     [ | 
					
						
							|  |  |  |  |                         {"id": i, **segment} | 
					
						
							|  |  |  |  |                         for i, segment in enumerate( | 
					
						
							|  |  |  |  |                             current_segments, start=len(all_segments) | 
					
						
							|  |  |  |  |                         ) | 
					
						
							|  |  |  |  |                     ] | 
					
						
							|  |  |  |  |                 ) | 
					
						
							|  |  |  |  |                 all_tokens.extend( | 
					
						
							|  |  |  |  |                     [ | 
					
						
							|  |  |  |  |                         token | 
					
						
							|  |  |  |  |                         for segment in current_segments | 
					
						
							|  |  |  |  |                         for token in segment["tokens"] | 
					
						
							|  |  |  |  |                     ] | 
					
						
							|  |  |  |  |                 ) | 
					
						
							| 
									
										
										
										
											2023-11-29 08:17:26 -08:00
										 |  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-03-30 13:13:58 -07:00
										 |  |  |  |                 if not condition_on_previous_text or result.temperature > 0.5: | 
					
						
							|  |  |  |  |                     # do not feed the prompt tokens if a high temperature was used | 
					
						
							|  |  |  |  |                     prompt_reset_since = len(all_tokens) | 
					
						
							| 
									
										
										
										
											2023-11-29 08:17:26 -08:00
										 |  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-03-30 13:13:58 -07:00
										 |  |  |  |                 # update progress bar | 
					
						
							|  |  |  |  |                 pbar.update(min(content_frames, seek) - previous_seek) | 
					
						
							| 
									
										
										
										
											2023-11-29 08:17:26 -08:00
										 |  |  |  | 
 | 
					
						
							|  |  |  |  |     return dict( | 
					
						
							|  |  |  |  |         text=tokenizer.decode(all_tokens[len(initial_prompt_tokens) :]), | 
					
						
							|  |  |  |  |         segments=all_segments, | 
					
						
							|  |  |  |  |         language=language, | 
					
						
							|  |  |  |  |     ) |