diff --git a/whisper/whisper/assets/mel_filters.npz b/whisper/whisper/assets/mel_filters.npz index 1a783924..28ea2690 100644 Binary files a/whisper/whisper/assets/mel_filters.npz and b/whisper/whisper/assets/mel_filters.npz differ diff --git a/whisper/whisper/audio.py b/whisper/whisper/audio.py index 549158a0..5e63fb7f 100644 --- a/whisper/whisper/audio.py +++ b/whisper/whisper/audio.py @@ -11,7 +11,6 @@ import numpy as np # hard-coded audio hyperparameters SAMPLE_RATE = 16000 N_FFT = 400 -N_MELS = 80 HOP_LENGTH = 160 CHUNK_LENGTH = 30 N_SAMPLES = CHUNK_LENGTH * SAMPLE_RATE # 480000 samples in a 30-second chunk @@ -81,7 +80,7 @@ def pad_or_trim(array, length: int = N_SAMPLES, *, axis: int = -1): @lru_cache(maxsize=None) -def mel_filters(n_mels: int = N_MELS) -> mx.array: +def mel_filters(n_mels: int) -> mx.array: """ load the mel filterbank matrix for projecting STFT into a Mel spectrogram. Allows decoupling librosa dependency; saved using: @@ -89,9 +88,10 @@ def mel_filters(n_mels: int = N_MELS) -> mx.array: np.savez_compressed( "mel_filters.npz", mel_80=librosa.filters.mel(sr=16000, n_fft=400, n_mels=80), + mel_128=librosa.filters.mel(sr=16000, n_fft=400, n_mels=128), ) """ - assert n_mels == 80, f"Unsupported n_mels: {n_mels}" + assert n_mels in {80, 128}, f"Unsupported n_mels: {n_mels}" filename = os.path.join(os.path.dirname(__file__), "assets", "mel_filters.npz") return mx.load(filename)[f"mel_{n_mels}"] @@ -130,7 +130,7 @@ def stft(x, window, nperseg=256, noverlap=None, nfft=None, axis=-1, pad_mode="re def log_mel_spectrogram( audio: Union[str, np.ndarray], - n_mels: int = N_MELS, + n_mels: int = 80, padding: int = 0, ): """ diff --git a/whisper/whisper/decoding.py b/whisper/whisper/decoding.py index d63d5e98..7c7c4a93 100644 --- a/whisper/whisper/decoding.py +++ b/whisper/whisper/decoding.py @@ -33,7 +33,9 @@ def detect_language( list of dictionaries containing the probability distribution over all languages. """ if tokenizer is None: - tokenizer = get_tokenizer(model.is_multilingual) + tokenizer = get_tokenizer( + model.is_multilingual, num_languages=model.num_languages + ) if ( tokenizer.language is None or tokenizer.language_token not in tokenizer.sot_sequence @@ -401,7 +403,10 @@ class DecodingTask: language = options.language or "en" tokenizer = get_tokenizer( - model.is_multilingual, language=language, task=options.task + model.is_multilingual, + num_languages=model.num_languages, + language=language, + task=options.task, ) self.tokenizer: Tokenizer = tokenizer self.options: DecodingOptions = self._verify_options(options) diff --git a/whisper/whisper/load_models.py b/whisper/whisper/load_models.py index 6a4e301b..58cef9ac 100644 --- a/whisper/whisper/load_models.py +++ b/whisper/whisper/load_models.py @@ -26,7 +26,8 @@ _MODELS = { "medium": "https://openaipublic.azureedge.net/main/whisper/models/345ae4da62f9b3d59415adc60127b97c714f32e89e936602e85993674d08dcb1/medium.pt", "large-v1": "https://openaipublic.azureedge.net/main/whisper/models/e4b87e7e0bf463eb8e6956e646f1e277e901512310def2c24bf0e11bd3c28e9a/large-v1.pt", "large-v2": "https://openaipublic.azureedge.net/main/whisper/models/81f7c96c852ee8fc832187b0132e569d6c3065a3252ed18e56effd0b6a73e524/large-v2.pt", - "large": "https://openaipublic.azureedge.net/main/whisper/models/81f7c96c852ee8fc832187b0132e569d6c3065a3252ed18e56effd0b6a73e524/large-v2.pt", + "large-v3": "https://openaipublic.azureedge.net/main/whisper/models/e5b1a55b89c1367dacf97e3e19bfd829a01529dbfdeefa8caeb59b3f1b81dadb/large-v3.pt", + "large": "https://openaipublic.azureedge.net/main/whisper/models/e5b1a55b89c1367dacf97e3e19bfd829a01529dbfdeefa8caeb59b3f1b81dadb/large-v3.pt", } # base85-encoded (n_layers, n_heads) boolean arrays indicating the cross-attention heads that are @@ -42,7 +43,8 @@ _ALIGNMENT_HEADS = { "medium": b"ABzY8B0Jh+0{>%R7}kK1fFL7w6%<-Pf*t^=N)Qr&0RR9", "large-v1": b"ABzY8r9j$a0{>%R7#4sLmoOs{s)o3~84-RPdcFk!JR%R7=D0pU<_bnWW*tkYAhobTNnu$jnkEkXqp)j;w1Tzk)UH3X%SZd&fFZ2fC2yj", - "large": b"ABzY8zd+h!0{>%R7=D0pU<_bnWW*tkYAhobTNnu$jnkEkXqp)j;w1Tzk)UH3X%SZd&fFZ2fC2yj", + "large-v3": b"ABzY8gWO1E0{>%R7(9S+Kn!D~%ngiGaR?*L!iJG9p-nab0JQ=-{D1-g00", + "large": b"ABzY8gWO1E0{>%R7(9S+Kn!D~%ngiGaR?*L!iJG9p-nab0JQ=-{D1-g00" } diff --git a/whisper/whisper/tokenizer.py b/whisper/whisper/tokenizer.py index 5e345508..b589f764 100644 --- a/whisper/whisper/tokenizer.py +++ b/whisper/whisper/tokenizer.py @@ -109,6 +109,7 @@ LANGUAGES = { "ba": "bashkir", "jw": "javanese", "su": "sundanese", + "yue": "cantonese", } # language code lookup by name, with a few language aliases @@ -125,6 +126,7 @@ TO_LANGUAGE_CODE = { "moldovan": "ro", "sinhalese": "si", "castilian": "es", + "mandarin": "zh", } @@ -133,6 +135,7 @@ class Tokenizer: """A thin wrapper around `tiktoken` providing quick access to special tokens""" encoding: tiktoken.Encoding + num_languages: int language: Optional[str] = None task: Optional[str] = None sot_sequence: Tuple[int] = () @@ -147,7 +150,7 @@ class Tokenizer: translate: int = self.special_tokens["<|translate|>"] transcribe: int = self.special_tokens["<|transcribe|>"] - langs = tuple(LANGUAGES.keys()) + langs = tuple(LANGUAGES.keys())[: self.num_languages] sot_sequence = [sot] if self.language is not None: sot_sequence.append(sot + 1 + langs.index(self.language)) @@ -213,10 +216,13 @@ class Tokenizer: if self.language is None: raise ValueError("This tokenizer does not have language token configured") - if token := self.special_tokens.get(f"<|{self.language}|>", None): + return self.to_language_token(self.language) + + def to_language_token(self, language): + if token := self.special_tokens.get(f"<|{language}|>", None): return token - raise KeyError(f"Language {self.language} not found in tokenizer.") + raise KeyError(f"Language {language} not found in tokenizer.") @cached_property def all_language_tokens(self) -> Tuple[int]: @@ -224,7 +230,7 @@ class Tokenizer: for token, token_id in self.special_tokens.items(): if token.strip("<|>") in LANGUAGES: result.append(token_id) - return tuple(result) + return tuple(result)[: self.num_languages] @cached_property def all_language_codes(self) -> Tuple[str]: @@ -271,7 +277,7 @@ class Tokenizer: return tuple(sorted(result)) def split_to_word_tokens(self, tokens: List[int]): - if self.language in {"zh", "ja", "th", "lo", "my"}: + if self.language in {"zh", "ja", "th", "lo", "my", "yue"}: # These languages don't typically use spaces, so it is difficult to split words # without morpheme analysis. Here, we instead split words at any # position where the tokens are decoded as valid unicode points @@ -324,7 +330,7 @@ class Tokenizer: @lru_cache(maxsize=None) -def get_encoding(name: str = "gpt2"): +def get_encoding(name: str = "gpt2", num_languages: int = 99): vocab_path = os.path.join(os.path.dirname(__file__), "assets", f"{name}.tiktoken") with open(vocab_path) as fid: ranks = { @@ -337,7 +343,7 @@ def get_encoding(name: str = "gpt2"): specials = [ "<|endoftext|>", "<|startoftranscript|>", - *[f"<|{lang}|>" for lang in LANGUAGES.keys()], + *[f"<|{lang}|>" for lang in list(LANGUAGES.keys())[:num_languages]], "<|translate|>", "<|transcribe|>", "<|startoflm|>", @@ -364,6 +370,7 @@ def get_encoding(name: str = "gpt2"): def get_tokenizer( multilingual: bool, *, + num_languages: int = 99, language: Optional[str] = None, task: Optional[str] = None, # Literal["transcribe", "translate", None] ) -> Tokenizer: @@ -384,6 +391,8 @@ def get_tokenizer( language = None task = None - encoding = get_encoding(name=encoding_name) + encoding = get_encoding(name=encoding_name, num_languages=num_languages) - return Tokenizer(encoding=encoding, language=language, task=task) + return Tokenizer( + encoding=encoding, num_languages=num_languages, language=language, task=task + ) diff --git a/whisper/whisper/torch_whisper.py b/whisper/whisper/torch_whisper.py index 0ffcf302..3b5491e4 100644 --- a/whisper/whisper/torch_whisper.py +++ b/whisper/whisper/torch_whisper.py @@ -234,7 +234,8 @@ class Whisper(nn.Module): self.dims.n_text_head, self.dims.n_text_layer, ) - # use the last half layers for alignment by default; see `set_alignment_heads()` below + # use the last half among the decoder layers for time alignment by default; + # to use a specific set of heads, see `set_alignment_heads()` below. all_heads = torch.zeros( self.dims.n_text_layer, self.dims.n_text_head, dtype=torch.bool ) @@ -267,7 +268,11 @@ class Whisper(nn.Module): @property def is_multilingual(self): - return self.dims.n_vocab == 51865 + return self.dims.n_vocab >= 51865 + + @property + def num_languages(self): + return self.dims.n_vocab - 51765 - int(self.is_multilingual) def install_kv_cache_hooks(self, cache: Optional[dict] = None): """ diff --git a/whisper/whisper/transcribe.py b/whisper/whisper/transcribe.py index f05b828c..3172bdb3 100644 --- a/whisper/whisper/transcribe.py +++ b/whisper/whisper/transcribe.py @@ -118,7 +118,7 @@ def transcribe( model = ModelHolder.get_model(model, dtype) # Pad 30-seconds of silence to the input audio, for slicing - mel = log_mel_spectrogram(audio, padding=N_SAMPLES) + mel = log_mel_spectrogram(audio, n_mels=model.dims.n_mels, padding=N_SAMPLES) content_frames = mel.shape[-2] - N_FRAMES if verbose: @@ -149,7 +149,12 @@ def transcribe( language: str = decode_options["language"] task: str = decode_options.get("task", "transcribe") - tokenizer = get_tokenizer(model.is_multilingual, language=language, task=task) + tokenizer = get_tokenizer( + model.is_multilingual, + num_languages=model.num_languages, + language=language, + task=task, + ) def decode_with_fallback(segment: mx.array) -> DecodingResult: temperatures = ( diff --git a/whisper/whisper/whisper.py b/whisper/whisper/whisper.py index 1c7b856f..62e43de3 100644 --- a/whisper/whisper/whisper.py +++ b/whisper/whisper/whisper.py @@ -216,7 +216,11 @@ class Whisper(nn.Module): @property def is_multilingual(self): - return self.dims.n_vocab == 51865 + return self.dims.n_vocab >= 51865 + + @property + def num_languages(self): + return self.dims.n_vocab - 51765 - int(self.is_multilingual) detect_language = detect_language_function decode = decode_function