From 6e723a015a5b9d8f39d64a8e2788d87709c45123 Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Tue, 12 Dec 2023 07:37:35 -0800 Subject: [PATCH 1/7] whisper default in fp16 --- whisper/benchmark.py | 9 +++++---- whisper/test.py | 23 ++++++++++++++++------- whisper/whisper/decoding.py | 6 +++--- whisper/whisper/load_models.py | 9 ++++++--- whisper/whisper/transcribe.py | 9 ++++----- whisper/whisper/whisper.py | 26 ++++++++++++++++---------- 6 files changed, 50 insertions(+), 32 deletions(-) diff --git a/whisper/benchmark.py b/whisper/benchmark.py index 9df6b500..228a3b36 100644 --- a/whisper/benchmark.py +++ b/whisper/benchmark.py @@ -57,12 +57,13 @@ if __name__ == "__main__": if sys.argv[1] == "--all": models = ["tiny", "small", "medium", "large"] + feat_time = timer(feats) + print(f"\nFeature time {feat_time:.3f}") + mels = feats()[None].astype(mx.float16) + for model_name in models: - feat_time = timer(feats) print(f"\nModel: {model_name.upper()}") - print(f"\nFeature time {feat_time:.3f}") - mels = feats()[None] tokens = mx.array( [ 50364, @@ -96,7 +97,7 @@ if __name__ == "__main__": ], mx.int32, )[None] - model = load_models.load_model(f"{model_name}") + model = load_models.load_model(f"{model_name}", dtype=mx.float16) model_forward_time = timer(model_forward, model, mels, tokens) print(f"Model forward time {model_forward_time:.3f}") decode_time = timer(decode, model, mels) diff --git a/whisper/test.py b/whisper/test.py index 44f99edf..79f233ba 100644 --- a/whisper/test.py +++ b/whisper/test.py @@ -36,7 +36,7 @@ def forward_mlx(model, mels, tokens): class TestWhisper(unittest.TestCase): @classmethod def setUpClass(cls): - cls.model = load_models.load_model("tiny") + cls.model = load_models.load_model("tiny", dtype=mx.float32) data = audio.load_audio(TEST_AUDIO) data = audio.pad_or_trim(data) cls.mels = audio.log_mel_spectrogram(data) @@ -52,13 +52,22 @@ class TestWhisper(unittest.TestCase): torch_logits = forward_torch(torch_model, mels, tokens) - mlx_model = load_models.torch_to_mlx(torch_model) + mlx_model = load_models.torch_to_mlx(torch_model, mx.float32) mlx_logits = forward_mlx(mlx_model, mels, tokens) self.assertTrue(np.allclose(torch_logits, mlx_logits, atol=1e-2, rtol=1e-2)) + def test_fp16(self): + mlx_model = load_models.load_model("tiny", dtype=mx.float16) + dims = mlx_model.dims + mels = mx.array(np.random.randn(1, 3_000, dims.n_mels), mx.float16) + tokens = mx.array(np.random.randint(0, dims.n_vocab, (1, 20)), mx.int32) + logits = mlx_model(mels, tokens) + self.assertEqual(logits.dtype, mx.float16) + + def test_decode_lang(self): - options = decoding.DecodingOptions(task="lang_id") + options = decoding.DecodingOptions(task="lang_id", fp16=False) result = decoding.decode(self.model, self.mels, options) self.assertEqual(result.language, "en") self.assertEqual(len(result.language_probs), 99) @@ -67,7 +76,7 @@ class TestWhisper(unittest.TestCase): ) def test_decode_greedy(self): - result = decoding.decode(self.model, self.mels) + result = decoding.decode(self.model, self.mels, fp16=False) self.assertEqual(result.language, "en") self.assertEqual( result.tokens, @@ -114,7 +123,7 @@ class TestWhisper(unittest.TestCase): self.assertAlmostEqual(result.compression_ratio, 1.2359550561797752) # Small temp should give the same results - result = decoding.decode(self.model, self.mels, temperature=1e-8) + result = decoding.decode(self.model, self.mels, temperature=1e-8, fp16=False) self.assertEqual( result.text, @@ -128,7 +137,7 @@ class TestWhisper(unittest.TestCase): self.assertAlmostEqual(result.compression_ratio, 1.2359550561797752) def test_transcribe(self): - result = whisper.transcribe(TEST_AUDIO) + result = whisper.transcribe(TEST_AUDIO, fp16=False) self.assertEqual( result["text"], ( @@ -147,7 +156,7 @@ class TestWhisper(unittest.TestCase): print("bash path_to_whisper_repo/whisper/assets/download_alice.sh") return - result = whisper.transcribe(audio_file) + result = whisper.transcribe(audio_file, fp16=False) self.assertEqual(len(result["text"]), 10920) self.assertEqual(result["language"], "en") self.assertEqual(len(result["segments"]), 77) diff --git a/whisper/whisper/decoding.py b/whisper/whisper/decoding.py index c4b6326d..d63d5e98 100644 --- a/whisper/whisper/decoding.py +++ b/whisper/whisper/decoding.py @@ -110,7 +110,7 @@ class DecodingOptions: max_initial_timestamp: Optional[float] = 1.0 # implementation details - fp16: bool = False # use fp16 for most of the calculation + fp16: bool = True # use fp16 for most of the calculation @dataclass(frozen=True) @@ -141,7 +141,7 @@ class Inference: logits, self.kv_cache = self.model.decoder( tokens, audio_features, kv_cache=self.kv_cache ) - return logits + return logits.astype(mx.float32) def rearrange_kv_cache(self, source_indices): """Update the key-value cache according to the updated beams""" @@ -542,7 +542,7 @@ class DecodingTask: audio_features = self.model.encoder(mel) if audio_features.dtype != (mx.float16 if self.options.fp16 else mx.float32): - return TypeError( + raise TypeError( f"audio_features has an incorrect dtype: {audio_features.dtype}" ) diff --git a/whisper/whisper/load_models.py b/whisper/whisper/load_models.py index 2d0ae578..6a4e301b 100644 --- a/whisper/whisper/load_models.py +++ b/whisper/whisper/load_models.py @@ -7,6 +7,7 @@ import warnings from typing import List import mlx.core as mx +from mlx.utils import tree_map import torch from tqdm import tqdm @@ -163,7 +164,7 @@ def convert(model, rules=None): def torch_to_mlx( - torch_model: torch_whisper.Whisper, + torch_model: torch_whisper.Whisper, dtype: mx.Dtype = mx.float16, ) -> whisper.Whisper: def convert_rblock(model, rules): children = dict(model.named_children()) @@ -182,7 +183,8 @@ def torch_to_mlx( params = convert(torch_model, rules) - mlx_model = whisper.Whisper(torch_model.dims) + mlx_model = whisper.Whisper(torch_model.dims, dtype) + params = tree_map(lambda p: p.astype(dtype), params) mlx_model.update(params) return mlx_model @@ -190,5 +192,6 @@ def torch_to_mlx( def load_model( name: str, download_root: str = None, + dtype : mx.Dtype = mx.float32, ) -> whisper.Whisper: - return torch_to_mlx(load_torch_model(name, download_root)) + return torch_to_mlx(load_torch_model(name, download_root), dtype) diff --git a/whisper/whisper/transcribe.py b/whisper/whisper/transcribe.py index bfdc32b5..f05b828c 100644 --- a/whisper/whisper/transcribe.py +++ b/whisper/whisper/transcribe.py @@ -43,9 +43,9 @@ class ModelHolder: model_name = None @classmethod - def get_model(cls, model: str): + def get_model(cls, model: str, dtype : mx.Dtype): if cls.model is None or model != cls.model_name: - cls.model = load_model(model) + cls.model = load_model(model, dtype=dtype) cls.model_name = model return cls.model @@ -114,9 +114,8 @@ def transcribe( the spoken language ("language"), which is detected when `decode_options["language"]` is None. """ - model = ModelHolder.get_model(model) - - dtype = mx.float16 if decode_options.get("fp16", False) else mx.float32 + dtype = mx.float16 if decode_options.get("fp16", True) else mx.float32 + model = ModelHolder.get_model(model, dtype) # Pad 30-seconds of silence to the input audio, for slicing mel = log_mel_spectrogram(audio, padding=N_SAMPLES) diff --git a/whisper/whisper/whisper.py b/whisper/whisper/whisper.py index ec60c6ec..1c7b856f 100644 --- a/whisper/whisper/whisper.py +++ b/whisper/whisper/whisper.py @@ -37,6 +37,10 @@ def sinusoids(length, channels, max_timescale=10000): scaled_time = mx.arange(length)[:, None] * inv_timescales[None, :] return mx.concatenate([mx.sin(scaled_time), mx.cos(scaled_time)], axis=1) +class LayerNorm(nn.LayerNorm): + def __call__(self, x: mx.array) -> mx.array: + return super().__call__(x.astype(mx.float32)).astype(x.dtype) + class MultiHeadAttention(nn.Module): def __init__(self, n_state: int, n_head: int): @@ -94,17 +98,17 @@ class ResidualAttentionBlock(nn.Module): super().__init__() self.attn = MultiHeadAttention(n_state, n_head) - self.attn_ln = nn.LayerNorm(n_state) + self.attn_ln = LayerNorm(n_state) self.cross_attn = ( MultiHeadAttention(n_state, n_head) if cross_attention else None ) - self.cross_attn_ln = nn.LayerNorm(n_state) if cross_attention else None + self.cross_attn_ln = LayerNorm(n_state) if cross_attention else None n_mlp = n_state * 4 self.mlp1 = nn.Linear(n_state, n_mlp) self.mlp2 = nn.Linear(n_mlp, n_state) - self.mlp_ln = nn.LayerNorm(n_state) + self.mlp_ln = LayerNorm(n_state) def __call__(self, x, xa=None, mask=None, kv_cache=None): kv, cross_kv = kv_cache if kv_cache else (None, None) @@ -119,15 +123,15 @@ class ResidualAttentionBlock(nn.Module): class AudioEncoder(nn.Module): def __init__( - self, n_mels: int, n_ctx: int, n_state: int, n_head: int, n_layer: int + self, n_mels: int, n_ctx: int, n_state: int, n_head: int, n_layer: int, dtype: mx.Dtype = mx.float16, ): super().__init__() self.conv1 = nn.Conv1d(n_mels, n_state, kernel_size=3, padding=1) self.conv2 = nn.Conv1d(n_state, n_state, kernel_size=3, stride=2, padding=1) - self._positional_embedding = sinusoids(n_ctx, n_state) + self._positional_embedding = sinusoids(n_ctx, n_state).astype(dtype) self.blocks = [ResidualAttentionBlock(n_state, n_head) for _ in range(n_layer)] - self.ln_post = nn.LayerNorm(n_state) + self.ln_post = LayerNorm(n_state) def __call__(self, x): x = nn.gelu(self.conv1(x)) @@ -144,7 +148,7 @@ class AudioEncoder(nn.Module): class TextDecoder(nn.Module): def __init__( - self, n_vocab: int, n_ctx: int, n_state: int, n_head: int, n_layer: int + self, n_vocab: int, n_ctx: int, n_state: int, n_head: int, n_layer: int, dtype: mx.Dtype = mx.float16, ): super().__init__() @@ -155,8 +159,8 @@ class TextDecoder(nn.Module): ResidualAttentionBlock(n_state, n_head, cross_attention=True) for _ in range(n_layer) ] - self.ln = nn.LayerNorm(n_state) - self._mask = nn.MultiHeadAttention.create_additive_causal_mask(n_ctx) + self.ln = LayerNorm(n_state) + self._mask = nn.MultiHeadAttention.create_additive_causal_mask(n_ctx).astype(dtype) def __call__(self, x, xa, kv_cache=None): """ @@ -181,7 +185,7 @@ class TextDecoder(nn.Module): class Whisper(nn.Module): - def __init__(self, dims: ModelDimensions): + def __init__(self, dims: ModelDimensions, dtype: mx.Dtype = mx.float16): super().__init__() self.dims = dims self.encoder = AudioEncoder( @@ -190,6 +194,7 @@ class Whisper(nn.Module): self.dims.n_audio_state, self.dims.n_audio_head, self.dims.n_audio_layer, + dtype, ) self.decoder = TextDecoder( self.dims.n_vocab, @@ -197,6 +202,7 @@ class Whisper(nn.Module): self.dims.n_text_state, self.dims.n_text_head, self.dims.n_text_layer, + dtype, ) def embed_audio(self, mel): From 94705ed38bb5b07d7440f50e421d8e01404d15f1 Mon Sep 17 00:00:00 2001 From: bofenghuang Date: Tue, 12 Dec 2023 17:26:52 +0100 Subject: [PATCH 2/7] Add large v3 --- whisper/whisper/assets/mel_filters.npz | Bin 2048 -> 4271 bytes whisper/whisper/audio.py | 8 ++++---- whisper/whisper/decoding.py | 9 +++++++-- whisper/whisper/load_models.py | 6 ++++-- whisper/whisper/tokenizer.py | 27 ++++++++++++++++--------- whisper/whisper/torch_whisper.py | 9 +++++++-- whisper/whisper/transcribe.py | 9 +++++++-- whisper/whisper/whisper.py | 6 +++++- 8 files changed, 52 insertions(+), 22 deletions(-) diff --git a/whisper/whisper/assets/mel_filters.npz b/whisper/whisper/assets/mel_filters.npz index 1a7839244dfb6b1cc02e4f3cfe12e4817a073bc7..28ea26909dbdfd608aef67afc4d74d7961ae4bb6 100644 GIT binary patch literal 4271 zcmZ`-cQjmYw;lx1g6JcN7QKe3LG%_Oh!VX=^k~teM-XGQ(Mu4$_Y%?jkm$lFBkB+( z3yfKIgF zxGiAhze`A@t->QRNVV!%P+W=o}VHkB) z%g>qyRHfN1IQ4-=`Y@0T9qE#o+;4E3VQ!epW1Xt=ZG`I3U|62t?<>5h*W|9VvJc`KZ+)ghnA**Z~ET21Tjf_f8oe`vy zZQNtlOx?dDhS71hnOus5cqj)hfyF@H&4y?@9z{I#&cf>A+s2~~(I>TQF}SaR3_tqa z(7&ZdN^vR*t<~?{9DEoI>0PL@Sl?wa?Z{rGX`*eEx9Nh=z*J3HZL1*Py4z$TD#+;m zSSW(kcOTe(4hqgib_W6&xx+j~-u(p)Nn6?>a%wHk=h7Ay$%lcGoo;gAY zmVV7|!Nb;w(PlH@c24{ple2Y3<*9J@jE=sfLzwu_BiAFPE$0Axp`^Nq!H}eG0?r-X zFj@Pwp^al*p>K{@_Cz`q#(N0Y=OpZy^ z{P$KjLJuk_Y%I)$mh`b{uOW5C5Xcmxk!gt_Zg zw>}6fkD4zRK9!#ems~H%U$>V;_wK38Zf-baU$S!#i;7!HWsi}GuC>%@?lMdgkUGC& zh9gC?O-5BlS2#}?7x0?eP#bOL(cqE{M%LJD$CZnplD)CgQR#KCttD=dZK+Ck5R52; z*%5hZ+SXU7)8k%Y^_1U>yI*By(INn&+ir-_4$#dUwTlMNyR@iGQIaZ+eiYqucu)CB z#i{Ru1w+aU#}DHSyzjG_9c?ToB_YjU#f;N=qel98WBIjIc1!#ePwRR+(go&-by#}@ z+M+klVke5b@lWfZ+O&|c??YvRe)&W)qAgtc>t-IZtbRTG#X}49_Q$>P%-)=0W_QY-x%DPep2Vm9#ci zyQcCc4p2&dLtV1@rPe!%>Y^#9W8#ZH&}^@wJKT7N;R9A7cEq&;Y2CYvd@R+Mn&b5O zVyfS^*H#kD74=J5uhD)o`TXoX>>Si$!cT?TXRxj2pB)w_ljjhTby&Je;X|BESZZT= zC%G5!-$BJf&a~U78d_3zBjrvrkJ0CCl@Rfcf7I(`VTNPnI^B#B$zOfPW zG&mEd?R0+W<`l08O1dkcWKS8wB!Z*Cs%I1nMs-EeB-uu5?t@PuD3|z>je8DKi#X(B z{Z=Rz{4X%?-UnxnHQtkELIZ&=J;fK_t}yu8|IxG0(85e&K>H3!!~zlhyJrgti~o1i zzBS*jTgdG~Exp#B-T)6A+PB ztD-e`j^@XAx}|L&JSEFkRvS_%3b%m86z02#Hfn{Y+qIqQ_muywgt?roUA7oiS1xBD zFxmDMsj_cbBcn*^rn^KIMP{AlHM`NiVm*D&`z~7FH#hf<$L3HmJ+=NdiY5>W?nKD? z8Ox6{9dKyI1o8a-j9BtV-|=lm`<`v>tR^Cln&x1dMYzu{@wq5KW!#K14_QMnpH5K%Pavag+g6(i8i-#Eq zguc}rH3?BxH4SOqZW#7m*aT(U9-n#_Xn^Q19(}eH!xG`nI!GYziVQNcA0)`FDHD%~ zz2$HnxW4BQ{#*@u`dssbAa`|fESn$8i8FdxGZh48_Uf~_Q@tv?4in)6fwSed)k&ITqu|){^(WL~J z?Lb|0ro06J^>f>^2}^e-+$u5bU4IZNfO?75v8lstS15%XYw2ac^pkU34{QhDR(umt zPu~`w2?FP|nn3!RWZ3{?=77@teulahD9*S*k5KmY3*adlM)%{SR~bkZYlx1q@fkE= zI$7+kiw5!ha=dYlO>Z5KgxnZEJsaBm%v#nkX0MN-h%n&KA?N}xU3K3o-3Jpk?ANq2n9&Lh%K_CTvfiN ze>6w~NSSl8$#NEZ^t7h9YOxI=zcAG|a+m6AWei`3Jw7K;b;T${pJa^4RwRt%F>?>M zBmoQqm1`<_W7i!5P~THp-II)Ka^u;=z;}d{;SVj{G_4`9^HaEb!=@Pa;Dw)CH^DjsGxFqmb%o$Bkop$KnH8 zDYN)Bh)5=5!-*|f0Gh4)oZG=TEBr()g^DCtSQhmT3!ZN`Qd-E%@1cE}hm8&Vq5B+C zVF2_O)9IiZ(v(xzTwJIg5|}KVuE(;}|7dVIrT`$d=q_OG|3PY}x*URYkMXXJ6PT1$IFkNyvY_(9UglDi6TaeikPS(!Bnij z;Szn+)I_oxnRz7(WTYTp+IHSWQ?Xd~tQn(Q1r)kThM?NM< z?d6LaBG!H}R$zRy!Ij(}1?xe^+o+!;tqWJ3NgjHl1XNxzusxQ0I#6qzM(_00UPMw* zF*GWW_q&fqAN=uimSKgBu_@jD%MX3hpNY|*4r=e=k1lw2r**IyD(hcq?A+HtUgUy4Dqh5D7|G9q{)TsUj{g~c!xy>9wk^(LiXA4VKGz_zMvJMX#AgsR z34T3hhJ)#&sUaQ1+0PML(?YA~{5?=(MT}X^Vib%};uoI{qGW@wgJ&_M+8S8clsNz2 zPQkxMi`#3+Khwtl>>K>wxc{71{&!qGu&Zzz_wU(7TLTyG){PAu?!cXs?Dp-y0Ekcn AQvd(} delta 2007 zcmV;|2PpWjA%GAQP)h>@6aWAK2mk;8ApjrBSy*R7c85^~8Y5zVylPRV%1|^ltxNTi@h}3pFi!M?lnJJfIs2P`vlsTbr(zJd*gVYah_9J z7%VYZZs&g=pz{l}6V`SzaEP7C+Ac68Y*CnY!OV~_|A4=)f20l81vFmQ1!)%sG=8^t zc2HTX9US|sti!GUKWz;frSH8U4TwlD-7-TYPd~~|h!p6SMPqgQY<5DVU~k@O{Ig&y zJ0VigxOd693**@dk%Gi7HuA9RB6dQgU~itA+?szrk)04JIOpXrBaEEb36TQbvS_Id zGG`}53RV_xkZY3e;BDUr(Yq|MOPchUypx>}Dfl`dSE`&{*$I(?y?LkQ-B%IfZQlpc zyDY!u4?M#Y{SRByudov$1&cys#pgb!*$I&XmE&kc80E4PA_bR!NJGfso$Q22fpN#X z&<1~PU?)ThcGz~36T3&V6CwpC&N#_6Nq+2vNWqV@y=CwdS9U_A;Hkl4=`pzvJ0VgK z|6ROP=eJ-dL<+X_N|7q(LjDGlNWt@lEctmr3_BrG@RR*P`E11$c0#0Jc*g^>J|mc& z5GlC$V7DA-lFCkq6!?a2mX$_V*$I(?*vfyEa?@45YlTS3s6(D|cic#JLZskELRXnG zRmV<<6s&rfi<+t=**eylXA~kmbUEk6inFn@eZN|ELZl$yVYoc|U=(*RA_dn9iV%~% zTIStN;7lRXV|?9bVihLH`)#c_Q-~C(A5}s(QzxA(6FF0e^f-6ph{$k~a zL4D1nz)H44q(@b(IgGt-BQvhI3@^&TCpY}KhY>Yn{1$I9bwVc2w*H%Z5qbyTSq$TB zA<`qIa3YHO*5J>XrgFuCR4k}k$k{^Fj8+GBh|Mo{VA5Grd1C%8INtN-Y$1QrW96~| zNE=@SZDdDjoplHeS)JJnQ8S`!Hfb8%{80S!HB5MFASb1Mg~9fBxr-4s!*y1e(0vn* zx!PN>-17(hM_h0U|hKlA@i!)3sEylixPxYT`=bT zy#xgXRd{T^1@0Hkd43^kM&Wc`r1#t=0zy-uxp z>DV3bZ`}yrlpmpLaRND0bTFM=#$Je;QEzliboxvSpO^#aR#%RexnH89T#b{q6W9z< z6P~-uG~N-N@yKR1@{G^m>7{(c42eRYVeN3I`#!$=g{X=888O1}nHqntKdnd1sPm{f zmx(@Y0#Ih~SlBOjVKYQce7}2-xIB6Uu2*cp<<72X z_llb0VTd`o7M(1QVadpCI2f$Oipj0erK2U!Gek`cDq1fR7xzJH;}vkZl!Y1aDm zhsgexxL@-Hn;~jq#1(%RQFq@EeJgxmKPL%=5jhx=yc%OBx!_adTOvB{CHFI;H|%zr zqPcl4PwX6OhjODuFbdud_3SO^ay1Ag)xGfS)gDo+c4Rk1Z-`48AS%r&Md1Jq;_RZ} za3l@uYuBPSbRtgIw}*X)WRX0!(D`lOC(*m4b(>kjr>ItRIs<>aNDs&Fh1*bmAr=GG zW08H)1YX~66@fiwao!NULHf-Qhlc(ls!|>C$JYxGbvYR>IkC8z>4_1w77)u1i9f9j z`CNeL4f{Gch-9M@k!5I!Hqk!FGgytIaq%$Boq$1n;E@2nbt*r`=q!rl13cUd$J{Y87V=(HkNh28)}K{}#t{ zO^{$W0@G9G!_p=W1sPhnz2}4?|5n%-mo0`R_GCLmuf?k_JEupD6BXen#Q{rGEUni- zRs_QRLNpSBred8Kj14BOV0(PO@G2Z8qBE?yvk^6#w5os1IriR2;hB&t{^i{gLvO2a zEW{VCFC+2v_%sYZH3ZAfm>?{%MBE*q6PH4F@fjP@YrCGCi6>#nqQ^gOi)}ewvH$%K za5}>u<@W>e;MRwDRo)Z*w*M-|uFDWr2E)bAG40q7Q6trCi*+K4_lddg^BO mx.array: +def mel_filters(n_mels: int) -> mx.array: """ load the mel filterbank matrix for projecting STFT into a Mel spectrogram. Allows decoupling librosa dependency; saved using: @@ -89,9 +88,10 @@ def mel_filters(n_mels: int = N_MELS) -> mx.array: np.savez_compressed( "mel_filters.npz", mel_80=librosa.filters.mel(sr=16000, n_fft=400, n_mels=80), + mel_128=librosa.filters.mel(sr=16000, n_fft=400, n_mels=128), ) """ - assert n_mels == 80, f"Unsupported n_mels: {n_mels}" + assert n_mels in {80, 128}, f"Unsupported n_mels: {n_mels}" filename = os.path.join(os.path.dirname(__file__), "assets", "mel_filters.npz") return mx.load(filename)[f"mel_{n_mels}"] @@ -130,7 +130,7 @@ def stft(x, window, nperseg=256, noverlap=None, nfft=None, axis=-1, pad_mode="re def log_mel_spectrogram( audio: Union[str, np.ndarray], - n_mels: int = N_MELS, + n_mels: int = 80, padding: int = 0, ): """ diff --git a/whisper/whisper/decoding.py b/whisper/whisper/decoding.py index c4b6326d..dd0fa5e5 100644 --- a/whisper/whisper/decoding.py +++ b/whisper/whisper/decoding.py @@ -33,7 +33,9 @@ def detect_language( list of dictionaries containing the probability distribution over all languages. """ if tokenizer is None: - tokenizer = get_tokenizer(model.is_multilingual) + tokenizer = get_tokenizer( + model.is_multilingual, num_languages=model.num_languages + ) if ( tokenizer.language is None or tokenizer.language_token not in tokenizer.sot_sequence @@ -401,7 +403,10 @@ class DecodingTask: language = options.language or "en" tokenizer = get_tokenizer( - model.is_multilingual, language=language, task=options.task + model.is_multilingual, + num_languages=model.num_languages, + language=language, + task=options.task, ) self.tokenizer: Tokenizer = tokenizer self.options: DecodingOptions = self._verify_options(options) diff --git a/whisper/whisper/load_models.py b/whisper/whisper/load_models.py index 2d0ae578..2d9b1fea 100644 --- a/whisper/whisper/load_models.py +++ b/whisper/whisper/load_models.py @@ -25,7 +25,8 @@ _MODELS = { "medium": "https://openaipublic.azureedge.net/main/whisper/models/345ae4da62f9b3d59415adc60127b97c714f32e89e936602e85993674d08dcb1/medium.pt", "large-v1": "https://openaipublic.azureedge.net/main/whisper/models/e4b87e7e0bf463eb8e6956e646f1e277e901512310def2c24bf0e11bd3c28e9a/large-v1.pt", "large-v2": "https://openaipublic.azureedge.net/main/whisper/models/81f7c96c852ee8fc832187b0132e569d6c3065a3252ed18e56effd0b6a73e524/large-v2.pt", - "large": "https://openaipublic.azureedge.net/main/whisper/models/81f7c96c852ee8fc832187b0132e569d6c3065a3252ed18e56effd0b6a73e524/large-v2.pt", + "large-v3": "https://openaipublic.azureedge.net/main/whisper/models/e5b1a55b89c1367dacf97e3e19bfd829a01529dbfdeefa8caeb59b3f1b81dadb/large-v3.pt", + "large": "https://openaipublic.azureedge.net/main/whisper/models/e5b1a55b89c1367dacf97e3e19bfd829a01529dbfdeefa8caeb59b3f1b81dadb/large-v3.pt", } # base85-encoded (n_layers, n_heads) boolean arrays indicating the cross-attention heads that are @@ -41,7 +42,8 @@ _ALIGNMENT_HEADS = { "medium": b"ABzY8B0Jh+0{>%R7}kK1fFL7w6%<-Pf*t^=N)Qr&0RR9", "large-v1": b"ABzY8r9j$a0{>%R7#4sLmoOs{s)o3~84-RPdcFk!JR%R7=D0pU<_bnWW*tkYAhobTNnu$jnkEkXqp)j;w1Tzk)UH3X%SZd&fFZ2fC2yj", - "large": b"ABzY8zd+h!0{>%R7=D0pU<_bnWW*tkYAhobTNnu$jnkEkXqp)j;w1Tzk)UH3X%SZd&fFZ2fC2yj", + "large-v3": b"ABzY8gWO1E0{>%R7(9S+Kn!D~%ngiGaR?*L!iJG9p-nab0JQ=-{D1-g00", + "large": b"ABzY8gWO1E0{>%R7(9S+Kn!D~%ngiGaR?*L!iJG9p-nab0JQ=-{D1-g00" } diff --git a/whisper/whisper/tokenizer.py b/whisper/whisper/tokenizer.py index 5e345508..b589f764 100644 --- a/whisper/whisper/tokenizer.py +++ b/whisper/whisper/tokenizer.py @@ -109,6 +109,7 @@ LANGUAGES = { "ba": "bashkir", "jw": "javanese", "su": "sundanese", + "yue": "cantonese", } # language code lookup by name, with a few language aliases @@ -125,6 +126,7 @@ TO_LANGUAGE_CODE = { "moldovan": "ro", "sinhalese": "si", "castilian": "es", + "mandarin": "zh", } @@ -133,6 +135,7 @@ class Tokenizer: """A thin wrapper around `tiktoken` providing quick access to special tokens""" encoding: tiktoken.Encoding + num_languages: int language: Optional[str] = None task: Optional[str] = None sot_sequence: Tuple[int] = () @@ -147,7 +150,7 @@ class Tokenizer: translate: int = self.special_tokens["<|translate|>"] transcribe: int = self.special_tokens["<|transcribe|>"] - langs = tuple(LANGUAGES.keys()) + langs = tuple(LANGUAGES.keys())[: self.num_languages] sot_sequence = [sot] if self.language is not None: sot_sequence.append(sot + 1 + langs.index(self.language)) @@ -213,10 +216,13 @@ class Tokenizer: if self.language is None: raise ValueError("This tokenizer does not have language token configured") - if token := self.special_tokens.get(f"<|{self.language}|>", None): + return self.to_language_token(self.language) + + def to_language_token(self, language): + if token := self.special_tokens.get(f"<|{language}|>", None): return token - raise KeyError(f"Language {self.language} not found in tokenizer.") + raise KeyError(f"Language {language} not found in tokenizer.") @cached_property def all_language_tokens(self) -> Tuple[int]: @@ -224,7 +230,7 @@ class Tokenizer: for token, token_id in self.special_tokens.items(): if token.strip("<|>") in LANGUAGES: result.append(token_id) - return tuple(result) + return tuple(result)[: self.num_languages] @cached_property def all_language_codes(self) -> Tuple[str]: @@ -271,7 +277,7 @@ class Tokenizer: return tuple(sorted(result)) def split_to_word_tokens(self, tokens: List[int]): - if self.language in {"zh", "ja", "th", "lo", "my"}: + if self.language in {"zh", "ja", "th", "lo", "my", "yue"}: # These languages don't typically use spaces, so it is difficult to split words # without morpheme analysis. Here, we instead split words at any # position where the tokens are decoded as valid unicode points @@ -324,7 +330,7 @@ class Tokenizer: @lru_cache(maxsize=None) -def get_encoding(name: str = "gpt2"): +def get_encoding(name: str = "gpt2", num_languages: int = 99): vocab_path = os.path.join(os.path.dirname(__file__), "assets", f"{name}.tiktoken") with open(vocab_path) as fid: ranks = { @@ -337,7 +343,7 @@ def get_encoding(name: str = "gpt2"): specials = [ "<|endoftext|>", "<|startoftranscript|>", - *[f"<|{lang}|>" for lang in LANGUAGES.keys()], + *[f"<|{lang}|>" for lang in list(LANGUAGES.keys())[:num_languages]], "<|translate|>", "<|transcribe|>", "<|startoflm|>", @@ -364,6 +370,7 @@ def get_encoding(name: str = "gpt2"): def get_tokenizer( multilingual: bool, *, + num_languages: int = 99, language: Optional[str] = None, task: Optional[str] = None, # Literal["transcribe", "translate", None] ) -> Tokenizer: @@ -384,6 +391,8 @@ def get_tokenizer( language = None task = None - encoding = get_encoding(name=encoding_name) + encoding = get_encoding(name=encoding_name, num_languages=num_languages) - return Tokenizer(encoding=encoding, language=language, task=task) + return Tokenizer( + encoding=encoding, num_languages=num_languages, language=language, task=task + ) diff --git a/whisper/whisper/torch_whisper.py b/whisper/whisper/torch_whisper.py index 0ffcf302..3b5491e4 100644 --- a/whisper/whisper/torch_whisper.py +++ b/whisper/whisper/torch_whisper.py @@ -234,7 +234,8 @@ class Whisper(nn.Module): self.dims.n_text_head, self.dims.n_text_layer, ) - # use the last half layers for alignment by default; see `set_alignment_heads()` below + # use the last half among the decoder layers for time alignment by default; + # to use a specific set of heads, see `set_alignment_heads()` below. all_heads = torch.zeros( self.dims.n_text_layer, self.dims.n_text_head, dtype=torch.bool ) @@ -267,7 +268,11 @@ class Whisper(nn.Module): @property def is_multilingual(self): - return self.dims.n_vocab == 51865 + return self.dims.n_vocab >= 51865 + + @property + def num_languages(self): + return self.dims.n_vocab - 51765 - int(self.is_multilingual) def install_kv_cache_hooks(self, cache: Optional[dict] = None): """ diff --git a/whisper/whisper/transcribe.py b/whisper/whisper/transcribe.py index bfdc32b5..40b9ec40 100644 --- a/whisper/whisper/transcribe.py +++ b/whisper/whisper/transcribe.py @@ -119,7 +119,7 @@ def transcribe( dtype = mx.float16 if decode_options.get("fp16", False) else mx.float32 # Pad 30-seconds of silence to the input audio, for slicing - mel = log_mel_spectrogram(audio, padding=N_SAMPLES) + mel = log_mel_spectrogram(audio, n_mels=model.dims.n_mels, padding=N_SAMPLES) content_frames = mel.shape[-2] - N_FRAMES if verbose: @@ -150,7 +150,12 @@ def transcribe( language: str = decode_options["language"] task: str = decode_options.get("task", "transcribe") - tokenizer = get_tokenizer(model.is_multilingual, language=language, task=task) + tokenizer = get_tokenizer( + model.is_multilingual, + num_languages=model.num_languages, + language=language, + task=task, + ) def decode_with_fallback(segment: mx.array) -> DecodingResult: temperatures = ( diff --git a/whisper/whisper/whisper.py b/whisper/whisper/whisper.py index ec60c6ec..983819ef 100644 --- a/whisper/whisper/whisper.py +++ b/whisper/whisper/whisper.py @@ -210,7 +210,11 @@ class Whisper(nn.Module): @property def is_multilingual(self): - return self.dims.n_vocab == 51865 + return self.dims.n_vocab >= 51865 + + @property + def num_languages(self): + return self.dims.n_vocab - 51765 - int(self.is_multilingual) detect_language = detect_language_function decode = decode_function From f0c57c1361b4084976f4e39df3a6bc82d44ef40d Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Tue, 12 Dec 2023 12:48:15 -0800 Subject: [PATCH 3/7] llama v2 with sharded weights --- llama/README.md | 25 +++-- llama/convert.py | 78 +++++++++------- llama/llama.py | 206 ++++++++++++++++++++++++++--------------- llama/requirements.txt | 1 + mixtral/README.md | 2 +- 5 files changed, 189 insertions(+), 123 deletions(-) diff --git a/llama/README.md b/llama/README.md index b9f487dd..da4e85f3 100644 --- a/llama/README.md +++ b/llama/README.md @@ -1,8 +1,9 @@ -# LLaMA +# Llama -An example of generating text with LLaMA using MLX. +An example of generating text with Llama (1 or 2) using MLX. -LLaMA is a set of open source language models from Meta AI Research[^1] ranging from 7B to 65B parameters. +Llama is a set of open source language models from Meta AI Research[^1][^2] +ranging from 7B to 70B parameters. ### Setup @@ -14,27 +15,31 @@ pip install -r requirements.txt Next, download and convert the model. If you do not have access to the model weights you will need to [request -access](https://docs.google.com/forms/d/e/1FAIpQLSfqNECQnMkycAp2jP4Z9TFX0cGR4uf7b_fBxjY_OjhJILlKGA/viewform) +access](https://ai.meta.com/resources/models-and-libraries/llama-downloads/) from Meta. - -Alternatively, you can also download a select converted checkpoints from the [mlx-llama](https://huggingface.co/mlx-llama) community organisation on Hugging Face and skip the conversion step. +Alternatively, you can also download a select converted checkpoints from the +[mlx-llama](https://huggingface.co/mlx-llama) community organisation on Hugging +Face and skip the conversion step. Convert the weights with: ``` -python convert.py +python convert.py --model_path ``` +The conversion script will save the converted weights in the same location. + ### Run Once you've converted the weights to MLX format, you can interact with the -LLaMA model: +LlaMA model: ``` -python llama.py "hello" +python llama.py "hello" ``` Run `python llama.py --help` for more details. -[^1]: Refer to the [arXiv paper](https://arxiv.org/abs/2302.13971) and [blog post](https://ai.meta.com/blog/large-language-model-llama-meta-ai/) for more details. +[^1]: For Llama v1 refer to the [arXiv paper](https://arxiv.org/abs/2302.13971) and [blog post](https://ai.meta.com/blog/large-language-model-llama-meta-ai/) for more details. +[^2]: For Llama v2 refer to the [blob post](https://ai.meta.com/llama/) diff --git a/llama/convert.py b/llama/convert.py index 69168493..89ce8a36 100644 --- a/llama/convert.py +++ b/llama/convert.py @@ -1,53 +1,59 @@ # Copyright © 2023 Apple Inc. import argparse -from itertools import starmap +import collections +import glob +from pathlib import Path import numpy as np import torch +SHARD_FIRST = ["wv", "wq", "wk", "w1", "w3", "output"] +SHARD_SECOND = ["tok_embeddings", "wo", "w2"] +SHARD_WEIGHTS = set(SHARD_FIRST + SHARD_SECOND) -def map_torch_to_mlx(key, value): - if "tok_embedding" in key: - key = "embedding.weight" - elif "norm" in key: - key = key.replace("attention_norm", "norm1").replace("ffn_norm", "norm2") +def shard_key(k): + keys = k.split(".") + if len(keys) < 2: + return None + return keys[-2] - elif "wq" in key or "wk" in key or "wv" in key or "wo" in key: - key = key.replace("wq", "query_proj") - key = key.replace("wk", "key_proj") - key = key.replace("wv", "value_proj") - key = key.replace("wo", "out_proj") - elif "w1" in key or "w2" in key or "w3" in key: - # The FFN is a separate submodule in PyTorch - key = key.replace("feed_forward.w1", "linear1") - key = key.replace("feed_forward.w3", "linear2") - key = key.replace("feed_forward.w2", "linear3") - - elif "output" in key: - key = key.replace("output", "out_proj") - - elif "rope" in key: - return None, None - - return ( - key, - value.numpy() - if value.dtype != torch.bfloat16 - else value.to(torch.float32).numpy(), - ) +def unshard(k, v): + wn = shard_key(k) + if wn not in SHARD_WEIGHTS: + return v + elif wn in SHARD_FIRST: + axis = 0 + elif wn in SHARD_SECOND: + axis = 1 + else: + raise ValueError("Invalid weight name") + return np.concatenate(v, axis=axis) if __name__ == "__main__": parser = argparse.ArgumentParser(description="Convert Llama weights to MLX") - parser.add_argument("torch_weights") - parser.add_argument("output_file") + parser.add_argument( + "--model_path", + help="Path to the Torch model. The MLX weights will also be saved there.", + ) args = parser.parse_args() - state = torch.load(args.torch_weights, map_location=torch.device('cpu')) - np.savez( - args.output_file, - **{k: v for k, v in starmap(map_torch_to_mlx, state.items()) if k is not None} - ) + model_path = Path(args.model_path) + torch_files = glob.glob(str(model_path / "consolidated.*.pth")) + weights = collections.defaultdict(list) + for wf in torch_files: + state = torch.load(wf, map_location=torch.device("cpu")) + for k, v in state.items(): + v = v.to(torch.float16).numpy() + if shard_key(k) in SHARD_WEIGHTS: + weights[k].append(v) + else: + weights[k] = v + + out_file = str(model_path / "weights.npz") + for k, v in weights.items(): + weights[k] = unshard(k, v) + np.savez(out_file, **weights) diff --git a/llama/llama.py b/llama/llama.py index c18728ff..db9c8db3 100644 --- a/llama/llama.py +++ b/llama/llama.py @@ -1,8 +1,10 @@ # Copyright © 2023 Apple Inc. import argparse -import math -import numpy as np +from dataclasses import dataclass +import json +from pathlib import Path +from typing import Optional, Tuple, List from sentencepiece import SentencePieceProcessor import time @@ -11,33 +13,71 @@ import mlx.nn as nn from mlx.utils import tree_unflatten -class LlamaAttention(nn.Module): - def __init__(self, dims: int, num_heads: int): +@dataclass +class ModelArgs: + dim: int + n_layers: int + head_dim: int + hidden_dim: int + n_heads: int + n_kv_heads: int + norm_eps: float + vocab_size: int + + +class RMSNorm(nn.Module): + def __init__(self, dims: int, eps: float = 1e-5): super().__init__() + self.weight = mx.ones((dims,)) + self.eps = eps - self.num_heads = num_heads + def _norm(self, x): + return x * mx.rsqrt(x.square().mean(-1, keepdims=True) + self.eps) - self.rope = nn.RoPE(dims // num_heads, traditional=True) - self.query_proj = nn.Linear(dims, dims, bias=False) - self.key_proj = nn.Linear(dims, dims, bias=False) - self.value_proj = nn.Linear(dims, dims, bias=False) - self.out_proj = nn.Linear(dims, dims, bias=False) + def __call__(self, x): + output = self._norm(x.astype(mx.float32)).astype(x.dtype) + return self.weight * output - def __call__(self, queries, keys, values, mask=None, cache=None): - queries = self.query_proj(queries) - keys = self.key_proj(keys) - values = self.value_proj(values) - # Extract some shapes - num_heads = self.num_heads - B, L, D = queries.shape +class Attention(nn.Module): + def __init__(self, args: ModelArgs): + super().__init__() + self.args = args + + self.n_heads: int = args.n_heads + self.n_kv_heads: int = args.n_kv_heads + + self.repeats = self.n_heads // self.n_kv_heads + + self.scale = self.args.head_dim**-0.5 + + self.wq = nn.Linear(args.dim, args.n_heads * args.head_dim, bias=False) + self.wk = nn.Linear(args.dim, args.n_kv_heads * args.head_dim, bias=False) + self.wv = nn.Linear(args.dim, args.n_kv_heads * args.head_dim, bias=False) + self.wo = nn.Linear(args.n_heads * args.head_dim, args.dim, bias=False) + self.rope = nn.RoPE(args.head_dim, traditional=True) + + def __call__( + self, + x: mx.array, + mask: Optional[mx.array] = None, + cache: Optional[Tuple[mx.array, mx.array]] = None, + ) -> mx.array: + B, L, D = x.shape + + queries, keys, values = self.wq(x), self.wk(x), self.wv(x) # Prepare the queries, keys and values for the attention computation - queries = queries.reshape(B, L, num_heads, -1).transpose(0, 2, 1, 3) - keys = keys.reshape(B, L, num_heads, -1).transpose(0, 2, 1, 3) - values = values.reshape(B, L, num_heads, -1).transpose(0, 2, 1, 3) + queries = queries.reshape(B, L, self.n_heads, -1).transpose(0, 2, 1, 3) + keys = keys.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3) + values = values.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3) + + def repeat(a): + a = mx.concatenate([mx.expand_dims(a, 2)] * self.repeats, axis=2) + return a.reshape([B, self.n_heads, L, -1]) + + keys, values = map(repeat, (keys, values)) - # Add RoPE to the queries and keys and combine them with the cache if cache is not None: key_cache, value_cache = cache queries = self.rope(queries, offset=key_cache.shape[2]) @@ -48,86 +88,87 @@ class LlamaAttention(nn.Module): queries = self.rope(queries) keys = self.rope(keys) - # Finally perform the attention computation - scale = math.sqrt(1 / queries.shape[-1]) - scores = (queries * scale) @ keys.transpose(0, 1, 3, 2) + scores = (queries * self.scale) @ keys.transpose(0, 1, 3, 2) if mask is not None: - scores = scores + mask - scores = mx.softmax(scores, axis=-1) - values_hat = (scores @ values).transpose(0, 2, 1, 3).reshape(B, L, -1) - - # Note that we return the keys and values to possibly be used as a cache - return self.out_proj(values_hat), (keys, values) + scores += mask + scores = mx.softmax(scores.astype(mx.float32), axis=-1).astype(scores.dtype) + output = (scores @ values).transpose(0, 2, 1, 3).reshape(B, L, -1) + return self.wo(output), (keys, values) -class LlamaEncoderLayer(nn.Module): - def __init__(self, dims: int, mlp_dims: int, num_heads: int): +class FeedForward(nn.Module): + def __init__(self, args: ModelArgs): super().__init__() - self.attention = LlamaAttention(dims, num_heads) + self.w1 = nn.Linear(args.dim, args.hidden_dim, bias=False) + self.w2 = nn.Linear(args.hidden_dim, args.dim, bias=False) + self.w3 = nn.Linear(args.dim, args.hidden_dim, bias=False) - self.norm1 = nn.RMSNorm(dims) - self.norm2 = nn.RMSNorm(dims) + def __call__(self, x) -> mx.array: + return self.w2(nn.silu(self.w1(x)) * self.w3(x)) - self.linear1 = nn.Linear(dims, mlp_dims, bias=False) - self.linear2 = nn.Linear(dims, mlp_dims, bias=False) - self.linear3 = nn.Linear(mlp_dims, dims, bias=False) - def __call__(self, x, mask=None, cache=None): - y = self.norm1(x) - y, cache = self.attention(y, y, y, mask, cache) - x = x + y +class TransformerBlock(nn.Module): + def __init__(self, args: ModelArgs): + super().__init__() + self.n_heads = args.n_heads + self.dim = args.dim + self.attention = Attention(args) + self.feed_forward = FeedForward(args=args) + self.attention_norm = RMSNorm(args.dim, eps=args.norm_eps) + self.ffn_norm = RMSNorm(args.dim, eps=args.norm_eps) + self.args = args - y = self.norm2(x) - a = self.linear1(y) - b = self.linear2(y) - y = a * mx.sigmoid(a) * b - y = self.linear3(y) - x = x + y - - return x, cache + def __call__( + self, + x: mx.array, + mask: Optional[mx.array] = None, + cache: Optional[Tuple[mx.array, mx.array]] = None, + ) -> mx.array: + r, cache = self.attention(self.attention_norm(x), mask, cache) + h = x + r + r = self.feed_forward(self.ffn_norm(h)) + out = h + r + return out, cache class Llama(nn.Module): - def __init__( - self, num_layers: int, vocab_size: int, dims: int, mlp_dims: int, num_heads: int - ): + def __init__(self, args: ModelArgs): super().__init__() - - self.embedding = nn.Embedding(vocab_size, dims) - self.layers = [ - LlamaEncoderLayer(dims, mlp_dims, num_heads) for _ in range(num_layers) - ] - self.norm = nn.RMSNorm(dims) - self.out_proj = nn.Linear(dims, vocab_size, bias=False) + self.args = args + self.vocab_size = args.vocab_size + self.tok_embeddings = nn.Embedding(args.vocab_size, args.dim) + self.layers = [TransformerBlock(args=args) for _ in range(args.n_layers)] + self.norm = RMSNorm(args.dim, eps=args.norm_eps) + self.output = nn.Linear(args.dim, args.vocab_size, bias=False) def __call__(self, x): mask = nn.MultiHeadAttention.create_additive_causal_mask(x.shape[1]) - mask = mask.astype(self.embedding.weight.dtype) + mask = mask.astype(self.tok_embeddings.weight.dtype) - x = self.embedding(x) + x = self.tok_embeddings(x) for l in self.layers: x, _ = l(x, mask) x = self.norm(x) - return self.out_proj(x) + return self.output(x) def generate(self, x, temp=1.0): cache = [] # Make an additive causal mask. We will need that to process the prompt. mask = nn.MultiHeadAttention.create_additive_causal_mask(x.shape[1]) - mask = mask.astype(self.embedding.weight.dtype) + mask = mask.astype(self.tok_embeddings.weight.dtype) # First we process the prompt x the same was as in __call__ but # save the caches in cache - x = self.embedding(x) + x = self.tok_embeddings(x) for l in self.layers: x, c = l(x, mask=mask) # We store the per layer cache in a simple python list cache.append(c) x = self.norm(x) # We only care about the last logits that generate the next token - y = self.out_proj(x[:, -1]) + y = self.output(x[:, -1]) y = mx.random.categorical(y * (1 / temp)) # y now has size [1] @@ -145,14 +186,14 @@ class Llama(nn.Module): # dimension of 1 x = y[:, None] - x = self.embedding(x) + x = self.tok_embeddings(x) for i in range(len(cache)): # We are overwriting the arrays in the cache list. When # the computation will happen, MLX will be discarding the # old cache the moment it is not needed anymore. x, cache[i] = self.layers[i](x, mask=None, cache=cache[i]) x = self.norm(x) - y = self.out_proj(x[:, -1]) + y = self.output(x[:, -1]) y = mx.random.categorical(y * (1 / temp)) yield y @@ -261,20 +302,33 @@ def few_shot_generate(args): def load_model(model_path): - weights = mx.load(model_path) - mlp_dims, dims = weights["layers.0.linear1.weight"].shape - num_heads = dims // 128 - num_layers = max(int(l.split(".")[1]) for l in weights.keys() if "layers" in l) + 1 - vocab_size = weights["out_proj.weight"].shape[-1] - model = Llama(num_layers, vocab_size, dims, mlp_dims, num_heads) + model_path = Path(model_path) + weights = mx.load(str(model_path / "weights.npz")) + with open(model_path / "params.json", "r") as f: + config = json.loads(f.read()) + n_heads = config["n_heads"] + if "n_kv_heads" not in config: + config["n_kv_heads"] = n_heads + if "head_dim" not in config: + config["head_dim"] = config["dim"] // n_heads + if "hidden_dim" not in config: + config["hidden_dim"] = weights["layers.0.feed_forward.w1.weight"].shape[0] + if config.get("vocab_size", -1) < 0: + config["vocab_size"] = weights["output.weight"].shape[-1] + unused = ["multiple_of", "ffn_dim_multiplie"] + for k in unused: + if k in config: + config.pop(k) + model = Llama(ModelArgs(**config)) model.update(tree_unflatten(list(weights.items()))) - mx.eval(model.parameters()) return model if __name__ == "__main__": parser = argparse.ArgumentParser(description="Llama inference script") - parser.add_argument("model", help="The model file containing MLX weights") + parser.add_argument( + "model", help="Path to the model directory containing the MLX weights" + ) parser.add_argument("tokenizer", help="The sentencepiece tokenizer") parser.add_argument("prompt", help="The message to be processed by the model") parser.add_argument( diff --git a/llama/requirements.txt b/llama/requirements.txt index c036fa59..7111f1d4 100644 --- a/llama/requirements.txt +++ b/llama/requirements.txt @@ -1,2 +1,3 @@ +mlx sentencepiece torch diff --git a/mixtral/README.md b/mixtral/README.md index 23de1430..494e8107 100644 --- a/mixtral/README.md +++ b/mixtral/README.md @@ -46,7 +46,7 @@ rm mixtral-8x7b-32kseqlen/*.pth* As easy as: ``` -python mixtral.py --model_path mixtral mixtral-8x7b-32kseqlen/ +python mixtral.py --model_path mixtral-8x7b-32kseqlen/ ``` [^mixtral]: Refer to Mistral's [blog post](https://mistral.ai/news/mixtral-of-experts/) for more details. From e0a53edb465353468fcbcf4f221f471ef6758d82 Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Tue, 12 Dec 2023 13:32:05 -0800 Subject: [PATCH 4/7] llama v1 request --- llama/README.md | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/llama/README.md b/llama/README.md index da4e85f3..3ad882de 100644 --- a/llama/README.md +++ b/llama/README.md @@ -14,9 +14,11 @@ pip install -r requirements.txt ``` Next, download and convert the model. If you do not have access to the model -weights you will need to [request -access](https://ai.meta.com/resources/models-and-libraries/llama-downloads/) -from Meta. +weights you will need to request access from Meta: + +- [Request Llama v1](https://docs.google.com/forms/d/e/1FAIpQLSfqNECQnMkycAp2jP4Z9TFX0cGR4uf7b_fBxjY_OjhJILlKGA/viewform) +- [Request Llama v2](https://ai.meta.com/resources/models-and-libraries/llama-downloads/) + Alternatively, you can also download a select converted checkpoints from the [mlx-llama](https://huggingface.co/mlx-llama) community organisation on Hugging From 2206e8f7d9485d9cf4176e560b20bd4408ac97f6 Mon Sep 17 00:00:00 2001 From: Merrick Christensen Date: Tue, 12 Dec 2023 14:33:33 -0700 Subject: [PATCH 5/7] Update convert.py Docs are right, however, the code has a typo. --- mixtral/convert.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mixtral/convert.py b/mixtral/convert.py index a1a423d0..e67f4453 100644 --- a/mixtral/convert.py +++ b/mixtral/convert.py @@ -16,7 +16,7 @@ if __name__ == "__main__": ) args = parser.parse_args() model_path = Path(args.model_path) - state = torch.load(str(model_path / "consolidated.00.pt")) + state = torch.load(str(model_path / "consolidated.00.pth")) np.savez( str(model_path / "weights.npz"), **{k: v.to(torch.float16).numpy() for k, v in state.items()}, From 2e6a6c32aec7d2bf0e6f1808ac85c38fdc0bef32 Mon Sep 17 00:00:00 2001 From: Ashraful Islam Date: Tue, 12 Dec 2023 18:26:13 -0600 Subject: [PATCH 6/7] Update README.md updates readme with recently added examples --- README.md | 3 +++ 1 file changed, 3 insertions(+) diff --git a/README.md b/README.md index f60c140a..37c977ed 100644 --- a/README.md +++ b/README.md @@ -9,9 +9,12 @@ Some more useful examples include: - [Transformer language model](transformer_lm) training. - Large scale text generation with [LLaMA](llama) or [Mistral](mistral). +- Mixture-of-experts (MoE) language model with [Mixtral 8x7B](mixtral) - Parameter efficient fine-tuning with [LoRA](lora). - Generating images with [Stable Diffusion](stable_diffusion). - Speech recognition with [OpenAI's Whisper](whisper). +- Bidirectional language understanding with [BERT](bert) +- Semi-supervised learning on graph-structured data with [GCN](gcn). ## Contributing From a99e9d551e0ba7b994e97fb3efc01f0f40d621d8 Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Tue, 12 Dec 2023 17:08:04 -0800 Subject: [PATCH 7/7] hf correction --- bert/hf_model.py | 2 +- mixtral/README.md | 2 +- stable_diffusion/README.md | 4 ++-- stable_diffusion/stable_diffusion/model_io.py | 8 ++++---- stable_diffusion/stable_diffusion/tokenizer.py | 2 +- 5 files changed, 9 insertions(+), 9 deletions(-) diff --git a/bert/hf_model.py b/bert/hf_model.py index 4f07df13..e63b904b 100644 --- a/bert/hf_model.py +++ b/bert/hf_model.py @@ -25,7 +25,7 @@ def run(bert_model: str): if __name__ == "__main__": parser = argparse.ArgumentParser( - description="Run the BERT model using HuggingFace Transformers." + description="Run the BERT model using Hugging Face Transformers." ) parser.add_argument( "--bert-model", diff --git a/mixtral/README.md b/mixtral/README.md index 494e8107..417759e1 100644 --- a/mixtral/README.md +++ b/mixtral/README.md @@ -14,7 +14,7 @@ For example with Homebrew: brew install git-lfs ``` -Download the models from HuggingFace: +Download the models from Hugging Face: ``` git clone https://huggingface.co/someone13574/mixtral-8x7b-32kseqlen diff --git a/stable_diffusion/README.md b/stable_diffusion/README.md index e18a4d49..400a50f7 100644 --- a/stable_diffusion/README.md +++ b/stable_diffusion/README.md @@ -1,9 +1,9 @@ Stable Diffusion ================ -Stable Diffusion in MLX. The implementation was ported from Huggingface's +Stable Diffusion in MLX. The implementation was ported from Hugging Face's [diffusers](https://huggingface.co/docs/diffusers/index) and we are fetching -and using the weights available on the Huggingface Hub by Stability AI at +and using the weights available on the Hugging Face Hub by Stability AI at [stabilitiai/stable-diffusion-2-1](https://huggingface.co/stabilityai/stable-diffusion-2-1). ![out](generated-mlx.png) diff --git a/stable_diffusion/stable_diffusion/model_io.py b/stable_diffusion/stable_diffusion/model_io.py index 7eef4e28..c2669de4 100644 --- a/stable_diffusion/stable_diffusion/model_io.py +++ b/stable_diffusion/stable_diffusion/model_io.py @@ -169,7 +169,7 @@ def _check_key(key: str, part: str): def load_unet(key: str = _DEFAULT_MODEL, float16: bool = False): - """Load the stable diffusion UNet from Huggingface Hub.""" + """Load the stable diffusion UNet from Hugging Face Hub.""" _check_key(key, "load_unet") # Download the config and create the model @@ -199,7 +199,7 @@ def load_unet(key: str = _DEFAULT_MODEL, float16: bool = False): def load_text_encoder(key: str = _DEFAULT_MODEL, float16: bool = False): - """Load the stable diffusion text encoder from Huggingface Hub.""" + """Load the stable diffusion text encoder from Hugging Face Hub.""" _check_key(key, "load_text_encoder") # Download the config and create the model @@ -226,7 +226,7 @@ def load_text_encoder(key: str = _DEFAULT_MODEL, float16: bool = False): def load_autoencoder(key: str = _DEFAULT_MODEL, float16: bool = False): - """Load the stable diffusion autoencoder from Huggingface Hub.""" + """Load the stable diffusion autoencoder from Hugging Face Hub.""" _check_key(key, "load_autoencoder") # Download the config and create the model @@ -255,7 +255,7 @@ def load_autoencoder(key: str = _DEFAULT_MODEL, float16: bool = False): def load_diffusion_config(key: str = _DEFAULT_MODEL): - """Load the stable diffusion config from Huggingface Hub.""" + """Load the stable diffusion config from Hugging Face Hub.""" _check_key(key, "load_diffusion_config") diffusion_config = _MODELS[key]["diffusion_config"] diff --git a/stable_diffusion/stable_diffusion/tokenizer.py b/stable_diffusion/stable_diffusion/tokenizer.py index 07375fc7..ae9b967a 100644 --- a/stable_diffusion/stable_diffusion/tokenizer.py +++ b/stable_diffusion/stable_diffusion/tokenizer.py @@ -81,7 +81,7 @@ class Tokenizer: if isinstance(text, list): return [self.tokenize(t, prepend_bos, append_eos) for t in text] - # Lower case cleanup and split according to self.pat. Huggingface does + # Lower case cleanup and split according to self.pat. Hugging Face does # a much more thorough job here but this should suffice for 95% of # cases. clean_text = regex.sub(r"\s+", " ", text.lower())