From 94705ed38bb5b07d7440f50e421d8e01404d15f1 Mon Sep 17 00:00:00 2001 From: bofenghuang Date: Tue, 12 Dec 2023 17:26:52 +0100 Subject: [PATCH] Add large v3 --- whisper/whisper/assets/mel_filters.npz | Bin 2048 -> 4271 bytes whisper/whisper/audio.py | 8 ++++---- whisper/whisper/decoding.py | 9 +++++++-- whisper/whisper/load_models.py | 6 ++++-- whisper/whisper/tokenizer.py | 27 ++++++++++++++++--------- whisper/whisper/torch_whisper.py | 9 +++++++-- whisper/whisper/transcribe.py | 9 +++++++-- whisper/whisper/whisper.py | 6 +++++- 8 files changed, 52 insertions(+), 22 deletions(-) diff --git a/whisper/whisper/assets/mel_filters.npz b/whisper/whisper/assets/mel_filters.npz index 1a7839244dfb6b1cc02e4f3cfe12e4817a073bc7..28ea26909dbdfd608aef67afc4d74d7961ae4bb6 100644 GIT binary patch literal 4271 zcmZ`-cQjmYw;lx1g6JcN7QKe3LG%_Oh!VX=^k~teM-XGQ(Mu4$_Y%?jkm$lFBkB+( z3yfKIgF zxGiAhze`A@t->QRNVV!%P+W=o}VHkB) z%g>qyRHfN1IQ4-=`Y@0T9qE#o+;4E3VQ!epW1Xt=ZG`I3U|62t?<>5h*W|9VvJc`KZ+)ghnA**Z~ET21Tjf_f8oe`vy zZQNtlOx?dDhS71hnOus5cqj)hfyF@H&4y?@9z{I#&cf>A+s2~~(I>TQF}SaR3_tqa z(7&ZdN^vR*t<~?{9DEoI>0PL@Sl?wa?Z{rGX`*eEx9Nh=z*J3HZL1*Py4z$TD#+;m zSSW(kcOTe(4hqgib_W6&xx+j~-u(p)Nn6?>a%wHk=h7Ay$%lcGoo;gAY zmVV7|!Nb;w(PlH@c24{ple2Y3<*9J@jE=sfLzwu_BiAFPE$0Axp`^Nq!H}eG0?r-X zFj@Pwp^al*p>K{@_Cz`q#(N0Y=OpZy^ z{P$KjLJuk_Y%I)$mh`b{uOW5C5Xcmxk!gt_Zg zw>}6fkD4zRK9!#ems~H%U$>V;_wK38Zf-baU$S!#i;7!HWsi}GuC>%@?lMdgkUGC& zh9gC?O-5BlS2#}?7x0?eP#bOL(cqE{M%LJD$CZnplD)CgQR#KCttD=dZK+Ck5R52; z*%5hZ+SXU7)8k%Y^_1U>yI*By(INn&+ir-_4$#dUwTlMNyR@iGQIaZ+eiYqucu)CB z#i{Ru1w+aU#}DHSyzjG_9c?ToB_YjU#f;N=qel98WBIjIc1!#ePwRR+(go&-by#}@ z+M+klVke5b@lWfZ+O&|c??YvRe)&W)qAgtc>t-IZtbRTG#X}49_Q$>P%-)=0W_QY-x%DPep2Vm9#ci zyQcCc4p2&dLtV1@rPe!%>Y^#9W8#ZH&}^@wJKT7N;R9A7cEq&;Y2CYvd@R+Mn&b5O zVyfS^*H#kD74=J5uhD)o`TXoX>>Si$!cT?TXRxj2pB)w_ljjhTby&Je;X|BESZZT= zC%G5!-$BJf&a~U78d_3zBjrvrkJ0CCl@Rfcf7I(`VTNPnI^B#B$zOfPW zG&mEd?R0+W<`l08O1dkcWKS8wB!Z*Cs%I1nMs-EeB-uu5?t@PuD3|z>je8DKi#X(B z{Z=Rz{4X%?-UnxnHQtkELIZ&=J;fK_t}yu8|IxG0(85e&K>H3!!~zlhyJrgti~o1i zzBS*jTgdG~Exp#B-T)6A+PB ztD-e`j^@XAx}|L&JSEFkRvS_%3b%m86z02#Hfn{Y+qIqQ_muywgt?roUA7oiS1xBD zFxmDMsj_cbBcn*^rn^KIMP{AlHM`NiVm*D&`z~7FH#hf<$L3HmJ+=NdiY5>W?nKD? z8Ox6{9dKyI1o8a-j9BtV-|=lm`<`v>tR^Cln&x1dMYzu{@wq5KW!#K14_QMnpH5K%Pavag+g6(i8i-#Eq zguc}rH3?BxH4SOqZW#7m*aT(U9-n#_Xn^Q19(}eH!xG`nI!GYziVQNcA0)`FDHD%~ zz2$HnxW4BQ{#*@u`dssbAa`|fESn$8i8FdxGZh48_Uf~_Q@tv?4in)6fwSed)k&ITqu|){^(WL~J z?Lb|0ro06J^>f>^2}^e-+$u5bU4IZNfO?75v8lstS15%XYw2ac^pkU34{QhDR(umt zPu~`w2?FP|nn3!RWZ3{?=77@teulahD9*S*k5KmY3*adlM)%{SR~bkZYlx1q@fkE= zI$7+kiw5!ha=dYlO>Z5KgxnZEJsaBm%v#nkX0MN-h%n&KA?N}xU3K3o-3Jpk?ANq2n9&Lh%K_CTvfiN ze>6w~NSSl8$#NEZ^t7h9YOxI=zcAG|a+m6AWei`3Jw7K;b;T${pJa^4RwRt%F>?>M zBmoQqm1`<_W7i!5P~THp-II)Ka^u;=z;}d{;SVj{G_4`9^HaEb!=@Pa;Dw)CH^DjsGxFqmb%o$Bkop$KnH8 zDYN)Bh)5=5!-*|f0Gh4)oZG=TEBr()g^DCtSQhmT3!ZN`Qd-E%@1cE}hm8&Vq5B+C zVF2_O)9IiZ(v(xzTwJIg5|}KVuE(;}|7dVIrT`$d=q_OG|3PY}x*URYkMXXJ6PT1$IFkNyvY_(9UglDi6TaeikPS(!Bnij z;Szn+)I_oxnRz7(WTYTp+IHSWQ?Xd~tQn(Q1r)kThM?NM< z?d6LaBG!H}R$zRy!Ij(}1?xe^+o+!;tqWJ3NgjHl1XNxzusxQ0I#6qzM(_00UPMw* zF*GWW_q&fqAN=uimSKgBu_@jD%MX3hpNY|*4r=e=k1lw2r**IyD(hcq?A+HtUgUy4Dqh5D7|G9q{)TsUj{g~c!xy>9wk^(LiXA4VKGz_zMvJMX#AgsR z34T3hhJ)#&sUaQ1+0PML(?YA~{5?=(MT}X^Vib%};uoI{qGW@wgJ&_M+8S8clsNz2 zPQkxMi`#3+Khwtl>>K>wxc{71{&!qGu&Zzz_wU(7TLTyG){PAu?!cXs?Dp-y0Ekcn AQvd(} delta 2007 zcmV;|2PpWjA%GAQP)h>@6aWAK2mk;8ApjrBSy*R7c85^~8Y5zVylPRV%1|^ltxNTi@h}3pFi!M?lnJJfIs2P`vlsTbr(zJd*gVYah_9J z7%VYZZs&g=pz{l}6V`SzaEP7C+Ac68Y*CnY!OV~_|A4=)f20l81vFmQ1!)%sG=8^t zc2HTX9US|sti!GUKWz;frSH8U4TwlD-7-TYPd~~|h!p6SMPqgQY<5DVU~k@O{Ig&y zJ0VigxOd693**@dk%Gi7HuA9RB6dQgU~itA+?szrk)04JIOpXrBaEEb36TQbvS_Id zGG`}53RV_xkZY3e;BDUr(Yq|MOPchUypx>}Dfl`dSE`&{*$I(?y?LkQ-B%IfZQlpc zyDY!u4?M#Y{SRByudov$1&cys#pgb!*$I&XmE&kc80E4PA_bR!NJGfso$Q22fpN#X z&<1~PU?)ThcGz~36T3&V6CwpC&N#_6Nq+2vNWqV@y=CwdS9U_A;Hkl4=`pzvJ0VgK z|6ROP=eJ-dL<+X_N|7q(LjDGlNWt@lEctmr3_BrG@RR*P`E11$c0#0Jc*g^>J|mc& z5GlC$V7DA-lFCkq6!?a2mX$_V*$I(?*vfyEa?@45YlTS3s6(D|cic#JLZskELRXnG zRmV<<6s&rfi<+t=**eylXA~kmbUEk6inFn@eZN|ELZl$yVYoc|U=(*RA_dn9iV%~% zTIStN;7lRXV|?9bVihLH`)#c_Q-~C(A5}s(QzxA(6FF0e^f-6ph{$k~a zL4D1nz)H44q(@b(IgGt-BQvhI3@^&TCpY}KhY>Yn{1$I9bwVc2w*H%Z5qbyTSq$TB zA<`qIa3YHO*5J>XrgFuCR4k}k$k{^Fj8+GBh|Mo{VA5Grd1C%8INtN-Y$1QrW96~| zNE=@SZDdDjoplHeS)JJnQ8S`!Hfb8%{80S!HB5MFASb1Mg~9fBxr-4s!*y1e(0vn* zx!PN>-17(hM_h0U|hKlA@i!)3sEylixPxYT`=bT zy#xgXRd{T^1@0Hkd43^kM&Wc`r1#t=0zy-uxp z>DV3bZ`}yrlpmpLaRND0bTFM=#$Je;QEzliboxvSpO^#aR#%RexnH89T#b{q6W9z< z6P~-uG~N-N@yKR1@{G^m>7{(c42eRYVeN3I`#!$=g{X=888O1}nHqntKdnd1sPm{f zmx(@Y0#Ih~SlBOjVKYQce7}2-xIB6Uu2*cp<<72X z_llb0VTd`o7M(1QVadpCI2f$Oipj0erK2U!Gek`cDq1fR7xzJH;}vkZl!Y1aDm zhsgexxL@-Hn;~jq#1(%RQFq@EeJgxmKPL%=5jhx=yc%OBx!_adTOvB{CHFI;H|%zr zqPcl4PwX6OhjODuFbdud_3SO^ay1Ag)xGfS)gDo+c4Rk1Z-`48AS%r&Md1Jq;_RZ} za3l@uYuBPSbRtgIw}*X)WRX0!(D`lOC(*m4b(>kjr>ItRIs<>aNDs&Fh1*bmAr=GG zW08H)1YX~66@fiwao!NULHf-Qhlc(ls!|>C$JYxGbvYR>IkC8z>4_1w77)u1i9f9j z`CNeL4f{Gch-9M@k!5I!Hqk!FGgytIaq%$Boq$1n;E@2nbt*r`=q!rl13cUd$J{Y87V=(HkNh28)}K{}#t{ zO^{$W0@G9G!_p=W1sPhnz2}4?|5n%-mo0`R_GCLmuf?k_JEupD6BXen#Q{rGEUni- zRs_QRLNpSBred8Kj14BOV0(PO@G2Z8qBE?yvk^6#w5os1IriR2;hB&t{^i{gLvO2a zEW{VCFC+2v_%sYZH3ZAfm>?{%MBE*q6PH4F@fjP@YrCGCi6>#nqQ^gOi)}ewvH$%K za5}>u<@W>e;MRwDRo)Z*w*M-|uFDWr2E)bAG40q7Q6trCi*+K4_lddg^BO mx.array: +def mel_filters(n_mels: int) -> mx.array: """ load the mel filterbank matrix for projecting STFT into a Mel spectrogram. Allows decoupling librosa dependency; saved using: @@ -89,9 +88,10 @@ def mel_filters(n_mels: int = N_MELS) -> mx.array: np.savez_compressed( "mel_filters.npz", mel_80=librosa.filters.mel(sr=16000, n_fft=400, n_mels=80), + mel_128=librosa.filters.mel(sr=16000, n_fft=400, n_mels=128), ) """ - assert n_mels == 80, f"Unsupported n_mels: {n_mels}" + assert n_mels in {80, 128}, f"Unsupported n_mels: {n_mels}" filename = os.path.join(os.path.dirname(__file__), "assets", "mel_filters.npz") return mx.load(filename)[f"mel_{n_mels}"] @@ -130,7 +130,7 @@ def stft(x, window, nperseg=256, noverlap=None, nfft=None, axis=-1, pad_mode="re def log_mel_spectrogram( audio: Union[str, np.ndarray], - n_mels: int = N_MELS, + n_mels: int = 80, padding: int = 0, ): """ diff --git a/whisper/whisper/decoding.py b/whisper/whisper/decoding.py index c4b6326d..dd0fa5e5 100644 --- a/whisper/whisper/decoding.py +++ b/whisper/whisper/decoding.py @@ -33,7 +33,9 @@ def detect_language( list of dictionaries containing the probability distribution over all languages. """ if tokenizer is None: - tokenizer = get_tokenizer(model.is_multilingual) + tokenizer = get_tokenizer( + model.is_multilingual, num_languages=model.num_languages + ) if ( tokenizer.language is None or tokenizer.language_token not in tokenizer.sot_sequence @@ -401,7 +403,10 @@ class DecodingTask: language = options.language or "en" tokenizer = get_tokenizer( - model.is_multilingual, language=language, task=options.task + model.is_multilingual, + num_languages=model.num_languages, + language=language, + task=options.task, ) self.tokenizer: Tokenizer = tokenizer self.options: DecodingOptions = self._verify_options(options) diff --git a/whisper/whisper/load_models.py b/whisper/whisper/load_models.py index 2d0ae578..2d9b1fea 100644 --- a/whisper/whisper/load_models.py +++ b/whisper/whisper/load_models.py @@ -25,7 +25,8 @@ _MODELS = { "medium": "https://openaipublic.azureedge.net/main/whisper/models/345ae4da62f9b3d59415adc60127b97c714f32e89e936602e85993674d08dcb1/medium.pt", "large-v1": "https://openaipublic.azureedge.net/main/whisper/models/e4b87e7e0bf463eb8e6956e646f1e277e901512310def2c24bf0e11bd3c28e9a/large-v1.pt", "large-v2": "https://openaipublic.azureedge.net/main/whisper/models/81f7c96c852ee8fc832187b0132e569d6c3065a3252ed18e56effd0b6a73e524/large-v2.pt", - "large": "https://openaipublic.azureedge.net/main/whisper/models/81f7c96c852ee8fc832187b0132e569d6c3065a3252ed18e56effd0b6a73e524/large-v2.pt", + "large-v3": "https://openaipublic.azureedge.net/main/whisper/models/e5b1a55b89c1367dacf97e3e19bfd829a01529dbfdeefa8caeb59b3f1b81dadb/large-v3.pt", + "large": "https://openaipublic.azureedge.net/main/whisper/models/e5b1a55b89c1367dacf97e3e19bfd829a01529dbfdeefa8caeb59b3f1b81dadb/large-v3.pt", } # base85-encoded (n_layers, n_heads) boolean arrays indicating the cross-attention heads that are @@ -41,7 +42,8 @@ _ALIGNMENT_HEADS = { "medium": b"ABzY8B0Jh+0{>%R7}kK1fFL7w6%<-Pf*t^=N)Qr&0RR9", "large-v1": b"ABzY8r9j$a0{>%R7#4sLmoOs{s)o3~84-RPdcFk!JR%R7=D0pU<_bnWW*tkYAhobTNnu$jnkEkXqp)j;w1Tzk)UH3X%SZd&fFZ2fC2yj", - "large": b"ABzY8zd+h!0{>%R7=D0pU<_bnWW*tkYAhobTNnu$jnkEkXqp)j;w1Tzk)UH3X%SZd&fFZ2fC2yj", + "large-v3": b"ABzY8gWO1E0{>%R7(9S+Kn!D~%ngiGaR?*L!iJG9p-nab0JQ=-{D1-g00", + "large": b"ABzY8gWO1E0{>%R7(9S+Kn!D~%ngiGaR?*L!iJG9p-nab0JQ=-{D1-g00" } diff --git a/whisper/whisper/tokenizer.py b/whisper/whisper/tokenizer.py index 5e345508..b589f764 100644 --- a/whisper/whisper/tokenizer.py +++ b/whisper/whisper/tokenizer.py @@ -109,6 +109,7 @@ LANGUAGES = { "ba": "bashkir", "jw": "javanese", "su": "sundanese", + "yue": "cantonese", } # language code lookup by name, with a few language aliases @@ -125,6 +126,7 @@ TO_LANGUAGE_CODE = { "moldovan": "ro", "sinhalese": "si", "castilian": "es", + "mandarin": "zh", } @@ -133,6 +135,7 @@ class Tokenizer: """A thin wrapper around `tiktoken` providing quick access to special tokens""" encoding: tiktoken.Encoding + num_languages: int language: Optional[str] = None task: Optional[str] = None sot_sequence: Tuple[int] = () @@ -147,7 +150,7 @@ class Tokenizer: translate: int = self.special_tokens["<|translate|>"] transcribe: int = self.special_tokens["<|transcribe|>"] - langs = tuple(LANGUAGES.keys()) + langs = tuple(LANGUAGES.keys())[: self.num_languages] sot_sequence = [sot] if self.language is not None: sot_sequence.append(sot + 1 + langs.index(self.language)) @@ -213,10 +216,13 @@ class Tokenizer: if self.language is None: raise ValueError("This tokenizer does not have language token configured") - if token := self.special_tokens.get(f"<|{self.language}|>", None): + return self.to_language_token(self.language) + + def to_language_token(self, language): + if token := self.special_tokens.get(f"<|{language}|>", None): return token - raise KeyError(f"Language {self.language} not found in tokenizer.") + raise KeyError(f"Language {language} not found in tokenizer.") @cached_property def all_language_tokens(self) -> Tuple[int]: @@ -224,7 +230,7 @@ class Tokenizer: for token, token_id in self.special_tokens.items(): if token.strip("<|>") in LANGUAGES: result.append(token_id) - return tuple(result) + return tuple(result)[: self.num_languages] @cached_property def all_language_codes(self) -> Tuple[str]: @@ -271,7 +277,7 @@ class Tokenizer: return tuple(sorted(result)) def split_to_word_tokens(self, tokens: List[int]): - if self.language in {"zh", "ja", "th", "lo", "my"}: + if self.language in {"zh", "ja", "th", "lo", "my", "yue"}: # These languages don't typically use spaces, so it is difficult to split words # without morpheme analysis. Here, we instead split words at any # position where the tokens are decoded as valid unicode points @@ -324,7 +330,7 @@ class Tokenizer: @lru_cache(maxsize=None) -def get_encoding(name: str = "gpt2"): +def get_encoding(name: str = "gpt2", num_languages: int = 99): vocab_path = os.path.join(os.path.dirname(__file__), "assets", f"{name}.tiktoken") with open(vocab_path) as fid: ranks = { @@ -337,7 +343,7 @@ def get_encoding(name: str = "gpt2"): specials = [ "<|endoftext|>", "<|startoftranscript|>", - *[f"<|{lang}|>" for lang in LANGUAGES.keys()], + *[f"<|{lang}|>" for lang in list(LANGUAGES.keys())[:num_languages]], "<|translate|>", "<|transcribe|>", "<|startoflm|>", @@ -364,6 +370,7 @@ def get_encoding(name: str = "gpt2"): def get_tokenizer( multilingual: bool, *, + num_languages: int = 99, language: Optional[str] = None, task: Optional[str] = None, # Literal["transcribe", "translate", None] ) -> Tokenizer: @@ -384,6 +391,8 @@ def get_tokenizer( language = None task = None - encoding = get_encoding(name=encoding_name) + encoding = get_encoding(name=encoding_name, num_languages=num_languages) - return Tokenizer(encoding=encoding, language=language, task=task) + return Tokenizer( + encoding=encoding, num_languages=num_languages, language=language, task=task + ) diff --git a/whisper/whisper/torch_whisper.py b/whisper/whisper/torch_whisper.py index 0ffcf302..3b5491e4 100644 --- a/whisper/whisper/torch_whisper.py +++ b/whisper/whisper/torch_whisper.py @@ -234,7 +234,8 @@ class Whisper(nn.Module): self.dims.n_text_head, self.dims.n_text_layer, ) - # use the last half layers for alignment by default; see `set_alignment_heads()` below + # use the last half among the decoder layers for time alignment by default; + # to use a specific set of heads, see `set_alignment_heads()` below. all_heads = torch.zeros( self.dims.n_text_layer, self.dims.n_text_head, dtype=torch.bool ) @@ -267,7 +268,11 @@ class Whisper(nn.Module): @property def is_multilingual(self): - return self.dims.n_vocab == 51865 + return self.dims.n_vocab >= 51865 + + @property + def num_languages(self): + return self.dims.n_vocab - 51765 - int(self.is_multilingual) def install_kv_cache_hooks(self, cache: Optional[dict] = None): """ diff --git a/whisper/whisper/transcribe.py b/whisper/whisper/transcribe.py index bfdc32b5..40b9ec40 100644 --- a/whisper/whisper/transcribe.py +++ b/whisper/whisper/transcribe.py @@ -119,7 +119,7 @@ def transcribe( dtype = mx.float16 if decode_options.get("fp16", False) else mx.float32 # Pad 30-seconds of silence to the input audio, for slicing - mel = log_mel_spectrogram(audio, padding=N_SAMPLES) + mel = log_mel_spectrogram(audio, n_mels=model.dims.n_mels, padding=N_SAMPLES) content_frames = mel.shape[-2] - N_FRAMES if verbose: @@ -150,7 +150,12 @@ def transcribe( language: str = decode_options["language"] task: str = decode_options.get("task", "transcribe") - tokenizer = get_tokenizer(model.is_multilingual, language=language, task=task) + tokenizer = get_tokenizer( + model.is_multilingual, + num_languages=model.num_languages, + language=language, + task=task, + ) def decode_with_fallback(segment: mx.array) -> DecodingResult: temperatures = ( diff --git a/whisper/whisper/whisper.py b/whisper/whisper/whisper.py index ec60c6ec..983819ef 100644 --- a/whisper/whisper/whisper.py +++ b/whisper/whisper/whisper.py @@ -210,7 +210,11 @@ class Whisper(nn.Module): @property def is_multilingual(self): - return self.dims.n_vocab == 51865 + return self.dims.n_vocab >= 51865 + + @property + def num_languages(self): + return self.dims.n_vocab - 51765 - int(self.is_multilingual) detect_language = detect_language_function decode = decode_function