+ glorot_uniform() (in module mlx.nn.init)
+
grad() (in module mlx.core)
greater() (in module mlx.core)
@@ -903,10 +934,14 @@ document.write(`
H
@@ -916,7 +951,11 @@ document.write(`
@@ -1143,6 +1186,8 @@ document.write(`
Q
@@ -1244,8 +1289,6 @@ document.write(`
SiLU (class in mlx.nn)
silu (class in mlx.nn)
-
- simplify() (in module mlx.core)
sin() (in module mlx.core)
@@ -1255,15 +1298,19 @@ document.write(`
sinh() (in module mlx.core)
-
-
+
FFT
Linear Algebra
Neural Networks
@@ -432,6 +435,7 @@
mlx.nn.prelu
mlx.nn.relu
mlx.nn.selu
+mlx.nn.softshrink
mlx.nn.silu
mlx.nn.step
@@ -440,6 +444,7 @@
mlx.nn.losses.binary_cross_entropy
mlx.nn.losses.cosine_similarity_loss
mlx.nn.losses.cross_entropy
+mlx.nn.losses.gaussian_nll_loss
mlx.nn.losses.hinge_loss
mlx.nn.losses.huber_loss
mlx.nn.losses.kl_div_loss
@@ -451,14 +456,26 @@
mlx.nn.losses.triplet_loss
+Initializers
-Optimizers
+
+Optimizers
-Tree Utils
@@ -441,6 +445,7 @@
mlx.nn.losses.binary_cross_entropy
mlx.nn.losses.cosine_similarity_loss
mlx.nn.losses.cross_entropy
+mlx.nn.losses.gaussian_nll_loss
mlx.nn.losses.hinge_loss
mlx.nn.losses.huber_loss
mlx.nn.losses.kl_div_loss
@@ -452,14 +457,26 @@
mlx.nn.losses.triplet_loss
+Initializers
-Optimizers
+
+Optimizers
-Tree Utils
+Tree Utils
mlx.utils.tree_flatten
mlx.utils.tree_unflatten
mlx.utils.tree_map
diff --git a/docs/build/html/objects.inv b/docs/build/html/objects.inv
index 965ee3814bb7567692a48763f88ff8a1c56ecb04..b6a3290925121465eb5acd9e9c73bbfb9a372aa0 100644
GIT binary patch
delta 7505
zcmV-X9j@Y~H-S5lIRh~+Fp)ewf60>EHWIz#D;Qy4OR`$GxyhF7afD>cYB@YN2D*r@
zg2hc-THXEi1At3mNf4PV9UN9C@ZL+HNbE#b<>QNE(}@>v#e+PGtbO{=vE+T9y;A=i
z&Q1L?=i|`yV^#6)>1E|+e{mkF@`oR@1M?H|T-I{Pb4^$YMO5r(9G+T&e`7j#jAf5o
z5x(1@6TCvZHD!~Ji~$nSd1azEC`K{j=YM`eddih@K^6c9;Sl1R!DCOMW(FNjPA
z1JTt_B&PZaBI$U2qIqNxAl32M(0(w3h*F*DTxXc9mo~A)-`gI#J?j_PxijNT@1J2yAX6skz$Jf%&oD~S$nVW+gIVIm{rUS<#sA8{nLCMc_-+U8WkrU$vuTKN9l&9ls&oF6Dm%s
zW^8>sPdrJy3ouIkz&k3bg2d=@p(>j;(Z0X{(z~wV#ZirK@}Z%y+BksrC?v^Aogbnp
zWlN$_`5~I7KJr1->OT_O1qPEoyef!%3K5~YAyeJNsG9PDe_Bq10qJ@gQxB7zi7e2h
zyAbz$EQh>U7ca4dE^>pqU)TAl)O~D$uF!8@J`{K1@88gxZ5alx2w$ER-wCHrS3M0y{FV0h2B->ts?I#akRiCAJiy}L^nY|vZm>1
zR06tsO$hqDSFZ%
zMCXst(4W)`3Zj&1;Cy;fHyvq)VZq|0U%MTYUpttvZ#iv#wO}B+hSm`2{2>`c;YC3l%C8`pP#$@i^P2eMt*{N!x*7J6D%b`oty90C
zV9Lur4}med$5Zj48cFTQgaD}?X*`cGiLAzoNTeYUnV@w|k$ooO83=fHNm@=|^g
zR4xO82uI?G)&Kyame#YiV8^!Wz|Pibon<2w%>{3Q?pd_lSh8G_ik#jU`dEOzQ$0>o
zy|@!y-KYtXq|p1=IKAEsI5nL-(guYgIxcY
zf5iHKvL1FY&&W7+FGR6}SHm_)mz`ZIYh-hGP9D0^>^~Lx?wa-~;$mYh?}-x=7YR{2
zM<~#=TgN9l(F#thYl8r()`JLmuGk@j_2dPbDst*#J#T>~ifrijoHjVn`6D#>=``L`
zxeO>q*zi-@R2l#hHEpgJ*rlKiY8)U$Jl+i*Khij73`j?7e;~kF
zbN`LE{E;^ETQC6Kz-bddoj*d8^RlT=Idw|!K|%7GMp(lnXCe#K;7lE3L0e&4K@g#R
zFb++)aVT!+B#lbKBqfohglR@1%?Qvu@NyJ-(KDxRySYe^x})(y3l$Ill9n=6c8H|n
zG!~T~p!s$dw8q{50HSBn(axqBe?W|gyux%rxs-MN;>~;cvjjeIssXMgdqDCtACBig
z<{24|XoUo1zpO_UaZJAe2Nh@Y4Uy(1qpDodS@Z7`ExLtfyrCTN@MD{M6om^5&!SDC~
znl?JB(C{MvCgkZn>`}1Le=&vK?k7C2d8Ky@_Od(hm`2Mxl}-KpI|MG+7`s{R+FkcH
zI;Nn;>2k+QYG9k3{H_%FHMIk=Y%0!uZTP(wLI>(}=k&4O`{pjlD05rPpQ8|Y4Q`ut
z|6}jsLlwOt{a?Y0H~M(xvTcd%Ap>kc|M8Qc40TNQfaI@byO25#fAD|-@uv^Bv?0n2
z2#H9g1Pl48`S^}Np>l&1FY=?B#q_eyhqD-XKJOCpLCtjUM)i1pH)b=_{SMr>q`~kl
zG#I`WUq}4_P9q5<(TWM_B*`NQc_izxZ)CyCk4~)1L0dEv^APj??anzOuQ6f-C3f$6uhFkw$`
zPLFGGdR&VG$og$h54AmaycfYL_#w6o?yj2qTxtk9a5Xbm|Ac062kI0qoDBks4^26m
zr3!xw?LgT>bGPAAGB;G2n^5K^lqu$2v|#(~K-z!ve>U+FDx|V_n{&Y~Uqb8k4z1Uq
z*6UD9tlM+I-rj+@xfkpb`fBC00_Cq@9VlB_?rxZ`Fb&h?vuOvQ-R1*vyE)jm-c|&5
zBoD+b$zX4DTM^#2Jka+op>5D@X<&QvK-}L9HcGb@gSE;7X|d8<@$|LYb$uxp%L0>k
z>ft~*fA-*x73yr()?$i9=B=0PS}w8Z_OgLw{Z6tTB$GSI
zet;DSr35#sk(d+={(l+g&ddX)!?s!x1x1t<4CZGWT+v2_F$D!%y
z#1n3dd8r8+Z>kNHU;|@yyQ!Fd_J$(LUi}8JM?8Qp_bNMHcbj78a
z^+!rW>>S2-(zsnD_7uT@qZ(>Ls!w{#iK%N%SW9$tz^fBij_
zSslMnD9}NJ4io5Shv=6N%Epdu`?6_!Ls&FNO*fkiPxfakx@py2XU&(atG1o_2ldB6
z=deHpA9$`qeDq)?Uy>&V{r6$Yeqe67Y!`tNXK%%WDD^qY>}_)#D>XecF0koKQ6B}&>gJg_-+#zf(}k}pTaZK9C!ATsIGEMD?4#(8qE%Dg
zn(1pip&GQV>iqTB&kT4WNbk*(aIo;dF7Nxg!qKynsAXMf;>NE<`V>p+C#3pfHh_
zq0-Bo!PD119hQXo@tU>=r{{3mKFp`C8+LeJ&NCSgQ^wC{vC#aY4Xmy=!9Qkyh_O>Y
z;vZu8)^uMq;4%S;Rn(ecBB|@c1#+pedaT+f=y@^;i@x5|Q%|7&f4x-4`Nvp&d@>L7
zpTk*I;;;IM{Nz~Bt<3(cJfSX8kaUU+kk{bALEQbWro7@-u6OHk1)e=xID98tW`99A#0EesJ}6t8fJMw)p>hKLtR&{;#3W}XqFb0!L%
zL8UW8gc_wNE%(D@o)M#SCJLQFrF4<>4HN{WaNlq`}7Cy7^D<49(n
zks%_fl5~KlK^qx5z$!(}z2`9llNcrtD@Am@HYoG0EFEWxe;~B131qH~VWPMaWj;ef
zFFrF|Oje>8(6=sjhKt%t6%4z>#m;bXU8#cg6D(Mxem(Nr?Zt}Cz$Au_4ZA}bc`q0s@wx3PCJR=0?iU-p-z0ZGi(@IqGS;*Xn`%KzOS1mx?I+sg?5IFW=m9D1K?&r
z0>ed;B}y(gf8l4Lona&5(j{NDG0|5oqeRT5OIx?FXR%arfyPE-^e
zZkTie!v*Q3N=BMOWIJdg|I!rKu)i6Qz>r~niJI%B*cp((NHKtk%HHj_Su~m9;{}sl
zxZHA^!D$Q|SC}fj-in(+DGVEXm@2zoO_@O{3>%%8e<&NTW^^U@DQ#yJC%7$s&iRouDOGK-}$T!d$$_;R&l7TOs${xeO|3S^Z*DGVJT8m${dmCnE<
zmW(Bhe^y)rPiKyaCBsbP6_%LPnPXzu(9?J!FI;s7r7%Qzsh4ff0;3oPklE?_bw8{ECv~}ht
zGGx?kg0>U&CIb>!GNLwKZHeTaIVOgR*-cVee}jByftewKdK0wPK&Y8#WQd606rC!V
z(@jQ}jt7pGfx1jYXQ_jWqb26Z;+bt>h#=w=g(c>A<{4Qs#yDQ(A*`u~WJM8{gu-Q@
z9rNmXHB-f^lRHsDNXGgBkE>fc%u1Y(4op!NG?U1$F(W|wK
z)oX{VSaQ8asufqO^h%A@DY8oGH5#i>OnuI$w
zm4m}IJBcNu)#DY%U}IaGSTc}3T4DXBe<^cJ3=@r>AY$?dI>yTIf$yn;;PQ==wk^J1
zEQr;0?i0)cb3(1I33Y06`#Rc|5HZ2((IRV@`pmVZL~6oDp2;1E)T!+|l|`Irg_!`M
zlrg8&^yL^NwFwZYP3CR2e;@E;e3jSJP?l}_QuO%2z0T|jXj58&_|9G+fEwm)f1SS>
z(6_+gMK2g%^^$^;{n}AQzjm8fJmDZxuIy>gc3JN1zr@C6^0cNYEAZ}ppFKuzmua#
zE|^esNx{k1%|Mo|n;{~rQgE__fAQ@F5)n;}bm`}aVkd}@XNTtK=2$f=pn2vQQ#y|X
zk;g#cF-b2g!iMmiKa|ndgcRl26L8=uebk9^$SFMDlswp^GSKR3${)%g>Z*AVCVnMa
zYXc?bQd0~3
zFQ+qmI@*+0m?9ZcWG+z*v(TOp94RwlliL>>le6HKFa~RbAy}IlfziLrh<%oU{3e5R
z9uYW+m_CC@-o=dUVd8c$*Lw(aV(;ezBkD`?wilfm_cYauJJHq6Se6p)g%9++@RLzq
z=+hRbrjwNOpCJ_X#L9s3f8{h*`omnJ8xsolU`j)Y>wrQ;Wd(s{RRT(Lj$A>Mwc=J}
za2iTn|4-HfRegrR@^d0ezZasQifat4tW$8Zq$8R(C?cFbq&epty3rhND)N25yN`$c
z@pGYETf&_XJ+gzHWRb*#KK=B0^xQf2#Ca!L!3mFm$sTBFV@ij1f7BX2&5k<9PCLO4
zon1$qTBn^?hfb;^&Zz5PkZ28^*^|(wv_gJ5jrSxOg8;OQ9U<8Bdoe|??`gUg5R$i&
z(_TE2q@Mx3XM{+8Y^qk@&_Zb6eIO)lB`3TP`Vil)1A-yc_aw-G_k}o(3!xjnV5W7w
zJXMV({3ug?Uv;u3f2C@QP92KWDFuo^$e*S^Q~1xJK;j6izF3s9I5y|^btY>{iqLQa
zU)85Ssi%#qpy3o^97XDXfbIv;`_MjW0Y7KKe8{5mm{`38p*2Ep=35g|13!Mj3+ddx
zjJBmjn%3MuD+OUj;h+LhD=i^D-VGf;5_D&PFFBi2YJ?U%e^@+CesR=)<1K$AEWa%P
zUuZUOs|}p60;e;3D%zA*$a&e+r<}AN?BM`iD>dapO^~evsX9r2rtqIbfkaO+jqeJ=
zCfJJbwHA9qko?(M?1MApr)EPZ2~si;kxBZJ!e530Wl~|JD2$=PxUJyb502FWt_RN|
zi7EY#Ald~=e~G2@hkx-qpDM3U|8HWq^Fd!pjOJgiBF-Gsw!*iwAdDPLZcoG5wnd&r
zNBDBf2o7U8A4H
zQW^f{e7XZ}^@nLl@rtfo#bhh2=07_`|B0&p!_ZgtfA7$>=YByFLtg1m)P3*l9eN3j
za$P=A*uBL+i={I1js1K5xLo(2ZDwKkYyOF@*ueoZIE~>V0uY6l7y?;nXZUCV7|}?K
zfecP#xTpd|VW;zV?Tiv{fGF$6Dac}}%#Kt5NLaTOzOgSU^zP($LEer0*<1B@_F4bU
z;uqyZe>{=v^ueq1FBixIOzS5;`cir4vT-1@bql$F2l;x(H2|CM%nokMMnr=I*%HAc
zGuy(D5litxZjjT=w=zr&V7%Hd-hB$IiD7~O({ye^otu=-67&}>WJeCpTpL5i0w#-j
z1zW8Q84#E(_!<;^of5o6|3wSBVT3c^nh^Z$f8G1P@me~!Bclxr8K9RSC-fg!{5;$_{4zL{@jn9$AywM#(P%rPZ&K1lGEk&WDrhBl-WrjYNvh0m5E
zP_CktKv2$yCZB>r^!uE!2%`R2d2p^B&$28!QIlO-fRHOo8qS5tB@Ss2ByHd5u+4BU
ze@p#gq8{YHqe}`La>dPobgGMNC}Cir6P>-cw9Qfqe!u$>wZV`hb~@6Xh8OuaAy4OF
zhdUM+bjM1^xwE%oHtK{_9W}4?$<&=*IB3WhCj-u-mg_JugXrk)T%eFImsFf1<_OK+
zC&dOs&e-WVcZQBjdV^|EkOL-xpv><|f01AD+dw2yzr1_|H%S9g4;XbPJt&=XtC{$|
zrgi~BE-YzC7uo_|$il!td6PmB-`$+J@VGPB21S3d^IyS>H+pd-@@)hf>$(#pCw%$%
zNl>OOCcC3(ED4dnmQM|+qihf82zzM>7e0NsrA;ShfFn8EQc|QE4Js5-pP!nKfA0ti
zDz~ek3<-r>(q#75pUmE-C$qQW>!_bl%4qn4an)9HN>h?bCH+##bg5)L_Khrf`LXHE
z$4b0>S05YNLc|t}=TSUk%F|1LalBt{Sn(_{r!;l!alG0RaXhmO357Mh(J&thM-
zthmsLH3_M|uPQD;&;vphY+N>kfraw^twUQZyDC3c^Z07tSLe|R->&UTxYC8~B+4iSXN^Ca?KR=-slU#uBRO!juzN(r}eC&J4
z>%1<@Yil^*g1ikfQ9VAZfuFe6)e8ya<0X~s+;}IttDRd(Kuc|?_RQu1{j)Pry{!7H--v|*@@+|@y7e8}l&z*@
zZwt~_O*5#9zWQ-KUBsy#_n?o^e&s-&nAR%^DG{e9?nSU>k
zVOctZ6fcPrbX{lbe=l%-d{mz#1>sBL0v96tTmQ{0PfM<^o<48(Ea+QfGDk+=&n9!CwP(HH-SOb-4IHQf_b95=^AG20F^>hI~#?1TA}e@(dMCCuR;H=4r9>3)Mz};`agT#saFKx(fP43@DecOBN
z`3uP+&*;VjLkT?$AusT-z)#26YklCDHIO{h^yZD3PgaL`|FI1ielVuPg7Ns))+2m<
boc6=KK5GlIeskxk$DV72KCu4-w_y;&Igyvu
delta 7025
zcmV-%8;<0GJEb>}IRh{*Igva(e@TjN$}?S?^BU=9e<<#
z*fX(}t(r>O=kg#yZtBlg#Ze~_3i9YQ&h
zRsiq5Z$(z3)0&D0aSkLFM9dyKjq*5Dr(V`o3`Z4*K=PaflIIXeQIJ520tPu1bti~O
zV+;gWT}epQ6#{9q>P*XMi~yldr<(SIjS)ep3mxkMgLSeZVyTEA)P;@}$g9t&d8vpX
z)P;@(=XEZ#XBw7{2tZMre_ltxP)*kMas;6Pu5mDISGPS8QHOwFZFWU>q+{t2%B!y$
zr)ld(wEYza^)(1a@00hFn>%I-RG@R7d{ggAf`fb5$v1V3B)Dp)vg`doLJZ#%i+j2=
zsGxznV{~`R3kC_a3%nRi?zvB=Aci`1?aP-GY1ekXo1=Jcs5z>Te;}OfMM=feF+sqh
zK2z~j$R6CA<^bTEW)D)u9cafjFbI9%=GnV*N|BN88yyn{q~C{oH(`~_l74;bs34?*
z%Bd;pm>^&)PEr*(W#oQ$i`)V4E*ZTWn#01d^-z&pVGIgGc7IQgeiEb(3PP?`r~Ul<
zRqyCmXt#5PcJk#&e~Kz94A~Tt+>&EZ81lexv;WilJ!Gw*L+Yq7=sNti!T`{Rw$AdCdcR5gnu2P^0k9__
zNlEJR5KJK(5{xPj!PM0$>qVvhBcW|!FzC~poQS6)0#J8kf2g|p8}Na2*Y)
zgF!At6ll<$i^psz`ZQlBFR_4jrlut6J!Hp3x
zR9WXls1kf_B6m!oE=jbHB$)OQ29ebh5l2NpP*22qJ;Bv_CUs@-a6wtKDo@p$a7P4N
z!XVHq>0(mFe*uuk$6o(O0h&X?V0Q=5?!s`@0k|pvClAzBhAS(N)K>0PS01Ua+^cYV
zq{i(|mD@vgZuctP9;$V_Q|%+Q-l_DVI`38aP>q8XF0x*|?~vf85D=`cTN;!K2mz(^
zhn^V+!49+qD_tIdIky?<`Kil;Fhad%wlt(MBn(C$e=Asl0I+IF^-|K47S)y*2(GGV
zRUQK%AT^dJy)=vgAW%hHZyEz&AaN!qHN}8%YK%#Sn_>XCzn)0%BC0q5QizW95~9lo
zVCYxs76n1dDl5*fs=6h;%Nt=ZXxC)4fyWX9!PT^WLzf3(&NP%W4D#k*M3^`K3IH`!
zpaurIe!
zW!9Gi>9(IDK&WRL&NB=m%b_F!=@1A^(4HiL1sLMy0C?F&S;-ZfLLjjENvK+r`u-&b
zf*XonHU;UOVu=CZ)FU+OC2az1i2>ohAD9)|f1&`$x)P+e)8zv&X(7J}D%KbT0S?3$
z&@li2YG_Zk27Is`+R3uFjb&LOL}S6LgGXlbcINUHSDw;4LpK(HZ`Bj0E?+%}wyM?m
zMN;71*f6}FB|q1#Jkj3a{OB;nc_<&m5%eiA4E#-851^{3FrI2cPgeZ_
zf8;qb4BZKlAHdaNJD|&wClxiY>N&gnb{H3C^7L>^yAWZrp^}fpafBTSL0fwZ(6D>^
z3R=;qp*syuM}Sc43Oybxv^u|TNkBsdwj-5zs(^mF_O1y&UNB0hsiB9v-P!V^9dN
z%N|89s;+K3GOiIDfPbN4zJ!3%Q(ZRde^>8ly@neL!3(`kcXXxCkB(5#-|hTph(eyr
zTGHBP3W32YB2opyl&W8Sywaei7!Zz_6r2-nOLkzNcvnC6#1Y^e0)o*xn!rwH^20Cm
z)T4HMp?+`O+k*!R?CVhuAzr9hIs^jiv(q;clnwxa>L%6G>r&9(k*q)nczEdBf9yoV
zoMS*ZT7Lr$!n^NTlReW0l1mH#*JreGq%I$TNwcD^&M9?1SX=;B3Al!rb{V@7ubmhaQhJ2ZVqqDXk$P
zu(x!uw?5d1@6_HN*q^^~1;i;&F6oA3}{@8`{;hhInAVRq9Qre>w>oo*Csq
ze$1|-rQxk1KA5b;)NLY>#_}G_Cpqm7#ewM3K17`lg@8fI(H97Ll{NFqgS$ODL~n%c
zVy1naPSo8&cmpBCbh^kQZ$(9CSRzBn%brIXs=xDb^F7K!?<~>J;Q0wh!?xptEbTHd
zA)>eb=w$O2B6~dkYnstPe}+2C)9*r_FZ~e)BRa%zIQ>ZGRaWW+h@Kq$M6^Rw3fE(*Z|J#=UPP#wK<
zm&hP;lh2=h5OK%2ebnQZy&dnf==JUY6IuRFZ?0TcL6J{LFV=(qfBZ#IUUkOe9>{-|
z)ko@kfD`BefBAGzdqtT8d>~RS!AgFqKYt)VsL~+BtMsHsF`cZ^{v!G;ofjRmUX64g
z26cJ{e9e+ba3v>*ErAM=<)rac0ugHH%O
zy!D9K;~-La3{ir6+Rt$B{tWlApTX`?5`33p_DH(~*3F;bfA4;?^VlP>@8smjt!(xQ
zF)HsqOm4W*f&L@P4|WF0t-V30yV*;p#NEPv3Z;89m$93f4(C;s-E3f7Qh>j!N9w@`N&7ZqK;KXzr*qcRtOXPm|9R=H5(uLYiuK=IbLYq`Els
z@ZN-egzN1AuD2l9Tc1m;Gx5MQenOnb_Z9?PwqhD=x`Wk(GHY#5e7g+OPD1%I_5;wu
zQbJr=@|LLf8G*%zgt+?PEo1C6!YdRBeTf2E@z|#Ze^xRQ;!=jUCbG}yO-m%Cc?oB7
z!{ut$p=2tSv81!WnGjYuttCJ_yH*`Q3!q9rfYat>n=3
zR$$J|2Q`wXLSXQK+EM<|kvcIOh5pL_RvrT-L{z-Fm1IT;64YlspL@BP#>5GcmEFykzycAXsEXoNiUlG>(Uo=G
ztBT(I)7U^Ey1tc7A^Ka%Oc;-92aXp@em&{yY5#)6hJZkB@f1&}C=wdPyZ+hb!~rR7WLaRKgwxS>(LCh@!ADq@HhB
zPIVefB~z(6Mh_9c%l^VorT#v`ip|B-e{1oqxAc#*z8i!JN_0>%1^r|J{rpMU*r{n=
z*G)G9=JiR#&888D|IvtcnoZV4)1~OjZ5IAP{gLPrMo^yhS*lNi@4!ah+!KQSJDG|n
z%&wQ2B2Z)gUOb6HZw2D->(fxGVaJ#Po4&p4#)sGRkz61ZzucVD5We{cGqY%4Zb*(mYp%{5sZ^OZ8KU`?0Qw8E86PhmLWdB$lhm`nw~Uc>_P
zvo^4*+714Z|1O4B{fNJd{(Iei(*fHXf2g6>0uxA8C0oeWV%1PKSLk{g9Sgr+c2ZZM
z{=L*Q@b{tod>t?5Kl+Pmf5cz(hRrE4?+G#Ivpyt#G#?M!+esj3eR)}97Yj3fn8+{o
zc|@~+ID{sgc`sunJQpI-1!Rg=
zX%rE|Gp*`j+%YnQejbSjxEi#Pp=I?bZ1y2$9CR^^{2qmLJ;Z{$Ru)$mArLLYdv4np
zM!Szl^A#d`^_Ah2fB1-Gz_WF+Gn`f*mC%12E_Q}f^P>`0SD3Iu{d(lP+ffMQpo^gu
z0S_<(cm3vKXxTt2a|gU}nl6Tu7bMa`5AoT~u;PP6$~;)n99vI)7%h!-If<70c81h2
zBr>kDha7M)oPHpYayjOi`*w!aL8Mc@X+orL8b;Acq*EJ5e+=`Gm*G?uiR7XF3E4S~
zrvwujMFw|F+QD$rj8w`%SBUHfN4-ZH<0|&Z0S7~hI}$n96JI&tU?k;7BC~U}BM-V6
zUL}&8!ghoshkXpICP}4Uk9_2ihhY^bskG}sX&mw}tR^LqHr&kUM()Tra2L3NJKPOC
z%3)F&UjnsMe-4I~zNC__N9b|L!;t!zRMyaSguaZXqxq<;>*;(PbTORJCY5%8`(ry<_SK^LPabyCT#L)Li6%W%4$MDpdJG49(LR`rucX^5(D$ivXGplDuCOO%5y
z7F7^MGp=%@+%mDKxG0`sQ6lA*iD3mt@kEYDDTh1^f1yiCVO&1L;*OCaL`z9LcM#7V
zf@jn~r7*1msN6Hh*g67H=W%pV_zMO9;uA#XU7gdGv0T70%RJy@1PN6#e+Mm{*-nPk
zS|xB>k`5ejFoYm0o^KPeCWQmuI}BUEDR&~
zil#8tYq@JhQQc9w?qISz6kA1Qt6;W@p>$+PY!YrriE2qnZb}(wONnbtNp4LUXikZ2
z&z9VvZJZqGlngY5bJLCHo|#3JqVY_%*STv`jNwcmID7bK7FU<1
z5X@TC+_uF@8WKqZkTg+}I||30PjSa^e`3`X8bjWi`*wzvvZhd4^4A>lFs#@$nRXG#
z=8%UW^{~mT-aF^!U^vBW3MIG!2D$ZRR-VPO8Dg{CF~)eF2_(-R!!yH3aZ@N(g=_9w
zSyZbV&*SgciSN{*?9uUe=S27AEP~(!3aG6nD?>?tD=c(TDCmX0uY>r_@tbe#v?j3lv+
z&H~hC=Rr3k>abJua3$RNP638je{H9xVQ9T`*vF9K?o{SQ51s=KhL?V)(i-CB9QHAS
zIy{x#6qV-zC&Q}9Q^{B5c@B9PT4tWi3v4koUF?YbJT(bWzMjWQF@kD6H6v5(o(G%^
zDQ}PG4sDypZkTRGBQnAzFc$Z?v@ym*+bIlh#zt+!rftB6wq7GPUDLK(e?uFs5nHUa
zvKE>{7j`JLDbA3d&%+~$W{d#1MmvJA%O1rf1>DhijX;RraL2QFAyK~od@l%+^i-FP
zK465*TqtCE-Vz%Db$SHA+=c
zcq*TyiZPVx$@_REuuK6Sf74`+u~Dky*CtvS!q@h|=>hSh1G@ULr%c7o1hW%>=4g7+7}?;w;!0VK2(`eW|2U&iYs#08lmd5(>FYk
zF7482Ta2V`#?w)yAdFwED38?eMaYMTzRgYq-Z{XfoXs&d!t=tLf2xgsp3{BLn(Ub{
z*|!9^M6-FHtj7Ge0a<%y}7C(?7+$8g8c6N#6D>$z`d_#^5uq=AF&IqYM&qwI;q
z)|1z^Gs+?NL|XfBd>-;L^LTuK!Z=s&oq3#|a}c~8bj*SN%5&pUw{91kP{=
z&uiaKEeDX+e~)YQN#@qE*g#?9ylnTh?DaHeAZ$L}**n`AA*~0ZMU?osX<r9BkZPv{e;P|5mAh7kS31RWditmwa4@9S
zDW2ArN9C@SVRUi{Y@0eJw@eY9PZB&2W2dymp$##HNk2Pl;BBOcD4Wr65SY`xPAB<-
z?pP8AKrkPp3@)|PMHYE0DsoCoAe74D!=(^7J|P|iQJZHe>?1tNLVpsdUvl8LOG-GD
zid_Pke^eV-&B26$Nwns|&OS;Zvd6=hs2L0;V)~Kk)LEW>7xH}R52RxRgXtK4TslV^
zW~WU^%~55g-h_HE3o9DR#VUZysL8ZoNiSM@I$J1|%jU%;GVYc+dZn1bP%@?;mrmbi
zlHQ??QBVS-gTS0U6e7Lhw}C){c{$|>X_8(QDCCE`6FWSd47Nk@<;(FI!OG}7=#J#_`HP?oTgKuR<7ize~B-Ki~wUY5#$!h3onP-0?<`#b?+VR?YgX^|_`c9C$Y&cnMm&|0>c;Q5-tS)ul}1E
znYtq3My@~nxMbU9Pf~vwQRSFjNfNYJ!>UiLxh4^!;cWHxPZCe@)~5rjB(Y
zYfl#&ezHxqxS#x=6(Mt4j8neyTb~3+hkTrC#X}W!-{`?RJR?Vzwo7)PJWx`KR$u
zHpMM>QVRdL!9wDle5Pnp==$@)0@vuy1qodY
zJ}z*vz*oo6X?|eS9pIi}y77*g_u2XJ?q%yC{4ko5CC0&9dk_EivHFkk_E~$7_45aZ
PJ#<~G=!E@0o~jMXIN+V<
diff --git a/docs/build/html/python/_autosummary/mlx.core.Device.html b/docs/build/html/python/_autosummary/mlx.core.Device.html
index 8ee160470..d73f5e65a 100644
--- a/docs/build/html/python/_autosummary/mlx.core.Device.html
+++ b/docs/build/html/python/_autosummary/mlx.core.Device.html
@@ -9,7 +9,7 @@
- mlx.core.Device — MLX 0.0.9 documentation
+ mlx.core.Device — MLX 0.1.0 documentation
@@ -134,8 +134,8 @@
-
-
+
+
@@ -241,6 +241,8 @@
mlx.core.cos
mlx.core.cosh
mlx.core.dequantize
+mlx.core.diag
+mlx.core.diagonal
mlx.core.divide
mlx.core.divmod
mlx.core.equal
@@ -351,7 +353,6 @@
mlx.core.jvp
mlx.core.vjp
mlx.core.vmap
-mlx.core.simplify
FFT
Linear Algebra
Neural Networks
@@ -433,6 +436,7 @@
mlx.nn.prelu
mlx.nn.relu
mlx.nn.selu
+mlx.nn.softshrink
mlx.nn.silu
mlx.nn.step
@@ -441,6 +445,7 @@
mlx.nn.losses.binary_cross_entropy
mlx.nn.losses.cosine_similarity_loss
mlx.nn.losses.cross_entropy
+mlx.nn.losses.gaussian_nll_loss
mlx.nn.losses.hinge_loss
mlx.nn.losses.huber_loss
mlx.nn.losses.kl_div_loss
@@ -452,14 +457,26 @@
mlx.nn.losses.triplet_loss
+Initializers
-Optimizers
+
+Optimizers
-Tree Utils
@@ -441,6 +445,7 @@
mlx.nn.losses.binary_cross_entropy
mlx.nn.losses.cosine_similarity_loss
mlx.nn.losses.cross_entropy
+mlx.nn.losses.gaussian_nll_loss
mlx.nn.losses.hinge_loss
mlx.nn.losses.huber_loss
mlx.nn.losses.kl_div_loss
@@ -452,14 +457,26 @@
mlx.nn.losses.triplet_loss
+Initializers
-Optimizers
+
+Optimizers
-Tree Utils
@@ -441,6 +445,7 @@
mlx.nn.losses.binary_cross_entropy
mlx.nn.losses.cosine_similarity_loss
mlx.nn.losses.cross_entropy
+mlx.nn.losses.gaussian_nll_loss
mlx.nn.losses.hinge_loss
mlx.nn.losses.huber_loss
mlx.nn.losses.kl_div_loss
@@ -452,14 +457,26 @@
mlx.nn.losses.triplet_loss
+Initializers
-Optimizers
+
+Optimizers
-Tree Utils
@@ -441,6 +445,7 @@
mlx.nn.losses.binary_cross_entropy
mlx.nn.losses.cosine_similarity_loss
mlx.nn.losses.cross_entropy
+mlx.nn.losses.gaussian_nll_loss
mlx.nn.losses.hinge_loss
mlx.nn.losses.huber_loss
mlx.nn.losses.kl_div_loss
@@ -452,14 +457,26 @@
mlx.nn.losses.triplet_loss
+Initializers
-Optimizers
+
+Optimizers
-Tree Utils
@@ -441,6 +445,7 @@
mlx.nn.losses.binary_cross_entropy
mlx.nn.losses.cosine_similarity_loss
mlx.nn.losses.cross_entropy
+mlx.nn.losses.gaussian_nll_loss
mlx.nn.losses.hinge_loss
mlx.nn.losses.huber_loss
mlx.nn.losses.kl_div_loss
@@ -452,14 +457,26 @@
mlx.nn.losses.triplet_loss
+Initializers
-Optimizers
+
+Optimizers
-Tree Utils
@@ -441,6 +445,7 @@
mlx.nn.losses.binary_cross_entropy
mlx.nn.losses.cosine_similarity_loss
mlx.nn.losses.cross_entropy
+mlx.nn.losses.gaussian_nll_loss
mlx.nn.losses.hinge_loss
mlx.nn.losses.huber_loss
mlx.nn.losses.kl_div_loss
@@ -452,14 +457,26 @@
mlx.nn.losses.triplet_loss
+Initializers
-Optimizers
+
+Optimizers
-Tree Utils
@@ -441,6 +445,7 @@
mlx.nn.losses.binary_cross_entropy
mlx.nn.losses.cosine_similarity_loss
mlx.nn.losses.cross_entropy
+mlx.nn.losses.gaussian_nll_loss
mlx.nn.losses.hinge_loss
mlx.nn.losses.huber_loss
mlx.nn.losses.kl_div_loss
@@ -452,14 +457,26 @@
mlx.nn.losses.triplet_loss
+Initializers
-Optimizers
+
+Optimizers
-Tree Utils
+Tree Utils
mlx.utils.tree_flatten
mlx.utils.tree_unflatten
mlx.utils.tree_map
@@ -650,8 +667,9 @@ document.write(`
mlx.core.allclose
-mlx.core. allclose ( a : array , b : array , / , rtol : float = 1e-05 , atol : float = 1e-08 , * , stream : Union [ None , Stream , Device ] = None ) → array
+mlx.core. allclose ( a : array , b : array , / , rtol : float = 1e-05 , atol : float = 1e-08 , * , equal_nan : bool = False , stream : Union [ None , Stream , Device ] = None ) → array
Approximate comparison of two arrays.
+Infinite values are considered equal if they have the same sign, NaN values are not equal unless equal_nan
is True
.
The arrays are considered equal if:
all ( abs ( a - b ) <= ( atol + rtol * abs ( b )))
@@ -665,6 +683,8 @@ broadcasting.
b (array ) – Input array.
rtol (float ) – Relative tolerance.
atol (float ) – Absolute tolerance.
+
equal_nan (bool ) – If True
, NaNs are considered equal.
+Defaults to False
.
Returns:
diff --git a/docs/build/html/python/_autosummary/mlx.core.any.html b/docs/build/html/python/_autosummary/mlx.core.any.html
index a57b49010..8047fabc8 100644
--- a/docs/build/html/python/_autosummary/mlx.core.any.html
+++ b/docs/build/html/python/_autosummary/mlx.core.any.html
@@ -9,7 +9,7 @@
-
mlx.core.any — MLX 0.0.9 documentation
+
mlx.core.any — MLX 0.1.0 documentation
@@ -134,8 +134,8 @@
-
-
+
+
@@ -241,6 +241,8 @@
mlx.core.cos
mlx.core.cosh
mlx.core.dequantize
+mlx.core.diag
+mlx.core.diagonal
mlx.core.divide
mlx.core.divmod
mlx.core.equal
@@ -351,7 +353,6 @@
mlx.core.jvp
mlx.core.vjp
mlx.core.vmap
-mlx.core.simplify
FFT
Linear Algebra
Neural Networks
@@ -433,6 +436,7 @@
mlx.nn.prelu
mlx.nn.relu
mlx.nn.selu
+mlx.nn.softshrink
mlx.nn.silu
mlx.nn.step
@@ -441,6 +445,7 @@
mlx.nn.losses.binary_cross_entropy
mlx.nn.losses.cosine_similarity_loss
mlx.nn.losses.cross_entropy
+mlx.nn.losses.gaussian_nll_loss
mlx.nn.losses.hinge_loss
mlx.nn.losses.huber_loss
mlx.nn.losses.kl_div_loss
@@ -452,14 +457,26 @@
mlx.nn.losses.triplet_loss
+Initializers
-Optimizers
+
+Optimizers
-Tree Utils
@@ -441,6 +445,7 @@
mlx.nn.losses.binary_cross_entropy
mlx.nn.losses.cosine_similarity_loss
mlx.nn.losses.cross_entropy
+mlx.nn.losses.gaussian_nll_loss
mlx.nn.losses.hinge_loss
mlx.nn.losses.huber_loss
mlx.nn.losses.kl_div_loss
@@ -452,14 +457,26 @@
mlx.nn.losses.triplet_loss
+Initializers
-Optimizers
+
+Optimizers
-Tree Utils
@@ -441,6 +445,7 @@
mlx.nn.losses.binary_cross_entropy
mlx.nn.losses.cosine_similarity_loss
mlx.nn.losses.cross_entropy
+mlx.nn.losses.gaussian_nll_loss
mlx.nn.losses.hinge_loss
mlx.nn.losses.huber_loss
mlx.nn.losses.kl_div_loss
@@ -452,14 +457,26 @@
mlx.nn.losses.triplet_loss
+Initializers
-Optimizers
+
+Optimizers
-Tree Utils
@@ -441,6 +445,7 @@
mlx.nn.losses.binary_cross_entropy
mlx.nn.losses.cosine_similarity_loss
mlx.nn.losses.cross_entropy
+mlx.nn.losses.gaussian_nll_loss
mlx.nn.losses.hinge_loss
mlx.nn.losses.huber_loss
mlx.nn.losses.kl_div_loss
@@ -452,14 +457,26 @@
mlx.nn.losses.triplet_loss
+Initializers
-Optimizers
+
+Optimizers
-Tree Utils
@@ -441,6 +445,7 @@
mlx.nn.losses.binary_cross_entropy
mlx.nn.losses.cosine_similarity_loss
mlx.nn.losses.cross_entropy
+mlx.nn.losses.gaussian_nll_loss
mlx.nn.losses.hinge_loss
mlx.nn.losses.huber_loss
mlx.nn.losses.kl_div_loss
@@ -452,14 +457,26 @@
mlx.nn.losses.triplet_loss
+Initializers
-Optimizers
+
+Optimizers
-Tree Utils
@@ -441,6 +445,7 @@
mlx.nn.losses.binary_cross_entropy
mlx.nn.losses.cosine_similarity_loss
mlx.nn.losses.cross_entropy
+mlx.nn.losses.gaussian_nll_loss
mlx.nn.losses.hinge_loss
mlx.nn.losses.huber_loss
mlx.nn.losses.kl_div_loss
@@ -452,14 +457,26 @@
mlx.nn.losses.triplet_loss
+Initializers
-Optimizers
+
+Optimizers
-Tree Utils
@@ -441,6 +445,7 @@
mlx.nn.losses.binary_cross_entropy
mlx.nn.losses.cosine_similarity_loss
mlx.nn.losses.cross_entropy
+mlx.nn.losses.gaussian_nll_loss
mlx.nn.losses.hinge_loss
mlx.nn.losses.huber_loss
mlx.nn.losses.kl_div_loss
@@ -452,14 +457,26 @@
mlx.nn.losses.triplet_loss
+Initializers
-Optimizers
+
+Optimizers
-Tree Utils
@@ -441,6 +445,7 @@
mlx.nn.losses.binary_cross_entropy
mlx.nn.losses.cosine_similarity_loss
mlx.nn.losses.cross_entropy
+mlx.nn.losses.gaussian_nll_loss
mlx.nn.losses.hinge_loss
mlx.nn.losses.huber_loss
mlx.nn.losses.kl_div_loss
@@ -452,14 +457,26 @@
mlx.nn.losses.triplet_loss
+Initializers
-Optimizers
+
+Optimizers
-Tree Utils
@@ -441,6 +445,7 @@
mlx.nn.losses.binary_cross_entropy
mlx.nn.losses.cosine_similarity_loss
mlx.nn.losses.cross_entropy
+mlx.nn.losses.gaussian_nll_loss
mlx.nn.losses.hinge_loss
mlx.nn.losses.huber_loss
mlx.nn.losses.kl_div_loss
@@ -452,14 +457,26 @@
mlx.nn.losses.triplet_loss
+Initializers
-Optimizers
+
+Optimizers
-Tree Utils
Returns:
-The output array with the indices of the maximum values.
+The uint32
array with the indices of the maximum values.
Return type:
array
diff --git a/docs/build/html/python/_autosummary/mlx.core.argmin.html b/docs/build/html/python/_autosummary/mlx.core.argmin.html
index c7fc4c489..44b21c3d0 100644
--- a/docs/build/html/python/_autosummary/mlx.core.argmin.html
+++ b/docs/build/html/python/_autosummary/mlx.core.argmin.html
@@ -9,7 +9,7 @@
- mlx.core.argmin — MLX 0.0.9 documentation
+ mlx.core.argmin — MLX 0.1.0 documentation
@@ -134,8 +134,8 @@
-
-
+
+
@@ -241,6 +241,8 @@
mlx.core.cos
mlx.core.cosh
mlx.core.dequantize
+mlx.core.diag
+mlx.core.diagonal
mlx.core.divide
mlx.core.divmod
mlx.core.equal
@@ -351,7 +353,6 @@
mlx.core.jvp
mlx.core.vjp
mlx.core.vmap
-mlx.core.simplify
FFT
Linear Algebra
Neural Networks
@@ -433,6 +436,7 @@
mlx.nn.prelu
mlx.nn.relu
mlx.nn.selu
+mlx.nn.softshrink
mlx.nn.silu
mlx.nn.step
@@ -441,6 +445,7 @@
mlx.nn.losses.binary_cross_entropy
mlx.nn.losses.cosine_similarity_loss
mlx.nn.losses.cross_entropy
+mlx.nn.losses.gaussian_nll_loss
mlx.nn.losses.hinge_loss
mlx.nn.losses.huber_loss
mlx.nn.losses.kl_div_loss
@@ -452,14 +457,26 @@
mlx.nn.losses.triplet_loss
+Initializers
-Optimizers
+
+Optimizers
-Tree Utils
Returns:
-The output array with the indices of the minimum values.
+The uint32
array with the indices of the minimum values.
Return type:
array
diff --git a/docs/build/html/python/_autosummary/mlx.core.argpartition.html b/docs/build/html/python/_autosummary/mlx.core.argpartition.html
index 6473ab649..15ae2b135 100644
--- a/docs/build/html/python/_autosummary/mlx.core.argpartition.html
+++ b/docs/build/html/python/_autosummary/mlx.core.argpartition.html
@@ -9,7 +9,7 @@
- mlx.core.argpartition — MLX 0.0.9 documentation
+ mlx.core.argpartition — MLX 0.1.0 documentation
@@ -134,8 +134,8 @@
-
-
+
+
@@ -241,6 +241,8 @@
mlx.core.cos
mlx.core.cosh
mlx.core.dequantize
+mlx.core.diag
+mlx.core.diagonal
mlx.core.divide
mlx.core.divmod
mlx.core.equal
@@ -351,7 +353,6 @@
mlx.core.jvp
mlx.core.vjp
mlx.core.vmap
-mlx.core.simplify
FFT
Linear Algebra
Neural Networks
@@ -433,6 +436,7 @@
mlx.nn.prelu
mlx.nn.relu
mlx.nn.selu
+mlx.nn.softshrink
mlx.nn.silu
mlx.nn.step
@@ -441,6 +445,7 @@
mlx.nn.losses.binary_cross_entropy
mlx.nn.losses.cosine_similarity_loss
mlx.nn.losses.cross_entropy
+mlx.nn.losses.gaussian_nll_loss
mlx.nn.losses.hinge_loss
mlx.nn.losses.huber_loss
mlx.nn.losses.kl_div_loss
@@ -452,14 +457,26 @@
mlx.nn.losses.triplet_loss
+Initializers
-Optimizers
+
+Optimizers
-Tree Utils
+Tree Utils
mlx.utils.tree_flatten
mlx.utils.tree_unflatten
mlx.utils.tree_map
@@ -669,7 +686,7 @@ If unspecified, it defaults to
Returns:
-The indices that partition the input array.
+The uint32` array containing indices that partition the input.
Return type:
array
diff --git a/docs/build/html/python/_autosummary/mlx.core.argsort.html b/docs/build/html/python/_autosummary/mlx.core.argsort.html
index 217a89f6b..ba9bcf00f 100644
--- a/docs/build/html/python/_autosummary/mlx.core.argsort.html
+++ b/docs/build/html/python/_autosummary/mlx.core.argsort.html
@@ -9,7 +9,7 @@
- mlx.core.argsort — MLX 0.0.9 documentation
+ mlx.core.argsort — MLX 0.1.0 documentation
@@ -134,8 +134,8 @@
-
-
+
+
@@ -241,6 +241,8 @@
mlx.core.cos
mlx.core.cosh
mlx.core.dequantize
+mlx.core.diag
+mlx.core.diagonal
mlx.core.divide
mlx.core.divmod
mlx.core.equal
@@ -351,7 +353,6 @@
mlx.core.jvp
mlx.core.vjp
mlx.core.vmap
-mlx.core.simplify
FFT
Linear Algebra
Neural Networks
@@ -433,6 +436,7 @@
mlx.nn.prelu
mlx.nn.relu
mlx.nn.selu
+mlx.nn.softshrink
mlx.nn.silu
mlx.nn.step
@@ -441,6 +445,7 @@
mlx.nn.losses.binary_cross_entropy
mlx.nn.losses.cosine_similarity_loss
mlx.nn.losses.cross_entropy
+mlx.nn.losses.gaussian_nll_loss
mlx.nn.losses.hinge_loss
mlx.nn.losses.huber_loss
mlx.nn.losses.kl_div_loss
@@ -452,14 +457,26 @@
mlx.nn.losses.triplet_loss
+Initializers
-Optimizers
+
+Optimizers
-Tree Utils
Returns:
-The indices that sort the input array.
+The uint32
array containing indices that sort the input.
Return type:
array
diff --git a/docs/build/html/python/_autosummary/mlx.core.array.T.html b/docs/build/html/python/_autosummary/mlx.core.array.T.html
index 34400a979..31fc555b2 100644
--- a/docs/build/html/python/_autosummary/mlx.core.array.T.html
+++ b/docs/build/html/python/_autosummary/mlx.core.array.T.html
@@ -9,7 +9,7 @@
- mlx.core.array.T — MLX 0.0.9 documentation
+ mlx.core.array.T — MLX 0.1.0 documentation
@@ -134,8 +134,8 @@
-
-
+
+
@@ -241,6 +241,8 @@
mlx.core.cos
mlx.core.cosh
mlx.core.dequantize
+mlx.core.diag
+mlx.core.diagonal
mlx.core.divide
mlx.core.divmod
mlx.core.equal
@@ -351,7 +353,6 @@
mlx.core.jvp
mlx.core.vjp
mlx.core.vmap
-mlx.core.simplify
FFT
Linear Algebra
Neural Networks
@@ -433,6 +436,7 @@
mlx.nn.prelu
mlx.nn.relu
mlx.nn.selu
+mlx.nn.softshrink
mlx.nn.silu
mlx.nn.step
@@ -441,6 +445,7 @@
mlx.nn.losses.binary_cross_entropy
mlx.nn.losses.cosine_similarity_loss
mlx.nn.losses.cross_entropy
+mlx.nn.losses.gaussian_nll_loss
mlx.nn.losses.hinge_loss
mlx.nn.losses.huber_loss
mlx.nn.losses.kl_div_loss
@@ -452,14 +457,26 @@
mlx.nn.losses.triplet_loss
+Initializers
-Optimizers
+
+Optimizers
-Tree Utils
@@ -441,6 +445,7 @@
mlx.nn.losses.binary_cross_entropy
mlx.nn.losses.cosine_similarity_loss
mlx.nn.losses.cross_entropy
+mlx.nn.losses.gaussian_nll_loss
mlx.nn.losses.hinge_loss
mlx.nn.losses.huber_loss
mlx.nn.losses.kl_div_loss
@@ -452,14 +457,26 @@
mlx.nn.losses.triplet_loss
+Initializers
-Optimizers
+
+Optimizers
-Tree Utils
@@ -441,6 +445,7 @@
mlx.nn.losses.binary_cross_entropy
mlx.nn.losses.cosine_similarity_loss
mlx.nn.losses.cross_entropy
+mlx.nn.losses.gaussian_nll_loss
mlx.nn.losses.hinge_loss
mlx.nn.losses.huber_loss
mlx.nn.losses.kl_div_loss
@@ -452,14 +457,26 @@
mlx.nn.losses.triplet_loss
+Initializers
-Optimizers
+
+Optimizers
-Tree Utils
@@ -441,6 +445,7 @@
mlx.nn.losses.binary_cross_entropy
mlx.nn.losses.cosine_similarity_loss
mlx.nn.losses.cross_entropy
+mlx.nn.losses.gaussian_nll_loss
mlx.nn.losses.hinge_loss
mlx.nn.losses.huber_loss
mlx.nn.losses.kl_div_loss
@@ -452,14 +457,26 @@
mlx.nn.losses.triplet_loss
+Initializers
-Optimizers
+
+Optimizers
-Tree Utils
@@ -441,6 +445,7 @@
mlx.nn.losses.binary_cross_entropy
mlx.nn.losses.cosine_similarity_loss
mlx.nn.losses.cross_entropy
+mlx.nn.losses.gaussian_nll_loss
mlx.nn.losses.hinge_loss
mlx.nn.losses.huber_loss
mlx.nn.losses.kl_div_loss
@@ -452,14 +457,26 @@
mlx.nn.losses.triplet_loss
+Initializers
-Optimizers
+
+Optimizers
-Tree Utils
@@ -441,6 +445,7 @@
mlx.nn.losses.binary_cross_entropy
mlx.nn.losses.cosine_similarity_loss
mlx.nn.losses.cross_entropy
+mlx.nn.losses.gaussian_nll_loss
mlx.nn.losses.hinge_loss
mlx.nn.losses.huber_loss
mlx.nn.losses.kl_div_loss
@@ -452,14 +457,26 @@
mlx.nn.losses.triplet_loss
+Initializers
-Optimizers
+
+Optimizers
-Tree Utils
@@ -441,6 +445,7 @@
mlx.nn.losses.binary_cross_entropy
mlx.nn.losses.cosine_similarity_loss
mlx.nn.losses.cross_entropy
+mlx.nn.losses.gaussian_nll_loss
mlx.nn.losses.hinge_loss
mlx.nn.losses.huber_loss
mlx.nn.losses.kl_div_loss
@@ -452,14 +457,26 @@
mlx.nn.losses.triplet_loss
+Initializers
-Optimizers
+
+Optimizers
-Tree Utils
@@ -441,6 +445,7 @@
mlx.nn.losses.binary_cross_entropy
mlx.nn.losses.cosine_similarity_loss
mlx.nn.losses.cross_entropy
+mlx.nn.losses.gaussian_nll_loss
mlx.nn.losses.hinge_loss
mlx.nn.losses.huber_loss
mlx.nn.losses.kl_div_loss
@@ -452,14 +457,26 @@
mlx.nn.losses.triplet_loss
+Initializers
-Optimizers
+
+Optimizers
-Tree Utils
@@ -441,6 +445,7 @@
mlx.nn.losses.binary_cross_entropy
mlx.nn.losses.cosine_similarity_loss
mlx.nn.losses.cross_entropy
+mlx.nn.losses.gaussian_nll_loss
mlx.nn.losses.hinge_loss
mlx.nn.losses.huber_loss
mlx.nn.losses.kl_div_loss
@@ -452,14 +457,26 @@
mlx.nn.losses.triplet_loss
+Initializers
-Optimizers
+
+Optimizers
-Tree Utils
@@ -441,6 +445,7 @@
mlx.nn.losses.binary_cross_entropy
mlx.nn.losses.cosine_similarity_loss
mlx.nn.losses.cross_entropy
+mlx.nn.losses.gaussian_nll_loss
mlx.nn.losses.hinge_loss
mlx.nn.losses.huber_loss
mlx.nn.losses.kl_div_loss
@@ -452,14 +457,26 @@
mlx.nn.losses.triplet_loss
+Initializers
-Optimizers
+
+Optimizers
-Tree Utils
@@ -441,6 +445,7 @@
mlx.nn.losses.binary_cross_entropy
mlx.nn.losses.cosine_similarity_loss
mlx.nn.losses.cross_entropy
+mlx.nn.losses.gaussian_nll_loss
mlx.nn.losses.hinge_loss
mlx.nn.losses.huber_loss
mlx.nn.losses.kl_div_loss
@@ -452,14 +457,26 @@
mlx.nn.losses.triplet_loss
+Initializers
-Optimizers
+
+Optimizers
-Tree Utils
+Tree Utils
mlx.utils.tree_flatten
mlx.utils.tree_unflatten
mlx.utils.tree_map
@@ -696,6 +713,12 @@ document.write(`
cumsum
(self[, axis, reverse, inclusive, stream])
See cumsum()
.
+diag
(self[, k, stream])
+Extract a diagonal or construct a diagonal matrix.
+
+diagonal
(self[, offset, axis1, axis2, stream])
+See diagonal()
.
+
exp
(self, *[, stream])
See exp()
.
diff --git a/docs/build/html/python/_autosummary/mlx.core.array.item.html b/docs/build/html/python/_autosummary/mlx.core.array.item.html
index 1c4b8f266..8b8b258e8 100644
--- a/docs/build/html/python/_autosummary/mlx.core.array.item.html
+++ b/docs/build/html/python/_autosummary/mlx.core.array.item.html
@@ -9,7 +9,7 @@
- mlx.core.array.item — MLX 0.0.9 documentation
+ mlx.core.array.item — MLX 0.1.0 documentation
@@ -134,8 +134,8 @@
-
-
+
+
@@ -241,6 +241,8 @@
mlx.core.cos
mlx.core.cosh
mlx.core.dequantize
+mlx.core.diag
+mlx.core.diagonal
mlx.core.divide
mlx.core.divmod
mlx.core.equal
@@ -351,7 +353,6 @@
mlx.core.jvp
mlx.core.vjp
mlx.core.vmap
-mlx.core.simplify
FFT
Linear Algebra
Neural Networks
@@ -433,6 +436,7 @@
mlx.nn.prelu
mlx.nn.relu
mlx.nn.selu
+mlx.nn.softshrink
mlx.nn.silu
mlx.nn.step
@@ -441,6 +445,7 @@
mlx.nn.losses.binary_cross_entropy
mlx.nn.losses.cosine_similarity_loss
mlx.nn.losses.cross_entropy
+mlx.nn.losses.gaussian_nll_loss
mlx.nn.losses.hinge_loss
mlx.nn.losses.huber_loss
mlx.nn.losses.kl_div_loss
@@ -452,14 +457,26 @@
mlx.nn.losses.triplet_loss
+Initializers
-Optimizers
+
+Optimizers
-Tree Utils
@@ -441,6 +445,7 @@
mlx.nn.losses.binary_cross_entropy
mlx.nn.losses.cosine_similarity_loss
mlx.nn.losses.cross_entropy
+mlx.nn.losses.gaussian_nll_loss
mlx.nn.losses.hinge_loss
mlx.nn.losses.huber_loss
mlx.nn.losses.kl_div_loss
@@ -452,14 +457,26 @@
mlx.nn.losses.triplet_loss
+Initializers
-Optimizers
+
+Optimizers
-Tree Utils
@@ -441,6 +445,7 @@
mlx.nn.losses.binary_cross_entropy
mlx.nn.losses.cosine_similarity_loss
mlx.nn.losses.cross_entropy
+mlx.nn.losses.gaussian_nll_loss
mlx.nn.losses.hinge_loss
mlx.nn.losses.huber_loss
mlx.nn.losses.kl_div_loss
@@ -452,14 +457,26 @@
mlx.nn.losses.triplet_loss
+Initializers
-Optimizers
+
+Optimizers
-Tree Utils
@@ -441,6 +445,7 @@
mlx.nn.losses.binary_cross_entropy
mlx.nn.losses.cosine_similarity_loss
mlx.nn.losses.cross_entropy
+mlx.nn.losses.gaussian_nll_loss
mlx.nn.losses.hinge_loss
mlx.nn.losses.huber_loss
mlx.nn.losses.kl_div_loss
@@ -452,14 +457,26 @@
mlx.nn.losses.triplet_loss
+Initializers
-Optimizers
+
+Optimizers
-Tree Utils
@@ -441,6 +445,7 @@
mlx.nn.losses.binary_cross_entropy
mlx.nn.losses.cosine_similarity_loss
mlx.nn.losses.cross_entropy
+mlx.nn.losses.gaussian_nll_loss
mlx.nn.losses.hinge_loss
mlx.nn.losses.huber_loss
mlx.nn.losses.kl_div_loss
@@ -452,14 +457,26 @@
mlx.nn.losses.triplet_loss
+Initializers
-Optimizers
+
+Optimizers
-Tree Utils
@@ -441,6 +445,7 @@
mlx.nn.losses.binary_cross_entropy
mlx.nn.losses.cosine_similarity_loss
mlx.nn.losses.cross_entropy
+mlx.nn.losses.gaussian_nll_loss
mlx.nn.losses.hinge_loss
mlx.nn.losses.huber_loss
mlx.nn.losses.kl_div_loss
@@ -452,14 +457,26 @@
mlx.nn.losses.triplet_loss
+Initializers
-Optimizers
+
+Optimizers
-Tree Utils
@@ -441,6 +445,7 @@
mlx.nn.losses.binary_cross_entropy
mlx.nn.losses.cosine_similarity_loss
mlx.nn.losses.cross_entropy
+mlx.nn.losses.gaussian_nll_loss
mlx.nn.losses.hinge_loss
mlx.nn.losses.huber_loss
mlx.nn.losses.kl_div_loss
@@ -452,14 +457,26 @@
mlx.nn.losses.triplet_loss
+Initializers
-Optimizers
+
+Optimizers
-Tree Utils
@@ -441,6 +445,7 @@
mlx.nn.losses.binary_cross_entropy
mlx.nn.losses.cosine_similarity_loss
mlx.nn.losses.cross_entropy
+mlx.nn.losses.gaussian_nll_loss
mlx.nn.losses.hinge_loss
mlx.nn.losses.huber_loss
mlx.nn.losses.kl_div_loss
@@ -452,14 +457,26 @@
mlx.nn.losses.triplet_loss
+Initializers
-Optimizers
+
+Optimizers
-Tree Utils
@@ -441,6 +445,7 @@
mlx.nn.losses.binary_cross_entropy
mlx.nn.losses.cosine_similarity_loss
mlx.nn.losses.cross_entropy
+mlx.nn.losses.gaussian_nll_loss
mlx.nn.losses.hinge_loss
mlx.nn.losses.huber_loss
mlx.nn.losses.kl_div_loss
@@ -452,14 +457,26 @@
mlx.nn.losses.triplet_loss
+Initializers
-Optimizers
+
+Optimizers
-Tree Utils
@@ -441,6 +445,7 @@
mlx.nn.losses.binary_cross_entropy
mlx.nn.losses.cosine_similarity_loss
mlx.nn.losses.cross_entropy
+mlx.nn.losses.gaussian_nll_loss
mlx.nn.losses.hinge_loss
mlx.nn.losses.huber_loss
mlx.nn.losses.kl_div_loss
@@ -452,14 +457,26 @@
mlx.nn.losses.triplet_loss
+Initializers
-Optimizers
+
+Optimizers
-Tree Utils
@@ -441,6 +445,7 @@
mlx.nn.losses.binary_cross_entropy
mlx.nn.losses.cosine_similarity_loss
mlx.nn.losses.cross_entropy
+mlx.nn.losses.gaussian_nll_loss
mlx.nn.losses.hinge_loss
mlx.nn.losses.huber_loss
mlx.nn.losses.kl_div_loss
@@ -452,14 +457,26 @@
mlx.nn.losses.triplet_loss
+Initializers
-Optimizers
+
+Optimizers
-Tree Utils
@@ -441,6 +445,7 @@
mlx.nn.losses.binary_cross_entropy
mlx.nn.losses.cosine_similarity_loss
mlx.nn.losses.cross_entropy
+mlx.nn.losses.gaussian_nll_loss
mlx.nn.losses.hinge_loss
mlx.nn.losses.huber_loss
mlx.nn.losses.kl_div_loss
@@ -452,14 +457,26 @@
mlx.nn.losses.triplet_loss
+Initializers
-Optimizers
+
+Optimizers
-Tree Utils
@@ -441,6 +445,7 @@
mlx.nn.losses.binary_cross_entropy
mlx.nn.losses.cosine_similarity_loss
mlx.nn.losses.cross_entropy
+mlx.nn.losses.gaussian_nll_loss
mlx.nn.losses.hinge_loss
mlx.nn.losses.huber_loss
mlx.nn.losses.kl_div_loss
@@ -452,14 +457,26 @@
mlx.nn.losses.triplet_loss
+Initializers
-Optimizers
+
+Optimizers
-Tree Utils
@@ -441,6 +445,7 @@
mlx.nn.losses.binary_cross_entropy
mlx.nn.losses.cosine_similarity_loss
mlx.nn.losses.cross_entropy
+mlx.nn.losses.gaussian_nll_loss
mlx.nn.losses.hinge_loss
mlx.nn.losses.huber_loss
mlx.nn.losses.kl_div_loss
@@ -452,14 +457,26 @@
mlx.nn.losses.triplet_loss
+Initializers
-Optimizers
+
+Optimizers
-Tree Utils
@@ -441,6 +445,7 @@
mlx.nn.losses.binary_cross_entropy
mlx.nn.losses.cosine_similarity_loss
mlx.nn.losses.cross_entropy
+mlx.nn.losses.gaussian_nll_loss
mlx.nn.losses.hinge_loss
mlx.nn.losses.huber_loss
mlx.nn.losses.kl_div_loss
@@ -452,14 +457,26 @@
mlx.nn.losses.triplet_loss
+Initializers
-Optimizers
+
+Optimizers
-Tree Utils
@@ -441,6 +445,7 @@
mlx.nn.losses.binary_cross_entropy
mlx.nn.losses.cosine_similarity_loss
mlx.nn.losses.cross_entropy
+mlx.nn.losses.gaussian_nll_loss
mlx.nn.losses.hinge_loss
mlx.nn.losses.huber_loss
mlx.nn.losses.kl_div_loss
@@ -452,14 +457,26 @@
mlx.nn.losses.triplet_loss
+Initializers
-Optimizers
+
+Optimizers
-Tree Utils
@@ -441,6 +445,7 @@
mlx.nn.losses.binary_cross_entropy
mlx.nn.losses.cosine_similarity_loss
mlx.nn.losses.cross_entropy
+mlx.nn.losses.gaussian_nll_loss
mlx.nn.losses.hinge_loss
mlx.nn.losses.huber_loss
mlx.nn.losses.kl_div_loss
@@ -452,14 +457,26 @@
mlx.nn.losses.triplet_loss
+Initializers
-Optimizers
+
+Optimizers
-Tree Utils
@@ -441,6 +445,7 @@
mlx.nn.losses.binary_cross_entropy
mlx.nn.losses.cosine_similarity_loss
mlx.nn.losses.cross_entropy
+mlx.nn.losses.gaussian_nll_loss
mlx.nn.losses.hinge_loss
mlx.nn.losses.huber_loss
mlx.nn.losses.kl_div_loss
@@ -452,14 +457,26 @@
mlx.nn.losses.triplet_loss
+Initializers
-Optimizers
+
+Optimizers
-Tree Utils
@@ -441,6 +445,7 @@
mlx.nn.losses.binary_cross_entropy
mlx.nn.losses.cosine_similarity_loss
mlx.nn.losses.cross_entropy
+mlx.nn.losses.gaussian_nll_loss
mlx.nn.losses.hinge_loss
mlx.nn.losses.huber_loss
mlx.nn.losses.kl_div_loss
@@ -452,14 +457,26 @@
mlx.nn.losses.triplet_loss
+Initializers
-Optimizers
+
+Optimizers
-Tree Utils
@@ -441,6 +445,7 @@
mlx.nn.losses.binary_cross_entropy
mlx.nn.losses.cosine_similarity_loss
mlx.nn.losses.cross_entropy
+mlx.nn.losses.gaussian_nll_loss
mlx.nn.losses.hinge_loss
mlx.nn.losses.huber_loss
mlx.nn.losses.kl_div_loss
@@ -452,14 +457,26 @@
mlx.nn.losses.triplet_loss
+Initializers
-Optimizers
+
+Optimizers
-Tree Utils
@@ -441,6 +445,7 @@
mlx.nn.losses.binary_cross_entropy
mlx.nn.losses.cosine_similarity_loss
mlx.nn.losses.cross_entropy
+mlx.nn.losses.gaussian_nll_loss
mlx.nn.losses.hinge_loss
mlx.nn.losses.huber_loss
mlx.nn.losses.kl_div_loss
@@ -452,14 +457,26 @@
mlx.nn.losses.triplet_loss
+Initializers
-Optimizers
+
+Optimizers
-Tree Utils
@@ -441,6 +445,7 @@
mlx.nn.losses.binary_cross_entropy
mlx.nn.losses.cosine_similarity_loss
mlx.nn.losses.cross_entropy
+mlx.nn.losses.gaussian_nll_loss
mlx.nn.losses.hinge_loss
mlx.nn.losses.huber_loss
mlx.nn.losses.kl_div_loss
@@ -452,14 +457,26 @@
mlx.nn.losses.triplet_loss
+Initializers
-Optimizers
+
+Optimizers
-Tree Utils
@@ -441,6 +445,7 @@
mlx.nn.losses.binary_cross_entropy
mlx.nn.losses.cosine_similarity_loss
mlx.nn.losses.cross_entropy
+mlx.nn.losses.gaussian_nll_loss
mlx.nn.losses.hinge_loss
mlx.nn.losses.huber_loss
mlx.nn.losses.kl_div_loss
@@ -452,14 +457,26 @@
mlx.nn.losses.triplet_loss
+Initializers
-Optimizers
+
+Optimizers
-Tree Utils
@@ -441,6 +445,7 @@
mlx.nn.losses.binary_cross_entropy
mlx.nn.losses.cosine_similarity_loss
mlx.nn.losses.cross_entropy
+mlx.nn.losses.gaussian_nll_loss
mlx.nn.losses.hinge_loss
mlx.nn.losses.huber_loss
mlx.nn.losses.kl_div_loss
@@ -452,14 +457,26 @@
mlx.nn.losses.triplet_loss
+Initializers
-Optimizers
+
+Optimizers
-Tree Utils
+Tree Utils
mlx.utils.tree_flatten
mlx.utils.tree_unflatten
mlx.utils.tree_map
@@ -660,7 +677,7 @@ the same type to be considered equal.
a (array ) – Input array or scalar.
b (array ) – Input array or scalar.
-equal_nan (bool ) – If True
, NaNs are treated as equal.
+
equal_nan (bool ) – If True
, NaNs are considered equal.
Defaults to False
.
diff --git a/docs/build/html/python/_autosummary/mlx.core.broadcast_to.html b/docs/build/html/python/_autosummary/mlx.core.broadcast_to.html
index e945cd843..af9088e5a 100644
--- a/docs/build/html/python/_autosummary/mlx.core.broadcast_to.html
+++ b/docs/build/html/python/_autosummary/mlx.core.broadcast_to.html
@@ -9,7 +9,7 @@
- mlx.core.broadcast_to — MLX 0.0.9 documentation
+ mlx.core.broadcast_to — MLX 0.1.0 documentation
@@ -134,8 +134,8 @@
-
-
+
+
@@ -241,6 +241,8 @@
mlx.core.cos
mlx.core.cosh
mlx.core.dequantize
+mlx.core.diag
+mlx.core.diagonal
mlx.core.divide
mlx.core.divmod
mlx.core.equal
@@ -351,7 +353,6 @@
mlx.core.jvp
mlx.core.vjp
mlx.core.vmap
-mlx.core.simplify
FFT
Linear Algebra
Neural Networks
@@ -433,6 +436,7 @@
mlx.nn.prelu
mlx.nn.relu
mlx.nn.selu
+mlx.nn.softshrink
mlx.nn.silu
mlx.nn.step
@@ -441,6 +445,7 @@
mlx.nn.losses.binary_cross_entropy
mlx.nn.losses.cosine_similarity_loss
mlx.nn.losses.cross_entropy
+mlx.nn.losses.gaussian_nll_loss
mlx.nn.losses.hinge_loss
mlx.nn.losses.huber_loss
mlx.nn.losses.kl_div_loss
@@ -452,14 +457,26 @@
mlx.nn.losses.triplet_loss
+Initializers
-Optimizers
+
+Optimizers
-Tree Utils
@@ -441,6 +445,7 @@
mlx.nn.losses.binary_cross_entropy
mlx.nn.losses.cosine_similarity_loss
mlx.nn.losses.cross_entropy
+mlx.nn.losses.gaussian_nll_loss
mlx.nn.losses.hinge_loss
mlx.nn.losses.huber_loss
mlx.nn.losses.kl_div_loss
@@ -452,14 +457,26 @@
mlx.nn.losses.triplet_loss
+Initializers
-Optimizers
+
+Optimizers
-Tree Utils
@@ -441,6 +445,7 @@
mlx.nn.losses.binary_cross_entropy
mlx.nn.losses.cosine_similarity_loss
mlx.nn.losses.cross_entropy
+mlx.nn.losses.gaussian_nll_loss
mlx.nn.losses.hinge_loss
mlx.nn.losses.huber_loss
mlx.nn.losses.kl_div_loss
@@ -452,14 +457,26 @@
mlx.nn.losses.triplet_loss
+Initializers
-Optimizers
+
+Optimizers
-Tree Utils
@@ -441,6 +445,7 @@
mlx.nn.losses.binary_cross_entropy
mlx.nn.losses.cosine_similarity_loss
mlx.nn.losses.cross_entropy
+mlx.nn.losses.gaussian_nll_loss
mlx.nn.losses.hinge_loss
mlx.nn.losses.huber_loss
mlx.nn.losses.kl_div_loss
@@ -452,14 +457,26 @@
mlx.nn.losses.triplet_loss
+Initializers
-Optimizers
+
+Optimizers
-Tree Utils
@@ -441,6 +445,7 @@
mlx.nn.losses.binary_cross_entropy
mlx.nn.losses.cosine_similarity_loss
mlx.nn.losses.cross_entropy
+mlx.nn.losses.gaussian_nll_loss
mlx.nn.losses.hinge_loss
mlx.nn.losses.huber_loss
mlx.nn.losses.kl_div_loss
@@ -452,14 +457,26 @@
mlx.nn.losses.triplet_loss
+Initializers
-Optimizers
+
+Optimizers
-Tree Utils
@@ -441,6 +445,7 @@
mlx.nn.losses.binary_cross_entropy
mlx.nn.losses.cosine_similarity_loss
mlx.nn.losses.cross_entropy
+mlx.nn.losses.gaussian_nll_loss
mlx.nn.losses.hinge_loss
mlx.nn.losses.huber_loss
mlx.nn.losses.kl_div_loss
@@ -452,14 +457,26 @@
mlx.nn.losses.triplet_loss
+Initializers
-Optimizers
+
+Optimizers
-Tree Utils
@@ -441,6 +445,7 @@
mlx.nn.losses.binary_cross_entropy
mlx.nn.losses.cosine_similarity_loss
mlx.nn.losses.cross_entropy
+mlx.nn.losses.gaussian_nll_loss
mlx.nn.losses.hinge_loss
mlx.nn.losses.huber_loss
mlx.nn.losses.kl_div_loss
@@ -452,14 +457,26 @@
mlx.nn.losses.triplet_loss
+Initializers
-Optimizers
+
+Optimizers
-Tree Utils
@@ -441,6 +445,7 @@
mlx.nn.losses.binary_cross_entropy
mlx.nn.losses.cosine_similarity_loss
mlx.nn.losses.cross_entropy
+mlx.nn.losses.gaussian_nll_loss
mlx.nn.losses.hinge_loss
mlx.nn.losses.huber_loss
mlx.nn.losses.kl_div_loss
@@ -452,14 +457,26 @@
mlx.nn.losses.triplet_loss
+Initializers
-Optimizers
+
+Optimizers
-Tree Utils
@@ -441,6 +445,7 @@
mlx.nn.losses.binary_cross_entropy
mlx.nn.losses.cosine_similarity_loss
mlx.nn.losses.cross_entropy
+mlx.nn.losses.gaussian_nll_loss
mlx.nn.losses.hinge_loss
mlx.nn.losses.huber_loss
mlx.nn.losses.kl_div_loss
@@ -452,14 +457,26 @@
mlx.nn.losses.triplet_loss
+Initializers
-Optimizers
+
+Optimizers
-Tree Utils
@@ -441,6 +445,7 @@
mlx.nn.losses.binary_cross_entropy
mlx.nn.losses.cosine_similarity_loss
mlx.nn.losses.cross_entropy
+mlx.nn.losses.gaussian_nll_loss
mlx.nn.losses.hinge_loss
mlx.nn.losses.huber_loss
mlx.nn.losses.kl_div_loss
@@ -452,14 +457,26 @@
mlx.nn.losses.triplet_loss
+Initializers
-Optimizers
+
+Optimizers
-Tree Utils
@@ -441,6 +445,7 @@
mlx.nn.losses.binary_cross_entropy
mlx.nn.losses.cosine_similarity_loss
mlx.nn.losses.cross_entropy
+mlx.nn.losses.gaussian_nll_loss
mlx.nn.losses.hinge_loss
mlx.nn.losses.huber_loss
mlx.nn.losses.kl_div_loss
@@ -452,14 +457,26 @@
mlx.nn.losses.triplet_loss
+Initializers
-Optimizers
+
+Optimizers
-Tree Utils
+Tree Utils
mlx.utils.tree_flatten
mlx.utils.tree_unflatten
mlx.utils.tree_map
diff --git a/docs/build/html/python/_autosummary/mlx.core.dequantize.html b/docs/build/html/python/_autosummary/mlx.core.dequantize.html
index fff025d69..4c014b968 100644
--- a/docs/build/html/python/_autosummary/mlx.core.dequantize.html
+++ b/docs/build/html/python/_autosummary/mlx.core.dequantize.html
@@ -9,7 +9,7 @@
- mlx.core.dequantize — MLX 0.0.9 documentation
+ mlx.core.dequantize — MLX 0.1.0 documentation
@@ -47,7 +47,7 @@
-
+
@@ -135,8 +135,8 @@
-
-
+
+
@@ -242,6 +242,8 @@
mlx.core.cos
mlx.core.cosh
mlx.core.dequantize
+mlx.core.diag
+mlx.core.diagonal
mlx.core.divide
mlx.core.divmod
mlx.core.equal
@@ -352,7 +354,6 @@
mlx.core.jvp
mlx.core.vjp
mlx.core.vmap
-mlx.core.simplify
FFT
Linear Algebra
Neural Networks
@@ -434,6 +437,7 @@
mlx.nn.prelu
mlx.nn.relu
mlx.nn.selu
+mlx.nn.softshrink
mlx.nn.silu
mlx.nn.step
@@ -442,6 +446,7 @@
mlx.nn.losses.binary_cross_entropy
mlx.nn.losses.cosine_similarity_loss
mlx.nn.losses.cross_entropy
+mlx.nn.losses.gaussian_nll_loss
mlx.nn.losses.hinge_loss
mlx.nn.losses.huber_loss
mlx.nn.losses.kl_div_loss
@@ -453,14 +458,26 @@
mlx.nn.losses.triplet_loss
+Initializers
-Optimizers
+
+Optimizers
-Tree Utils
+Tree Utils
mlx.utils.tree_flatten
mlx.utils.tree_unflatten
mlx.utils.tree_map
@@ -703,11 +720,11 @@ scale and bias. (default:
next
-
mlx.core.divide
+
mlx.core.diag
diff --git a/docs/build/html/python/_autosummary/mlx.core.diag.html b/docs/build/html/python/_autosummary/mlx.core.diag.html
new file mode 100644
index 000000000..13ddf18c6
--- /dev/null
+++ b/docs/build/html/python/_autosummary/mlx.core.diag.html
@@ -0,0 +1,780 @@
+
+
+
+
+
+
+
+
+
+
+
+ mlx.core.diag — MLX 0.1.0 documentation
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+ Skip to main content
+
+
+
+
+
+
+ Back to top
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+mlx.core.diag
+
+
+mlx.core. diag ( a : array , / , k : int = 0 , * , stream : Union [ None , Stream , Device ] = None ) → array
+Extract a diagonal or construct a diagonal matrix.
+If a
is 1-D then a diagonal matrix is constructed with a
on the
+\(k\) -th diagonal. If a
is 2-D then the \(k\) -th diagonal is
+returned.
+
+Parameters:
+
+a (array ) – 1-D or 2-D input array.
+k (int , optional ) – The diagonal to extract or construct.
+Default: 0
.
+
+
+Returns:
+The extracted diagonal or the constructed diagonal matrix.
+
+Return type:
+array
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/docs/build/html/python/_autosummary/mlx.core.diagonal.html b/docs/build/html/python/_autosummary/mlx.core.diagonal.html
new file mode 100644
index 000000000..6c6b9c88b
--- /dev/null
+++ b/docs/build/html/python/_autosummary/mlx.core.diagonal.html
@@ -0,0 +1,786 @@
+
+
+
+
+
+
+
+
+
+
+
+ mlx.core.diagonal — MLX 0.1.0 documentation
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+ Skip to main content
+
+
+
+
+
+
+ Back to top
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
mlx.core.diagonal
+
+
+
+
+
+
+
+
+
+
+mlx.core.diagonal
+
+
+mlx.core. diagonal ( a : array , offset : int = 0 , axis1 : int = 0 , axis2 : int = 1 , stream : Union [ None , Stream , Device ] = None ) → array
+Return specified diagonals.
+If a
is 2-D, then a 1-D array containing the diagonal at the given
+offset
is returned.
+If a
has more than two dimensions, then axis1
and axis2
+determine the 2D subarrays from which diagonals are extracted. The new
+shape is the original shape with axis1
and axis2
removed and a
+new dimension inserted at the end corresponding to the diagonal.
+
+Parameters:
+
+a (array ) – Input array
+offset (int , optional ) – Offset of the diagonal from the main diagonal.
+Can be positive or negative. Default: 0
.
+axis1 (int , optional ) – The first axis of the 2-D sub-arrays from which
+the diagonals should be taken. Default: 0
.
+axis2 (int , optional ) – The second axis of the 2-D sub-arrays from which
+the diagonals should be taken. Default: 1
.
+
+
+Returns:
+The diagonals of the array.
+
+Return type:
+array
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/docs/build/html/python/_autosummary/mlx.core.divide.html b/docs/build/html/python/_autosummary/mlx.core.divide.html
index 249373294..b9d02c2dc 100644
--- a/docs/build/html/python/_autosummary/mlx.core.divide.html
+++ b/docs/build/html/python/_autosummary/mlx.core.divide.html
@@ -9,7 +9,7 @@
- mlx.core.divide — MLX 0.0.9 documentation
+ mlx.core.divide — MLX 0.1.0 documentation
@@ -47,7 +47,7 @@
-
+
@@ -134,8 +134,8 @@
-
-
+
+
@@ -241,6 +241,8 @@
mlx.core.cos
mlx.core.cosh
mlx.core.dequantize
+mlx.core.diag
+mlx.core.diagonal
mlx.core.divide
mlx.core.divmod
mlx.core.equal
@@ -351,7 +353,6 @@
mlx.core.jvp
mlx.core.vjp
mlx.core.vmap
-mlx.core.simplify
FFT
Linear Algebra
Neural Networks
@@ -433,6 +436,7 @@
mlx.nn.prelu
mlx.nn.relu
mlx.nn.selu
+mlx.nn.softshrink
mlx.nn.silu
mlx.nn.step
@@ -441,6 +445,7 @@
mlx.nn.losses.binary_cross_entropy
mlx.nn.losses.cosine_similarity_loss
mlx.nn.losses.cross_entropy
+mlx.nn.losses.gaussian_nll_loss
mlx.nn.losses.hinge_loss
mlx.nn.losses.huber_loss
mlx.nn.losses.kl_div_loss
@@ -452,14 +457,26 @@
mlx.nn.losses.triplet_loss
+Initializers
-Optimizers
+
+Optimizers
-Tree Utils
@@ -441,6 +445,7 @@
mlx.nn.losses.binary_cross_entropy
mlx.nn.losses.cosine_similarity_loss
mlx.nn.losses.cross_entropy
+mlx.nn.losses.gaussian_nll_loss
mlx.nn.losses.hinge_loss
mlx.nn.losses.huber_loss
mlx.nn.losses.kl_div_loss
@@ -452,14 +457,26 @@
mlx.nn.losses.triplet_loss
+Initializers
-Optimizers
+
+Optimizers
-Tree Utils
@@ -441,6 +445,7 @@
mlx.nn.losses.binary_cross_entropy
mlx.nn.losses.cosine_similarity_loss
mlx.nn.losses.cross_entropy
+mlx.nn.losses.gaussian_nll_loss
mlx.nn.losses.hinge_loss
mlx.nn.losses.huber_loss
mlx.nn.losses.kl_div_loss
@@ -452,14 +457,26 @@
mlx.nn.losses.triplet_loss
+Initializers
-Optimizers
+
+Optimizers
-Tree Utils
@@ -442,6 +446,7 @@
mlx.nn.losses.binary_cross_entropy
mlx.nn.losses.cosine_similarity_loss
mlx.nn.losses.cross_entropy
+mlx.nn.losses.gaussian_nll_loss
mlx.nn.losses.hinge_loss
mlx.nn.losses.huber_loss
mlx.nn.losses.kl_div_loss
@@ -453,14 +458,26 @@
mlx.nn.losses.triplet_loss
+Initializers
-Optimizers
+
+Optimizers
-Tree Utils
@@ -441,6 +445,7 @@
mlx.nn.losses.binary_cross_entropy
mlx.nn.losses.cosine_similarity_loss
mlx.nn.losses.cross_entropy
+mlx.nn.losses.gaussian_nll_loss
mlx.nn.losses.hinge_loss
mlx.nn.losses.huber_loss
mlx.nn.losses.kl_div_loss
@@ -452,14 +457,26 @@
mlx.nn.losses.triplet_loss
+Initializers
-Optimizers
+
+Optimizers
-Tree Utils
@@ -441,6 +445,7 @@
mlx.nn.losses.binary_cross_entropy
mlx.nn.losses.cosine_similarity_loss
mlx.nn.losses.cross_entropy
+mlx.nn.losses.gaussian_nll_loss
mlx.nn.losses.hinge_loss
mlx.nn.losses.huber_loss
mlx.nn.losses.kl_div_loss
@@ -452,14 +457,26 @@
mlx.nn.losses.triplet_loss
+Initializers
-Optimizers
+
+Optimizers
-Tree Utils
@@ -441,6 +445,7 @@
mlx.nn.losses.binary_cross_entropy
mlx.nn.losses.cosine_similarity_loss
mlx.nn.losses.cross_entropy
+mlx.nn.losses.gaussian_nll_loss
mlx.nn.losses.hinge_loss
mlx.nn.losses.huber_loss
mlx.nn.losses.kl_div_loss
@@ -452,14 +457,26 @@
mlx.nn.losses.triplet_loss
+Initializers
-Optimizers
+
+Optimizers
-Tree Utils
@@ -441,6 +445,7 @@
mlx.nn.losses.binary_cross_entropy
mlx.nn.losses.cosine_similarity_loss
mlx.nn.losses.cross_entropy
+mlx.nn.losses.gaussian_nll_loss
mlx.nn.losses.hinge_loss
mlx.nn.losses.huber_loss
mlx.nn.losses.kl_div_loss
@@ -452,14 +457,26 @@
mlx.nn.losses.triplet_loss
+Initializers
-Optimizers
+
+Optimizers
-Tree Utils
@@ -441,6 +445,7 @@
mlx.nn.losses.binary_cross_entropy
mlx.nn.losses.cosine_similarity_loss
mlx.nn.losses.cross_entropy
+mlx.nn.losses.gaussian_nll_loss
mlx.nn.losses.hinge_loss
mlx.nn.losses.huber_loss
mlx.nn.losses.kl_div_loss
@@ -452,14 +457,26 @@
mlx.nn.losses.triplet_loss
+Initializers
-Optimizers
+
+Optimizers
-Tree Utils
@@ -441,6 +445,7 @@
mlx.nn.losses.binary_cross_entropy
mlx.nn.losses.cosine_similarity_loss
mlx.nn.losses.cross_entropy
+mlx.nn.losses.gaussian_nll_loss
mlx.nn.losses.hinge_loss
mlx.nn.losses.huber_loss
mlx.nn.losses.kl_div_loss
@@ -452,14 +457,26 @@
mlx.nn.losses.triplet_loss
+Initializers
-Optimizers
+
+Optimizers
-Tree Utils
@@ -441,6 +445,7 @@
mlx.nn.losses.binary_cross_entropy
mlx.nn.losses.cosine_similarity_loss
mlx.nn.losses.cross_entropy
+mlx.nn.losses.gaussian_nll_loss
mlx.nn.losses.hinge_loss
mlx.nn.losses.huber_loss
mlx.nn.losses.kl_div_loss
@@ -452,14 +457,26 @@
mlx.nn.losses.triplet_loss
+Initializers
-Optimizers
+
+Optimizers
-Tree Utils
@@ -441,6 +445,7 @@
mlx.nn.losses.binary_cross_entropy
mlx.nn.losses.cosine_similarity_loss
mlx.nn.losses.cross_entropy
+mlx.nn.losses.gaussian_nll_loss
mlx.nn.losses.hinge_loss
mlx.nn.losses.huber_loss
mlx.nn.losses.kl_div_loss
@@ -452,14 +457,26 @@
mlx.nn.losses.triplet_loss
+Initializers
-Optimizers
+
+Optimizers
-Tree Utils
@@ -441,6 +445,7 @@
mlx.nn.losses.binary_cross_entropy
mlx.nn.losses.cosine_similarity_loss
mlx.nn.losses.cross_entropy
+mlx.nn.losses.gaussian_nll_loss
mlx.nn.losses.hinge_loss
mlx.nn.losses.huber_loss
mlx.nn.losses.kl_div_loss
@@ -452,14 +457,26 @@
mlx.nn.losses.triplet_loss
+Initializers
-Optimizers
+
+Optimizers
-Tree Utils
@@ -441,6 +445,7 @@
mlx.nn.losses.binary_cross_entropy
mlx.nn.losses.cosine_similarity_loss
mlx.nn.losses.cross_entropy
+mlx.nn.losses.gaussian_nll_loss
mlx.nn.losses.hinge_loss
mlx.nn.losses.huber_loss
mlx.nn.losses.kl_div_loss
@@ -452,14 +457,26 @@
mlx.nn.losses.triplet_loss
+Initializers
-Optimizers
+
+Optimizers
-Tree Utils
@@ -441,6 +445,7 @@
mlx.nn.losses.binary_cross_entropy
mlx.nn.losses.cosine_similarity_loss
mlx.nn.losses.cross_entropy
+mlx.nn.losses.gaussian_nll_loss
mlx.nn.losses.hinge_loss
mlx.nn.losses.huber_loss
mlx.nn.losses.kl_div_loss
@@ -452,14 +457,26 @@
mlx.nn.losses.triplet_loss
+Initializers
-Optimizers
+
+Optimizers
-Tree Utils
@@ -441,6 +445,7 @@
mlx.nn.losses.binary_cross_entropy
mlx.nn.losses.cosine_similarity_loss
mlx.nn.losses.cross_entropy
+mlx.nn.losses.gaussian_nll_loss
mlx.nn.losses.hinge_loss
mlx.nn.losses.huber_loss
mlx.nn.losses.kl_div_loss
@@ -452,14 +457,26 @@
mlx.nn.losses.triplet_loss
+Initializers
-Optimizers
+
+Optimizers
-Tree Utils
@@ -441,6 +445,7 @@
mlx.nn.losses.binary_cross_entropy
mlx.nn.losses.cosine_similarity_loss
mlx.nn.losses.cross_entropy
+mlx.nn.losses.gaussian_nll_loss
mlx.nn.losses.hinge_loss
mlx.nn.losses.huber_loss
mlx.nn.losses.kl_div_loss
@@ -452,14 +457,26 @@
mlx.nn.losses.triplet_loss
+Initializers
-Optimizers
+
+Optimizers
-Tree Utils
@@ -441,6 +445,7 @@
mlx.nn.losses.binary_cross_entropy
mlx.nn.losses.cosine_similarity_loss
mlx.nn.losses.cross_entropy
+mlx.nn.losses.gaussian_nll_loss
mlx.nn.losses.hinge_loss
mlx.nn.losses.huber_loss
mlx.nn.losses.kl_div_loss
@@ -452,14 +457,26 @@
mlx.nn.losses.triplet_loss
+Initializers
-Optimizers
+
+Optimizers
-Tree Utils
@@ -441,6 +445,7 @@
mlx.nn.losses.binary_cross_entropy
mlx.nn.losses.cosine_similarity_loss
mlx.nn.losses.cross_entropy
+mlx.nn.losses.gaussian_nll_loss
mlx.nn.losses.hinge_loss
mlx.nn.losses.huber_loss
mlx.nn.losses.kl_div_loss
@@ -452,14 +457,26 @@
mlx.nn.losses.triplet_loss
+Initializers
-Optimizers
+
+Optimizers
-Tree Utils
@@ -441,6 +445,7 @@
mlx.nn.losses.binary_cross_entropy
mlx.nn.losses.cosine_similarity_loss
mlx.nn.losses.cross_entropy
+mlx.nn.losses.gaussian_nll_loss
mlx.nn.losses.hinge_loss
mlx.nn.losses.huber_loss
mlx.nn.losses.kl_div_loss
@@ -452,14 +457,26 @@
mlx.nn.losses.triplet_loss
+Initializers
-Optimizers
+
+Optimizers
-Tree Utils
@@ -441,6 +445,7 @@
mlx.nn.losses.binary_cross_entropy
mlx.nn.losses.cosine_similarity_loss
mlx.nn.losses.cross_entropy
+mlx.nn.losses.gaussian_nll_loss
mlx.nn.losses.hinge_loss
mlx.nn.losses.huber_loss
mlx.nn.losses.kl_div_loss
@@ -452,14 +457,26 @@
mlx.nn.losses.triplet_loss
+Initializers
-Optimizers
+
+Optimizers
-Tree Utils
@@ -441,6 +445,7 @@
mlx.nn.losses.binary_cross_entropy
mlx.nn.losses.cosine_similarity_loss
mlx.nn.losses.cross_entropy
+mlx.nn.losses.gaussian_nll_loss
mlx.nn.losses.hinge_loss
mlx.nn.losses.huber_loss
mlx.nn.losses.kl_div_loss
@@ -452,14 +457,26 @@
mlx.nn.losses.triplet_loss
+Initializers
-Optimizers
+
+Optimizers
-Tree Utils
+Tree Utils
mlx.utils.tree_flatten
mlx.utils.tree_unflatten
mlx.utils.tree_map
@@ -652,6 +669,10 @@ document.write(`
mlx.core. flatten ( a : array , / , start_axis : int = 0 , end_axis : int = - 1 , * , stream : Union [ None , Stream , Device ] = None ) → array
Flatten an array.
+The axes flattened will be between start_axis
and end_axis
,
+inclusive. Negative axes are supported. After converting negative axis to
+positive, axes outside the valid range will be clamped to a valid value,
+start_axis
to 0
and end_axis
to ndim - 1
.
Parameters:
@@ -669,6 +690,15 @@ in which case the default stream of the default device is used.
array
+Example
+>>> a = mx . array ([[ 1 , 2 ], [ 3 , 4 ]])
+>>> mx . flatten ( a )
+array([1, 2, 3, 4], dtype=int32)
+>>>
+>>> mx . flatten ( a , start_axis = 0 , end_axis =- 1 )
+array([1, 2, 3, 4], dtype=int32)
+
+
diff --git a/docs/build/html/python/_autosummary/mlx.core.floor.html b/docs/build/html/python/_autosummary/mlx.core.floor.html
index 64967825a..857e80b7d 100644
--- a/docs/build/html/python/_autosummary/mlx.core.floor.html
+++ b/docs/build/html/python/_autosummary/mlx.core.floor.html
@@ -9,7 +9,7 @@
- mlx.core.floor — MLX 0.0.9 documentation
+ mlx.core.floor — MLX 0.1.0 documentation
@@ -134,8 +134,8 @@
-
-
+
+
@@ -241,6 +241,8 @@
mlx.core.cos
mlx.core.cosh
mlx.core.dequantize
+mlx.core.diag
+mlx.core.diagonal
mlx.core.divide
mlx.core.divmod
mlx.core.equal
@@ -351,7 +353,6 @@
mlx.core.jvp
mlx.core.vjp
mlx.core.vmap
-mlx.core.simplify
FFT
Linear Algebra
Neural Networks
@@ -433,6 +436,7 @@
mlx.nn.prelu
mlx.nn.relu
mlx.nn.selu
+mlx.nn.softshrink
mlx.nn.silu
mlx.nn.step
@@ -441,6 +445,7 @@
mlx.nn.losses.binary_cross_entropy
mlx.nn.losses.cosine_similarity_loss
mlx.nn.losses.cross_entropy
+mlx.nn.losses.gaussian_nll_loss
mlx.nn.losses.hinge_loss
mlx.nn.losses.huber_loss
mlx.nn.losses.kl_div_loss
@@ -452,14 +457,26 @@
mlx.nn.losses.triplet_loss
+Initializers
-Optimizers
+
+Optimizers
-Tree Utils
@@ -441,6 +445,7 @@
mlx.nn.losses.binary_cross_entropy
mlx.nn.losses.cosine_similarity_loss
mlx.nn.losses.cross_entropy
+mlx.nn.losses.gaussian_nll_loss
mlx.nn.losses.hinge_loss
mlx.nn.losses.huber_loss
mlx.nn.losses.kl_div_loss
@@ -452,14 +457,26 @@
mlx.nn.losses.triplet_loss
+Initializers
-Optimizers
+
+Optimizers
-Tree Utils
@@ -441,6 +445,7 @@
mlx.nn.losses.binary_cross_entropy
mlx.nn.losses.cosine_similarity_loss
mlx.nn.losses.cross_entropy
+mlx.nn.losses.gaussian_nll_loss
mlx.nn.losses.hinge_loss
mlx.nn.losses.huber_loss
mlx.nn.losses.kl_div_loss
@@ -452,14 +457,26 @@
mlx.nn.losses.triplet_loss
+Initializers
-Optimizers
+
+Optimizers
-Tree Utils
@@ -441,6 +445,7 @@
mlx.nn.losses.binary_cross_entropy
mlx.nn.losses.cosine_similarity_loss
mlx.nn.losses.cross_entropy
+mlx.nn.losses.gaussian_nll_loss
mlx.nn.losses.hinge_loss
mlx.nn.losses.huber_loss
mlx.nn.losses.kl_div_loss
@@ -452,14 +457,26 @@
mlx.nn.losses.triplet_loss
+Initializers
-Optimizers
+
+Optimizers
-Tree Utils
@@ -441,6 +445,7 @@
mlx.nn.losses.binary_cross_entropy
mlx.nn.losses.cosine_similarity_loss
mlx.nn.losses.cross_entropy
+mlx.nn.losses.gaussian_nll_loss
mlx.nn.losses.hinge_loss
mlx.nn.losses.huber_loss
mlx.nn.losses.kl_div_loss
@@ -452,14 +457,26 @@
mlx.nn.losses.triplet_loss
+Initializers
-Optimizers
+
+Optimizers
-Tree Utils
@@ -441,6 +445,7 @@
mlx.nn.losses.binary_cross_entropy
mlx.nn.losses.cosine_similarity_loss
mlx.nn.losses.cross_entropy
+mlx.nn.losses.gaussian_nll_loss
mlx.nn.losses.hinge_loss
mlx.nn.losses.huber_loss
mlx.nn.losses.kl_div_loss
@@ -452,14 +457,26 @@
mlx.nn.losses.triplet_loss
+Initializers
-Optimizers
+
+Optimizers
-Tree Utils
@@ -441,6 +445,7 @@
mlx.nn.losses.binary_cross_entropy
mlx.nn.losses.cosine_similarity_loss
mlx.nn.losses.cross_entropy
+mlx.nn.losses.gaussian_nll_loss
mlx.nn.losses.hinge_loss
mlx.nn.losses.huber_loss
mlx.nn.losses.kl_div_loss
@@ -452,14 +457,26 @@
mlx.nn.losses.triplet_loss
+Initializers
-Optimizers
+
+Optimizers
-Tree Utils
@@ -441,6 +445,7 @@
mlx.nn.losses.binary_cross_entropy
mlx.nn.losses.cosine_similarity_loss
mlx.nn.losses.cross_entropy
+mlx.nn.losses.gaussian_nll_loss
mlx.nn.losses.hinge_loss
mlx.nn.losses.huber_loss
mlx.nn.losses.kl_div_loss
@@ -452,14 +457,26 @@
mlx.nn.losses.triplet_loss
+Initializers
-Optimizers
+
+Optimizers
-Tree Utils
@@ -441,6 +445,7 @@
mlx.nn.losses.binary_cross_entropy
mlx.nn.losses.cosine_similarity_loss
mlx.nn.losses.cross_entropy
+mlx.nn.losses.gaussian_nll_loss
mlx.nn.losses.hinge_loss
mlx.nn.losses.huber_loss
mlx.nn.losses.kl_div_loss
@@ -452,14 +457,26 @@
mlx.nn.losses.triplet_loss
+Initializers
-Optimizers
+
+Optimizers
-Tree Utils
@@ -441,6 +445,7 @@
mlx.nn.losses.binary_cross_entropy
mlx.nn.losses.cosine_similarity_loss
mlx.nn.losses.cross_entropy
+mlx.nn.losses.gaussian_nll_loss
mlx.nn.losses.hinge_loss
mlx.nn.losses.huber_loss
mlx.nn.losses.kl_div_loss
@@ -452,14 +457,26 @@
mlx.nn.losses.triplet_loss
+Initializers
-Optimizers
+
+Optimizers
-Tree Utils
@@ -441,6 +445,7 @@
mlx.nn.losses.binary_cross_entropy
mlx.nn.losses.cosine_similarity_loss
mlx.nn.losses.cross_entropy
+mlx.nn.losses.gaussian_nll_loss
mlx.nn.losses.hinge_loss
mlx.nn.losses.huber_loss
mlx.nn.losses.kl_div_loss
@@ -452,14 +457,26 @@
mlx.nn.losses.triplet_loss
+Initializers
-Optimizers
+
+Optimizers
-Tree Utils
@@ -441,6 +445,7 @@
mlx.nn.losses.binary_cross_entropy
mlx.nn.losses.cosine_similarity_loss
mlx.nn.losses.cross_entropy
+mlx.nn.losses.gaussian_nll_loss
mlx.nn.losses.hinge_loss
mlx.nn.losses.huber_loss
mlx.nn.losses.kl_div_loss
@@ -452,14 +457,26 @@
mlx.nn.losses.triplet_loss
+Initializers
-Optimizers
+
+Optimizers
-Tree Utils
@@ -441,6 +445,7 @@
mlx.nn.losses.binary_cross_entropy
mlx.nn.losses.cosine_similarity_loss
mlx.nn.losses.cross_entropy
+mlx.nn.losses.gaussian_nll_loss
mlx.nn.losses.hinge_loss
mlx.nn.losses.huber_loss
mlx.nn.losses.kl_div_loss
@@ -452,14 +457,26 @@
mlx.nn.losses.triplet_loss
+Initializers
-Optimizers
+
+Optimizers
-Tree Utils
@@ -441,6 +445,7 @@
mlx.nn.losses.binary_cross_entropy
mlx.nn.losses.cosine_similarity_loss
mlx.nn.losses.cross_entropy
+mlx.nn.losses.gaussian_nll_loss
mlx.nn.losses.hinge_loss
mlx.nn.losses.huber_loss
mlx.nn.losses.kl_div_loss
@@ -452,14 +457,26 @@
mlx.nn.losses.triplet_loss
+Initializers
-Optimizers
+
+Optimizers
-Tree Utils
@@ -441,6 +445,7 @@
mlx.nn.losses.binary_cross_entropy
mlx.nn.losses.cosine_similarity_loss
mlx.nn.losses.cross_entropy
+mlx.nn.losses.gaussian_nll_loss
mlx.nn.losses.hinge_loss
mlx.nn.losses.huber_loss
mlx.nn.losses.kl_div_loss
@@ -452,14 +457,26 @@
mlx.nn.losses.triplet_loss
+Initializers
-Optimizers
+
+Optimizers
-Tree Utils
+Tree Utils
mlx.utils.tree_flatten
mlx.utils.tree_unflatten
mlx.utils.tree_map
diff --git a/docs/build/html/python/_autosummary/mlx.core.linalg.norm.html b/docs/build/html/python/_autosummary/mlx.core.linalg.norm.html
index 044209bf6..60bbf5b27 100644
--- a/docs/build/html/python/_autosummary/mlx.core.linalg.norm.html
+++ b/docs/build/html/python/_autosummary/mlx.core.linalg.norm.html
@@ -9,7 +9,7 @@
- mlx.core.linalg.norm — MLX 0.0.9 documentation
+ mlx.core.linalg.norm — MLX 0.1.0 documentation
@@ -47,7 +47,7 @@
-
+
@@ -135,8 +135,8 @@
-
-
+
+
@@ -242,6 +242,8 @@
mlx.core.cos
mlx.core.cosh
mlx.core.dequantize
+mlx.core.diag
+mlx.core.diagonal
mlx.core.divide
mlx.core.divmod
mlx.core.equal
@@ -352,7 +354,6 @@
mlx.core.jvp
mlx.core.vjp
mlx.core.vmap
-mlx.core.simplify
FFT
Linear Algebra
Neural Networks
@@ -434,6 +437,7 @@
mlx.nn.prelu
mlx.nn.relu
mlx.nn.selu
+mlx.nn.softshrink
mlx.nn.silu
mlx.nn.step
@@ -442,6 +446,7 @@
mlx.nn.losses.binary_cross_entropy
mlx.nn.losses.cosine_similarity_loss
mlx.nn.losses.cross_entropy
+mlx.nn.losses.gaussian_nll_loss
mlx.nn.losses.hinge_loss
mlx.nn.losses.huber_loss
mlx.nn.losses.kl_div_loss
@@ -453,14 +458,26 @@
mlx.nn.losses.triplet_loss
+Initializers
-Optimizers
+
+Optimizers
-Tree Utils
+Tree Utils
mlx.utils.tree_flatten
mlx.utils.tree_unflatten
mlx.utils.tree_map
@@ -832,11 +849,11 @@ Baltimore, MD, Johns Hopkins University Press, 1985, pg. 15
next
-
Neural Networks
+
mlx.core.linalg.qr
diff --git a/docs/build/html/python/_autosummary/mlx.core.linalg.qr.html b/docs/build/html/python/_autosummary/mlx.core.linalg.qr.html
new file mode 100644
index 000000000..728b6c107
--- /dev/null
+++ b/docs/build/html/python/_autosummary/mlx.core.linalg.qr.html
@@ -0,0 +1,790 @@
+
+
+
+
+
+
+
+
+
+
+
+ mlx.core.linalg.qr — MLX 0.1.0 documentation
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+ Skip to main content
+
+
+
+
+
+
+ Back to top
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
mlx.core.linalg.qr
+
+
+
+
+
+
+
+
+
+
+mlx.core.linalg.qr
+
+
+mlx.core.linalg. qr ( a : array , * , stream : Union [ None , Stream , Device ] = None )
+The QR factorizatoin of the input matrix.
+This function supports arrays with at least 2 dimensions. The matrices
+which are factorized are assumed to be in the last two dimensions of
+the input.
+
+Parameters:
+
+
+Returns:
+The Q
and R
matrices.
+
+Return type:
+tuple (array , array )
+
+
+Example
+>>> A = mx . array ([[ 2. , 3. ], [ 1. , 2. ]])
+>>> Q , R = mx . linalg . qr ( A , stream = mx . cpu )
+>>> Q
+array([[-0.894427, -0.447214],
+ [-0.447214, 0.894427]], dtype=float32)
+>>> R
+array([[-2.23607, -3.57771],
+ [0, 0.447214]], dtype=float32)
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/docs/build/html/python/_autosummary/mlx.core.linspace.html b/docs/build/html/python/_autosummary/mlx.core.linspace.html
index 7e5ffc63d..cc8656913 100644
--- a/docs/build/html/python/_autosummary/mlx.core.linspace.html
+++ b/docs/build/html/python/_autosummary/mlx.core.linspace.html
@@ -9,7 +9,7 @@
- mlx.core.linspace — MLX 0.0.9 documentation
+ mlx.core.linspace — MLX 0.1.0 documentation
@@ -134,8 +134,8 @@
-
-
+
+
@@ -241,6 +241,8 @@
mlx.core.cos
mlx.core.cosh
mlx.core.dequantize
+mlx.core.diag
+mlx.core.diagonal
mlx.core.divide
mlx.core.divmod
mlx.core.equal
@@ -351,7 +353,6 @@
mlx.core.jvp
mlx.core.vjp
mlx.core.vmap
-mlx.core.simplify
FFT
Linear Algebra
Neural Networks
@@ -433,6 +436,7 @@
mlx.nn.prelu
mlx.nn.relu
mlx.nn.selu
+mlx.nn.softshrink
mlx.nn.silu
mlx.nn.step
@@ -441,6 +445,7 @@
mlx.nn.losses.binary_cross_entropy
mlx.nn.losses.cosine_similarity_loss
mlx.nn.losses.cross_entropy
+mlx.nn.losses.gaussian_nll_loss
mlx.nn.losses.hinge_loss
mlx.nn.losses.huber_loss
mlx.nn.losses.kl_div_loss
@@ -452,14 +457,26 @@
mlx.nn.losses.triplet_loss
+Initializers
-Optimizers
+
+Optimizers
-Tree Utils
@@ -441,6 +445,7 @@
mlx.nn.losses.binary_cross_entropy
mlx.nn.losses.cosine_similarity_loss
mlx.nn.losses.cross_entropy
+mlx.nn.losses.gaussian_nll_loss
mlx.nn.losses.hinge_loss
mlx.nn.losses.huber_loss
mlx.nn.losses.kl_div_loss
@@ -452,14 +457,26 @@
mlx.nn.losses.triplet_loss
+Initializers
-Optimizers
+
+Optimizers
-Tree Utils
+Tree Utils
mlx.utils.tree_flatten
mlx.utils.tree_unflatten
mlx.utils.tree_map
@@ -650,7 +667,7 @@ document.write(`
mlx.core.load
-mlx.core. load ( file : str , / , format : Optional [ str ] = None , * , stream : Union [ None , Stream , Device ] = None ) → Union [ array , Dict [ str , array ] ]
+mlx.core. load ( file : str , / , format : Optional [ str ] = None , return_metadata : bool = False , * , stream : Union [ None , Stream , Device ] = None ) → Union [ array , Dict [ str , array ] ]
Load array(s) from a binary file.
The supported formats are .npy
, .npz
, .safetensors
, and .gguf
.
@@ -660,11 +677,16 @@ document.write(`
format (str , optional ) – Format of the file. If None
, the format
is inferred from the file extension. Supported formats: npy
,
npz
, and safetensors
. Default: None
.
+return_metadata (bool , optional ) – Load the metadata for formats which
+support matadata. The metadata will be returned as an additional
+dictionary.
Returns:
A single array if loading from a .npy
file or a dict mapping
-names to arrays if loading from a .npz
or .safetensors
file.
+names to arrays if loading from a .npz
or .safetensors
file.
+If return_metadata` is ``True
an additional dictionary of metadata
+will be returned.
Return type:
result (array , dict )
diff --git a/docs/build/html/python/_autosummary/mlx.core.log.html b/docs/build/html/python/_autosummary/mlx.core.log.html
index 5389c88e8..a33053a0e 100644
--- a/docs/build/html/python/_autosummary/mlx.core.log.html
+++ b/docs/build/html/python/_autosummary/mlx.core.log.html
@@ -9,7 +9,7 @@
- mlx.core.log — MLX 0.0.9 documentation
+ mlx.core.log — MLX 0.1.0 documentation
@@ -134,8 +134,8 @@
-
-
+
+
@@ -241,6 +241,8 @@
mlx.core.cos
mlx.core.cosh
mlx.core.dequantize
+mlx.core.diag
+mlx.core.diagonal
mlx.core.divide
mlx.core.divmod
mlx.core.equal
@@ -351,7 +353,6 @@
mlx.core.jvp
mlx.core.vjp
mlx.core.vmap
-mlx.core.simplify
FFT
Linear Algebra
Neural Networks
@@ -433,6 +436,7 @@
mlx.nn.prelu
mlx.nn.relu
mlx.nn.selu
+mlx.nn.softshrink
mlx.nn.silu
mlx.nn.step
@@ -441,6 +445,7 @@
mlx.nn.losses.binary_cross_entropy
mlx.nn.losses.cosine_similarity_loss
mlx.nn.losses.cross_entropy
+mlx.nn.losses.gaussian_nll_loss
mlx.nn.losses.hinge_loss
mlx.nn.losses.huber_loss
mlx.nn.losses.kl_div_loss
@@ -452,14 +457,26 @@
mlx.nn.losses.triplet_loss
+Initializers
-Optimizers
+
+Optimizers
-Tree Utils
@@ -441,6 +445,7 @@
mlx.nn.losses.binary_cross_entropy
mlx.nn.losses.cosine_similarity_loss
mlx.nn.losses.cross_entropy
+mlx.nn.losses.gaussian_nll_loss
mlx.nn.losses.hinge_loss
mlx.nn.losses.huber_loss
mlx.nn.losses.kl_div_loss
@@ -452,14 +457,26 @@
mlx.nn.losses.triplet_loss
+Initializers
-Optimizers
+
+Optimizers
-Tree Utils
@@ -441,6 +445,7 @@
mlx.nn.losses.binary_cross_entropy
mlx.nn.losses.cosine_similarity_loss
mlx.nn.losses.cross_entropy
+mlx.nn.losses.gaussian_nll_loss
mlx.nn.losses.hinge_loss
mlx.nn.losses.huber_loss
mlx.nn.losses.kl_div_loss
@@ -452,14 +457,26 @@
mlx.nn.losses.triplet_loss
+Initializers
-Optimizers
+
+Optimizers
-Tree Utils
@@ -441,6 +445,7 @@
mlx.nn.losses.binary_cross_entropy
mlx.nn.losses.cosine_similarity_loss
mlx.nn.losses.cross_entropy
+mlx.nn.losses.gaussian_nll_loss
mlx.nn.losses.hinge_loss
mlx.nn.losses.huber_loss
mlx.nn.losses.kl_div_loss
@@ -452,14 +457,26 @@
mlx.nn.losses.triplet_loss
+Initializers
-Optimizers
+
+Optimizers
-Tree Utils
@@ -441,6 +445,7 @@
mlx.nn.losses.binary_cross_entropy
mlx.nn.losses.cosine_similarity_loss
mlx.nn.losses.cross_entropy
+mlx.nn.losses.gaussian_nll_loss
mlx.nn.losses.hinge_loss
mlx.nn.losses.huber_loss
mlx.nn.losses.kl_div_loss
@@ -452,14 +457,26 @@
mlx.nn.losses.triplet_loss
+Initializers
-Optimizers
+
+Optimizers
-Tree Utils
@@ -441,6 +445,7 @@
mlx.nn.losses.binary_cross_entropy
mlx.nn.losses.cosine_similarity_loss
mlx.nn.losses.cross_entropy
+mlx.nn.losses.gaussian_nll_loss
mlx.nn.losses.hinge_loss
mlx.nn.losses.huber_loss
mlx.nn.losses.kl_div_loss
@@ -452,14 +457,26 @@
mlx.nn.losses.triplet_loss
+Initializers
-Optimizers
+
+Optimizers
-Tree Utils
@@ -441,6 +445,7 @@
mlx.nn.losses.binary_cross_entropy
mlx.nn.losses.cosine_similarity_loss
mlx.nn.losses.cross_entropy
+mlx.nn.losses.gaussian_nll_loss
mlx.nn.losses.hinge_loss
mlx.nn.losses.huber_loss
mlx.nn.losses.kl_div_loss
@@ -452,14 +457,26 @@
mlx.nn.losses.triplet_loss
+Initializers
-Optimizers
+
+Optimizers
-Tree Utils
@@ -441,6 +445,7 @@
mlx.nn.losses.binary_cross_entropy
mlx.nn.losses.cosine_similarity_loss
mlx.nn.losses.cross_entropy
+mlx.nn.losses.gaussian_nll_loss
mlx.nn.losses.hinge_loss
mlx.nn.losses.huber_loss
mlx.nn.losses.kl_div_loss
@@ -452,14 +457,26 @@
mlx.nn.losses.triplet_loss
+Initializers
-Optimizers
+
+Optimizers
-Tree Utils
@@ -441,6 +445,7 @@
mlx.nn.losses.binary_cross_entropy
mlx.nn.losses.cosine_similarity_loss
mlx.nn.losses.cross_entropy
+mlx.nn.losses.gaussian_nll_loss
mlx.nn.losses.hinge_loss
mlx.nn.losses.huber_loss
mlx.nn.losses.kl_div_loss
@@ -452,14 +457,26 @@
mlx.nn.losses.triplet_loss
+Initializers
-Optimizers
+
+Optimizers
-Tree Utils
@@ -441,6 +445,7 @@
mlx.nn.losses.binary_cross_entropy
mlx.nn.losses.cosine_similarity_loss
mlx.nn.losses.cross_entropy
+mlx.nn.losses.gaussian_nll_loss
mlx.nn.losses.hinge_loss
mlx.nn.losses.huber_loss
mlx.nn.losses.kl_div_loss
@@ -452,14 +457,26 @@
mlx.nn.losses.triplet_loss
+Initializers
-Optimizers
+
+Optimizers
-Tree Utils
@@ -441,6 +445,7 @@
mlx.nn.losses.binary_cross_entropy
mlx.nn.losses.cosine_similarity_loss
mlx.nn.losses.cross_entropy
+mlx.nn.losses.gaussian_nll_loss
mlx.nn.losses.hinge_loss
mlx.nn.losses.huber_loss
mlx.nn.losses.kl_div_loss
@@ -452,14 +457,26 @@
mlx.nn.losses.triplet_loss
+Initializers
-Optimizers
+
+Optimizers
-Tree Utils
@@ -441,6 +445,7 @@
mlx.nn.losses.binary_cross_entropy
mlx.nn.losses.cosine_similarity_loss
mlx.nn.losses.cross_entropy
+mlx.nn.losses.gaussian_nll_loss
mlx.nn.losses.hinge_loss
mlx.nn.losses.huber_loss
mlx.nn.losses.kl_div_loss
@@ -452,14 +457,26 @@
mlx.nn.losses.triplet_loss
+Initializers
-Optimizers
+
+Optimizers
-Tree Utils
@@ -441,6 +445,7 @@
mlx.nn.losses.binary_cross_entropy
mlx.nn.losses.cosine_similarity_loss
mlx.nn.losses.cross_entropy
+mlx.nn.losses.gaussian_nll_loss
mlx.nn.losses.hinge_loss
mlx.nn.losses.huber_loss
mlx.nn.losses.kl_div_loss
@@ -452,14 +457,26 @@
mlx.nn.losses.triplet_loss
+Initializers
-Optimizers
+
+Optimizers
-Tree Utils
@@ -441,6 +445,7 @@
mlx.nn.losses.binary_cross_entropy
mlx.nn.losses.cosine_similarity_loss
mlx.nn.losses.cross_entropy
+mlx.nn.losses.gaussian_nll_loss
mlx.nn.losses.hinge_loss
mlx.nn.losses.huber_loss
mlx.nn.losses.kl_div_loss
@@ -452,14 +457,26 @@
mlx.nn.losses.triplet_loss
+Initializers
-Optimizers
+
+Optimizers
-Tree Utils
@@ -441,6 +445,7 @@
mlx.nn.losses.binary_cross_entropy
mlx.nn.losses.cosine_similarity_loss
mlx.nn.losses.cross_entropy
+mlx.nn.losses.gaussian_nll_loss
mlx.nn.losses.hinge_loss
mlx.nn.losses.huber_loss
mlx.nn.losses.kl_div_loss
@@ -452,14 +457,26 @@
mlx.nn.losses.triplet_loss
+Initializers
-Optimizers
+
+Optimizers
-Tree Utils
@@ -441,6 +445,7 @@
mlx.nn.losses.binary_cross_entropy
mlx.nn.losses.cosine_similarity_loss
mlx.nn.losses.cross_entropy
+mlx.nn.losses.gaussian_nll_loss
mlx.nn.losses.hinge_loss
mlx.nn.losses.huber_loss
mlx.nn.losses.kl_div_loss
@@ -452,14 +457,26 @@
mlx.nn.losses.triplet_loss
+Initializers
-Optimizers
+
+Optimizers
-Tree Utils
@@ -441,6 +445,7 @@
mlx.nn.losses.binary_cross_entropy
mlx.nn.losses.cosine_similarity_loss
mlx.nn.losses.cross_entropy
+mlx.nn.losses.gaussian_nll_loss
mlx.nn.losses.hinge_loss
mlx.nn.losses.huber_loss
mlx.nn.losses.kl_div_loss
@@ -452,14 +457,26 @@
mlx.nn.losses.triplet_loss
+Initializers
-Optimizers
+
+Optimizers
-Tree Utils
@@ -441,6 +445,7 @@
mlx.nn.losses.binary_cross_entropy
mlx.nn.losses.cosine_similarity_loss
mlx.nn.losses.cross_entropy
+mlx.nn.losses.gaussian_nll_loss
mlx.nn.losses.hinge_loss
mlx.nn.losses.huber_loss
mlx.nn.losses.kl_div_loss
@@ -452,14 +457,26 @@
mlx.nn.losses.triplet_loss
+Initializers
-Optimizers
+
+Optimizers
-Tree Utils
@@ -441,6 +445,7 @@
mlx.nn.losses.binary_cross_entropy
mlx.nn.losses.cosine_similarity_loss
mlx.nn.losses.cross_entropy
+mlx.nn.losses.gaussian_nll_loss
mlx.nn.losses.hinge_loss
mlx.nn.losses.huber_loss
mlx.nn.losses.kl_div_loss
@@ -452,14 +457,26 @@
mlx.nn.losses.triplet_loss
+Initializers
-Optimizers
+
+Optimizers
-Tree Utils
@@ -441,6 +445,7 @@
mlx.nn.losses.binary_cross_entropy
mlx.nn.losses.cosine_similarity_loss
mlx.nn.losses.cross_entropy
+mlx.nn.losses.gaussian_nll_loss
mlx.nn.losses.hinge_loss
mlx.nn.losses.huber_loss
mlx.nn.losses.kl_div_loss
@@ -452,14 +457,26 @@
mlx.nn.losses.triplet_loss
+Initializers
-Optimizers
+
+Optimizers
-Tree Utils
@@ -441,6 +445,7 @@
mlx.nn.losses.binary_cross_entropy
mlx.nn.losses.cosine_similarity_loss
mlx.nn.losses.cross_entropy
+mlx.nn.losses.gaussian_nll_loss
mlx.nn.losses.hinge_loss
mlx.nn.losses.huber_loss
mlx.nn.losses.kl_div_loss
@@ -452,14 +457,26 @@
mlx.nn.losses.triplet_loss
+Initializers
-Optimizers
+
+Optimizers
-Tree Utils
@@ -441,6 +445,7 @@
mlx.nn.losses.binary_cross_entropy
mlx.nn.losses.cosine_similarity_loss
mlx.nn.losses.cross_entropy
+mlx.nn.losses.gaussian_nll_loss
mlx.nn.losses.hinge_loss
mlx.nn.losses.huber_loss
mlx.nn.losses.kl_div_loss
@@ -452,14 +457,26 @@
mlx.nn.losses.triplet_loss
+Initializers
-Optimizers
+
+Optimizers
-Tree Utils
@@ -441,6 +445,7 @@
mlx.nn.losses.binary_cross_entropy
mlx.nn.losses.cosine_similarity_loss
mlx.nn.losses.cross_entropy
+mlx.nn.losses.gaussian_nll_loss
mlx.nn.losses.hinge_loss
mlx.nn.losses.huber_loss
mlx.nn.losses.kl_div_loss
@@ -452,14 +457,26 @@
mlx.nn.losses.triplet_loss
+Initializers
-Optimizers
+
+Optimizers
-Tree Utils
@@ -441,6 +445,7 @@
mlx.nn.losses.binary_cross_entropy
mlx.nn.losses.cosine_similarity_loss
mlx.nn.losses.cross_entropy
+mlx.nn.losses.gaussian_nll_loss
mlx.nn.losses.hinge_loss
mlx.nn.losses.huber_loss
mlx.nn.losses.kl_div_loss
@@ -452,14 +457,26 @@
mlx.nn.losses.triplet_loss
+Initializers
-Optimizers
+
+Optimizers
-Tree Utils
@@ -441,6 +445,7 @@
mlx.nn.losses.binary_cross_entropy
mlx.nn.losses.cosine_similarity_loss
mlx.nn.losses.cross_entropy
+mlx.nn.losses.gaussian_nll_loss
mlx.nn.losses.hinge_loss
mlx.nn.losses.huber_loss
mlx.nn.losses.kl_div_loss
@@ -452,14 +457,26 @@
mlx.nn.losses.triplet_loss
+Initializers
-Optimizers
+
+Optimizers
-Tree Utils
@@ -442,6 +446,7 @@
mlx.nn.losses.binary_cross_entropy
mlx.nn.losses.cosine_similarity_loss
mlx.nn.losses.cross_entropy
+mlx.nn.losses.gaussian_nll_loss
mlx.nn.losses.hinge_loss
mlx.nn.losses.huber_loss
mlx.nn.losses.kl_div_loss
@@ -453,14 +458,26 @@
mlx.nn.losses.triplet_loss
+Initializers
-Optimizers
+
+Optimizers
-Tree Utils
@@ -441,6 +445,7 @@
mlx.nn.losses.binary_cross_entropy
mlx.nn.losses.cosine_similarity_loss
mlx.nn.losses.cross_entropy
+mlx.nn.losses.gaussian_nll_loss
mlx.nn.losses.hinge_loss
mlx.nn.losses.huber_loss
mlx.nn.losses.kl_div_loss
@@ -452,14 +457,26 @@
mlx.nn.losses.triplet_loss
+Initializers
-Optimizers
+
+Optimizers
-Tree Utils
@@ -441,6 +445,7 @@
mlx.nn.losses.binary_cross_entropy
mlx.nn.losses.cosine_similarity_loss
mlx.nn.losses.cross_entropy
+mlx.nn.losses.gaussian_nll_loss
mlx.nn.losses.hinge_loss
mlx.nn.losses.huber_loss
mlx.nn.losses.kl_div_loss
@@ -452,14 +457,26 @@
mlx.nn.losses.triplet_loss
+Initializers
-Optimizers
+
+Optimizers
-Tree Utils
@@ -441,6 +445,7 @@
mlx.nn.losses.binary_cross_entropy
mlx.nn.losses.cosine_similarity_loss
mlx.nn.losses.cross_entropy
+mlx.nn.losses.gaussian_nll_loss
mlx.nn.losses.hinge_loss
mlx.nn.losses.huber_loss
mlx.nn.losses.kl_div_loss
@@ -452,14 +457,26 @@
mlx.nn.losses.triplet_loss
+Initializers
-Optimizers
+
+Optimizers
-Tree Utils
@@ -441,6 +445,7 @@
mlx.nn.losses.binary_cross_entropy
mlx.nn.losses.cosine_similarity_loss
mlx.nn.losses.cross_entropy
+mlx.nn.losses.gaussian_nll_loss
mlx.nn.losses.hinge_loss
mlx.nn.losses.huber_loss
mlx.nn.losses.kl_div_loss
@@ -452,14 +457,26 @@
mlx.nn.losses.triplet_loss
+Initializers
-Optimizers
+
+Optimizers
-Tree Utils
@@ -441,6 +445,7 @@
mlx.nn.losses.binary_cross_entropy
mlx.nn.losses.cosine_similarity_loss
mlx.nn.losses.cross_entropy
+mlx.nn.losses.gaussian_nll_loss
mlx.nn.losses.hinge_loss
mlx.nn.losses.huber_loss
mlx.nn.losses.kl_div_loss
@@ -452,14 +457,26 @@
mlx.nn.losses.triplet_loss
+Initializers
-Optimizers
+
+Optimizers
-Tree Utils
@@ -441,6 +445,7 @@
mlx.nn.losses.binary_cross_entropy
mlx.nn.losses.cosine_similarity_loss
mlx.nn.losses.cross_entropy
+mlx.nn.losses.gaussian_nll_loss
mlx.nn.losses.hinge_loss
mlx.nn.losses.huber_loss
mlx.nn.losses.kl_div_loss
@@ -452,14 +457,26 @@
mlx.nn.losses.triplet_loss
+Initializers
-Optimizers
+
+Optimizers
-Tree Utils
@@ -441,6 +445,7 @@
mlx.nn.losses.binary_cross_entropy
mlx.nn.losses.cosine_similarity_loss
mlx.nn.losses.cross_entropy
+mlx.nn.losses.gaussian_nll_loss
mlx.nn.losses.hinge_loss
mlx.nn.losses.huber_loss
mlx.nn.losses.kl_div_loss
@@ -452,14 +457,26 @@
mlx.nn.losses.triplet_loss
+Initializers
-Optimizers
+
+Optimizers
-Tree Utils
@@ -441,6 +445,7 @@
mlx.nn.losses.binary_cross_entropy
mlx.nn.losses.cosine_similarity_loss
mlx.nn.losses.cross_entropy
+mlx.nn.losses.gaussian_nll_loss
mlx.nn.losses.hinge_loss
mlx.nn.losses.huber_loss
mlx.nn.losses.kl_div_loss
@@ -452,14 +457,26 @@
mlx.nn.losses.triplet_loss
+Initializers
-Optimizers
+
+Optimizers
-Tree Utils
@@ -441,6 +445,7 @@
mlx.nn.losses.binary_cross_entropy
mlx.nn.losses.cosine_similarity_loss
mlx.nn.losses.cross_entropy
+mlx.nn.losses.gaussian_nll_loss
mlx.nn.losses.hinge_loss
mlx.nn.losses.huber_loss
mlx.nn.losses.kl_div_loss
@@ -452,14 +457,26 @@
mlx.nn.losses.triplet_loss
+Initializers
-Optimizers
+
+Optimizers
-Tree Utils
@@ -441,6 +445,7 @@
mlx.nn.losses.binary_cross_entropy
mlx.nn.losses.cosine_similarity_loss
mlx.nn.losses.cross_entropy
+mlx.nn.losses.gaussian_nll_loss
mlx.nn.losses.hinge_loss
mlx.nn.losses.huber_loss
mlx.nn.losses.kl_div_loss
@@ -452,14 +457,26 @@
mlx.nn.losses.triplet_loss
+Initializers
-Optimizers
+
+Optimizers
-Tree Utils
@@ -441,6 +445,7 @@
mlx.nn.losses.binary_cross_entropy
mlx.nn.losses.cosine_similarity_loss
mlx.nn.losses.cross_entropy
+mlx.nn.losses.gaussian_nll_loss
mlx.nn.losses.hinge_loss
mlx.nn.losses.huber_loss
mlx.nn.losses.kl_div_loss
@@ -452,14 +457,26 @@
mlx.nn.losses.triplet_loss
+Initializers
-Optimizers
+
+Optimizers
-Tree Utils
@@ -441,6 +445,7 @@
mlx.nn.losses.binary_cross_entropy
mlx.nn.losses.cosine_similarity_loss
mlx.nn.losses.cross_entropy
+mlx.nn.losses.gaussian_nll_loss
mlx.nn.losses.hinge_loss
mlx.nn.losses.huber_loss
mlx.nn.losses.kl_div_loss
@@ -452,14 +457,26 @@
mlx.nn.losses.triplet_loss
+Initializers
-Optimizers
+
+Optimizers
-Tree Utils
@@ -441,6 +445,7 @@
mlx.nn.losses.binary_cross_entropy
mlx.nn.losses.cosine_similarity_loss
mlx.nn.losses.cross_entropy
+mlx.nn.losses.gaussian_nll_loss
mlx.nn.losses.hinge_loss
mlx.nn.losses.huber_loss
mlx.nn.losses.kl_div_loss
@@ -452,14 +457,26 @@
mlx.nn.losses.triplet_loss
+Initializers
-Optimizers
+
+Optimizers
-Tree Utils
@@ -441,6 +445,7 @@
mlx.nn.losses.binary_cross_entropy
mlx.nn.losses.cosine_similarity_loss
mlx.nn.losses.cross_entropy
+mlx.nn.losses.gaussian_nll_loss
mlx.nn.losses.hinge_loss
mlx.nn.losses.huber_loss
mlx.nn.losses.kl_div_loss
@@ -452,14 +457,26 @@
mlx.nn.losses.triplet_loss
+Initializers
-Optimizers
+
+Optimizers
-Tree Utils
+Tree Utils
mlx.utils.tree_flatten
mlx.utils.tree_unflatten
mlx.utils.tree_map
@@ -657,7 +674,7 @@ document.write(`
a (array ) – Input array.
shape (tuple ( int ) ) – New shape.
-stream (Stream , optional ) – Stream or device. Defaults to `None`
+
stream (Stream , optional ) – Stream or device. Defaults to None
in which case the default stream of the default device is used.
diff --git a/docs/build/html/python/_autosummary/mlx.core.round.html b/docs/build/html/python/_autosummary/mlx.core.round.html
index d8180347b..65f231417 100644
--- a/docs/build/html/python/_autosummary/mlx.core.round.html
+++ b/docs/build/html/python/_autosummary/mlx.core.round.html
@@ -9,7 +9,7 @@
- mlx.core.round — MLX 0.0.9 documentation
+ mlx.core.round — MLX 0.1.0 documentation
@@ -134,8 +134,8 @@
-
-
+
+
@@ -241,6 +241,8 @@
mlx.core.cos
mlx.core.cosh
mlx.core.dequantize
+mlx.core.diag
+mlx.core.diagonal
mlx.core.divide
mlx.core.divmod
mlx.core.equal
@@ -351,7 +353,6 @@
mlx.core.jvp
mlx.core.vjp
mlx.core.vmap
-mlx.core.simplify
FFT
Linear Algebra
Neural Networks
@@ -433,6 +436,7 @@
mlx.nn.prelu
mlx.nn.relu
mlx.nn.selu
+mlx.nn.softshrink
mlx.nn.silu
mlx.nn.step
@@ -441,6 +445,7 @@
mlx.nn.losses.binary_cross_entropy
mlx.nn.losses.cosine_similarity_loss
mlx.nn.losses.cross_entropy
+mlx.nn.losses.gaussian_nll_loss
mlx.nn.losses.hinge_loss
mlx.nn.losses.huber_loss
mlx.nn.losses.kl_div_loss
@@ -452,14 +457,26 @@
mlx.nn.losses.triplet_loss
+Initializers
-Optimizers
+
+Optimizers
-Tree Utils
@@ -441,6 +445,7 @@
mlx.nn.losses.binary_cross_entropy
mlx.nn.losses.cosine_similarity_loss
mlx.nn.losses.cross_entropy
+mlx.nn.losses.gaussian_nll_loss
mlx.nn.losses.hinge_loss
mlx.nn.losses.huber_loss
mlx.nn.losses.kl_div_loss
@@ -452,14 +457,26 @@
mlx.nn.losses.triplet_loss
+Initializers
-Optimizers
+
+Optimizers
-Tree Utils
@@ -441,6 +445,7 @@
mlx.nn.losses.binary_cross_entropy
mlx.nn.losses.cosine_similarity_loss
mlx.nn.losses.cross_entropy
+mlx.nn.losses.gaussian_nll_loss
mlx.nn.losses.hinge_loss
mlx.nn.losses.huber_loss
mlx.nn.losses.kl_div_loss
@@ -452,14 +457,26 @@
mlx.nn.losses.triplet_loss
+Initializers
-Optimizers
+
+Optimizers
-Tree Utils
@@ -441,6 +445,7 @@
mlx.nn.losses.binary_cross_entropy
mlx.nn.losses.cosine_similarity_loss
mlx.nn.losses.cross_entropy
+mlx.nn.losses.gaussian_nll_loss
mlx.nn.losses.hinge_loss
mlx.nn.losses.huber_loss
mlx.nn.losses.kl_div_loss
@@ -452,14 +457,26 @@
mlx.nn.losses.triplet_loss
+Initializers
-Optimizers
+
+Optimizers
-Tree Utils
+Tree Utils
mlx.utils.tree_flatten
mlx.utils.tree_unflatten
mlx.utils.tree_map
@@ -650,7 +667,7 @@ document.write(`
mlx.core.save_gguf
-mlx.core. save_gguf ( file : str , arrays : Dict [ str , array ] )
+mlx.core. save_gguf ( file : str , arrays : Dict [ str , array ] , metadata : Dict [ str , Union [ array , str , List [ str ] ] ] )
Save array(s) to a binary file in .gguf
format.
See the GGUF documentation for
more information on the format.
@@ -659,6 +676,9 @@ more information on the format.
file (file , str ) – File in which the array is saved.
arrays (dict ( str , array ) ) – The dictionary of names to arrays to be saved.
+metadata (dict ( str , Union [ array , str , list ( str ) ] ) ) – The dictionary of
+metadata to be saved. The values can be a scalar or 1D obj:array ,
+a str
, or a list
of str
.
diff --git a/docs/build/html/python/_autosummary/mlx.core.save_safetensors.html b/docs/build/html/python/_autosummary/mlx.core.save_safetensors.html
index d0eea4427..31e059841 100644
--- a/docs/build/html/python/_autosummary/mlx.core.save_safetensors.html
+++ b/docs/build/html/python/_autosummary/mlx.core.save_safetensors.html
@@ -9,7 +9,7 @@
- mlx.core.save_safetensors — MLX 0.0.9 documentation
+ mlx.core.save_safetensors — MLX 0.1.0 documentation
@@ -134,8 +134,8 @@
-
-
+
+
@@ -241,6 +241,8 @@
mlx.core.cos
mlx.core.cosh
mlx.core.dequantize
+mlx.core.diag
+mlx.core.diagonal
mlx.core.divide
mlx.core.divmod
mlx.core.equal
@@ -351,7 +353,6 @@
mlx.core.jvp
mlx.core.vjp
mlx.core.vmap
-mlx.core.simplify
FFT
Linear Algebra
Neural Networks
@@ -433,6 +436,7 @@
mlx.nn.prelu
mlx.nn.relu
mlx.nn.selu
+mlx.nn.softshrink
mlx.nn.silu
mlx.nn.step
@@ -441,6 +445,7 @@
mlx.nn.losses.binary_cross_entropy
mlx.nn.losses.cosine_similarity_loss
mlx.nn.losses.cross_entropy
+mlx.nn.losses.gaussian_nll_loss
mlx.nn.losses.hinge_loss
mlx.nn.losses.huber_loss
mlx.nn.losses.kl_div_loss
@@ -452,14 +457,26 @@
mlx.nn.losses.triplet_loss
+Initializers
-Optimizers
+
+Optimizers
-Tree Utils
@@ -441,6 +445,7 @@
mlx.nn.losses.binary_cross_entropy
mlx.nn.losses.cosine_similarity_loss
mlx.nn.losses.cross_entropy
+mlx.nn.losses.gaussian_nll_loss
mlx.nn.losses.hinge_loss
mlx.nn.losses.huber_loss
mlx.nn.losses.kl_div_loss
@@ -452,14 +457,26 @@
mlx.nn.losses.triplet_loss
+Initializers
-Optimizers
+
+Optimizers
-Tree Utils
@@ -441,6 +445,7 @@
mlx.nn.losses.binary_cross_entropy
mlx.nn.losses.cosine_similarity_loss
mlx.nn.losses.cross_entropy
+mlx.nn.losses.gaussian_nll_loss
mlx.nn.losses.hinge_loss
mlx.nn.losses.huber_loss
mlx.nn.losses.kl_div_loss
@@ -452,14 +457,26 @@
mlx.nn.losses.triplet_loss
+Initializers
-Optimizers
+
+Optimizers
-Tree Utils
@@ -441,6 +445,7 @@
mlx.nn.losses.binary_cross_entropy
mlx.nn.losses.cosine_similarity_loss
mlx.nn.losses.cross_entropy
+mlx.nn.losses.gaussian_nll_loss
mlx.nn.losses.hinge_loss
mlx.nn.losses.huber_loss
mlx.nn.losses.kl_div_loss
@@ -452,14 +457,26 @@
mlx.nn.losses.triplet_loss
+Initializers
-Optimizers
+
+Optimizers
-Tree Utils
@@ -441,6 +445,7 @@
mlx.nn.losses.binary_cross_entropy
mlx.nn.losses.cosine_similarity_loss
mlx.nn.losses.cross_entropy
+mlx.nn.losses.gaussian_nll_loss
mlx.nn.losses.hinge_loss
mlx.nn.losses.huber_loss
mlx.nn.losses.kl_div_loss
@@ -452,14 +457,26 @@
mlx.nn.losses.triplet_loss
+Initializers
-Optimizers
+
+Optimizers
-Tree Utils
@@ -442,6 +446,7 @@
mlx.nn.losses.binary_cross_entropy
mlx.nn.losses.cosine_similarity_loss
mlx.nn.losses.cross_entropy
+mlx.nn.losses.gaussian_nll_loss
mlx.nn.losses.hinge_loss
mlx.nn.losses.huber_loss
mlx.nn.losses.kl_div_loss
@@ -453,14 +458,26 @@
mlx.nn.losses.triplet_loss
+Initializers
-Optimizers
+
+Optimizers
-Tree Utils
@@ -441,6 +445,7 @@
mlx.nn.losses.binary_cross_entropy
mlx.nn.losses.cosine_similarity_loss
mlx.nn.losses.cross_entropy
+mlx.nn.losses.gaussian_nll_loss
mlx.nn.losses.hinge_loss
mlx.nn.losses.huber_loss
mlx.nn.losses.kl_div_loss
@@ -452,14 +457,26 @@
mlx.nn.losses.triplet_loss
+Initializers
-Optimizers
+
+Optimizers
-Tree Utils
@@ -441,6 +445,7 @@
mlx.nn.losses.binary_cross_entropy
mlx.nn.losses.cosine_similarity_loss
mlx.nn.losses.cross_entropy
+mlx.nn.losses.gaussian_nll_loss
mlx.nn.losses.hinge_loss
mlx.nn.losses.huber_loss
mlx.nn.losses.kl_div_loss
@@ -452,14 +457,26 @@
mlx.nn.losses.triplet_loss
+Initializers
-Optimizers
+
+Optimizers
-Tree Utils
@@ -441,6 +445,7 @@
mlx.nn.losses.binary_cross_entropy
mlx.nn.losses.cosine_similarity_loss
mlx.nn.losses.cross_entropy
+mlx.nn.losses.gaussian_nll_loss
mlx.nn.losses.hinge_loss
mlx.nn.losses.huber_loss
mlx.nn.losses.kl_div_loss
@@ -452,14 +457,26 @@
mlx.nn.losses.triplet_loss
+Initializers
-Optimizers
+
+Optimizers
-Tree Utils
@@ -441,6 +445,7 @@
mlx.nn.losses.binary_cross_entropy
mlx.nn.losses.cosine_similarity_loss
mlx.nn.losses.cross_entropy
+mlx.nn.losses.gaussian_nll_loss
mlx.nn.losses.hinge_loss
mlx.nn.losses.huber_loss
mlx.nn.losses.kl_div_loss
@@ -452,14 +457,26 @@
mlx.nn.losses.triplet_loss
+Initializers
-Optimizers
+
+Optimizers
-Tree Utils
@@ -441,6 +445,7 @@
mlx.nn.losses.binary_cross_entropy
mlx.nn.losses.cosine_similarity_loss
mlx.nn.losses.cross_entropy
+mlx.nn.losses.gaussian_nll_loss
mlx.nn.losses.hinge_loss
mlx.nn.losses.huber_loss
mlx.nn.losses.kl_div_loss
@@ -452,14 +457,26 @@
mlx.nn.losses.triplet_loss
+Initializers
-Optimizers
+
+Optimizers
-Tree Utils
@@ -441,6 +445,7 @@
mlx.nn.losses.binary_cross_entropy
mlx.nn.losses.cosine_similarity_loss
mlx.nn.losses.cross_entropy
+mlx.nn.losses.gaussian_nll_loss
mlx.nn.losses.hinge_loss
mlx.nn.losses.huber_loss
mlx.nn.losses.kl_div_loss
@@ -452,14 +457,26 @@
mlx.nn.losses.triplet_loss
+Initializers
-Optimizers
+
+Optimizers
-Tree Utils
@@ -441,6 +445,7 @@
mlx.nn.losses.binary_cross_entropy
mlx.nn.losses.cosine_similarity_loss
mlx.nn.losses.cross_entropy
+mlx.nn.losses.gaussian_nll_loss
mlx.nn.losses.hinge_loss
mlx.nn.losses.huber_loss
mlx.nn.losses.kl_div_loss
@@ -452,14 +457,26 @@
mlx.nn.losses.triplet_loss
+Initializers
-Optimizers
+
+Optimizers
-Tree Utils
@@ -441,6 +445,7 @@
mlx.nn.losses.binary_cross_entropy
mlx.nn.losses.cosine_similarity_loss
mlx.nn.losses.cross_entropy
+mlx.nn.losses.gaussian_nll_loss
mlx.nn.losses.hinge_loss
mlx.nn.losses.huber_loss
mlx.nn.losses.kl_div_loss
@@ -452,14 +457,26 @@
mlx.nn.losses.triplet_loss
+Initializers
-Optimizers
+
+Optimizers
-Tree Utils
@@ -441,6 +445,7 @@
mlx.nn.losses.binary_cross_entropy
mlx.nn.losses.cosine_similarity_loss
mlx.nn.losses.cross_entropy
+mlx.nn.losses.gaussian_nll_loss
mlx.nn.losses.hinge_loss
mlx.nn.losses.huber_loss
mlx.nn.losses.kl_div_loss
@@ -452,14 +457,26 @@
mlx.nn.losses.triplet_loss
+Initializers
-Optimizers
+
+Optimizers
-Tree Utils
+Tree Utils
mlx.utils.tree_flatten
mlx.utils.tree_unflatten
mlx.utils.tree_map
@@ -657,7 +674,7 @@ document.write(`
a (array ) – Input array.
axis (int or tuple ( int ) , optional ) – Axes to remove. Defaults
-removed. (to `None` in which case all size one axes are ) –
+removed. (to None in which case all size one axes are ) –
Returns:
diff --git a/docs/build/html/python/_autosummary/mlx.core.stack.html b/docs/build/html/python/_autosummary/mlx.core.stack.html
index 128976fb0..9f2d85a50 100644
--- a/docs/build/html/python/_autosummary/mlx.core.stack.html
+++ b/docs/build/html/python/_autosummary/mlx.core.stack.html
@@ -9,7 +9,7 @@
- mlx.core.stack — MLX 0.0.9 documentation
+ mlx.core.stack — MLX 0.1.0 documentation
@@ -134,8 +134,8 @@
-
-
+
+
@@ -241,6 +241,8 @@
mlx.core.cos
mlx.core.cosh
mlx.core.dequantize
+mlx.core.diag
+mlx.core.diagonal
mlx.core.divide
mlx.core.divmod
mlx.core.equal
@@ -351,7 +353,6 @@
mlx.core.jvp
mlx.core.vjp
mlx.core.vmap
-mlx.core.simplify
FFT
Linear Algebra
Neural Networks
@@ -433,6 +436,7 @@
mlx.nn.prelu
mlx.nn.relu
mlx.nn.selu
+mlx.nn.softshrink
mlx.nn.silu
mlx.nn.step
@@ -441,6 +445,7 @@
mlx.nn.losses.binary_cross_entropy
mlx.nn.losses.cosine_similarity_loss
mlx.nn.losses.cross_entropy
+mlx.nn.losses.gaussian_nll_loss
mlx.nn.losses.hinge_loss
mlx.nn.losses.huber_loss
mlx.nn.losses.kl_div_loss
@@ -452,14 +457,26 @@
mlx.nn.losses.triplet_loss
+Initializers
-Optimizers
+
+Optimizers
-Tree Utils
@@ -441,6 +445,7 @@
mlx.nn.losses.binary_cross_entropy
mlx.nn.losses.cosine_similarity_loss
mlx.nn.losses.cross_entropy
+mlx.nn.losses.gaussian_nll_loss
mlx.nn.losses.hinge_loss
mlx.nn.losses.huber_loss
mlx.nn.losses.kl_div_loss
@@ -452,14 +457,26 @@
mlx.nn.losses.triplet_loss
+Initializers
-Optimizers
+
+Optimizers
-Tree Utils
@@ -441,6 +445,7 @@
mlx.nn.losses.binary_cross_entropy
mlx.nn.losses.cosine_similarity_loss
mlx.nn.losses.cross_entropy
+mlx.nn.losses.gaussian_nll_loss
mlx.nn.losses.hinge_loss
mlx.nn.losses.huber_loss
mlx.nn.losses.kl_div_loss
@@ -452,14 +457,26 @@
mlx.nn.losses.triplet_loss
+Initializers
-Optimizers
+
+Optimizers
-Tree Utils
@@ -441,6 +445,7 @@
mlx.nn.losses.binary_cross_entropy
mlx.nn.losses.cosine_similarity_loss
mlx.nn.losses.cross_entropy
+mlx.nn.losses.gaussian_nll_loss
mlx.nn.losses.hinge_loss
mlx.nn.losses.huber_loss
mlx.nn.losses.kl_div_loss
@@ -452,14 +457,26 @@
mlx.nn.losses.triplet_loss
+Initializers
-Optimizers
+
+Optimizers
-Tree Utils
@@ -441,6 +445,7 @@
mlx.nn.losses.binary_cross_entropy
mlx.nn.losses.cosine_similarity_loss
mlx.nn.losses.cross_entropy
+mlx.nn.losses.gaussian_nll_loss
mlx.nn.losses.hinge_loss
mlx.nn.losses.huber_loss
mlx.nn.losses.kl_div_loss
@@ -452,14 +457,26 @@
mlx.nn.losses.triplet_loss
+Initializers
-Optimizers
+
+Optimizers
-Tree Utils
@@ -441,6 +445,7 @@
mlx.nn.losses.binary_cross_entropy
mlx.nn.losses.cosine_similarity_loss
mlx.nn.losses.cross_entropy
+mlx.nn.losses.gaussian_nll_loss
mlx.nn.losses.hinge_loss
mlx.nn.losses.huber_loss
mlx.nn.losses.kl_div_loss
@@ -452,14 +457,26 @@
mlx.nn.losses.triplet_loss
+Initializers
-Optimizers
+
+Optimizers
-Tree Utils
@@ -441,6 +445,7 @@
mlx.nn.losses.binary_cross_entropy
mlx.nn.losses.cosine_similarity_loss
mlx.nn.losses.cross_entropy
+mlx.nn.losses.gaussian_nll_loss
mlx.nn.losses.hinge_loss
mlx.nn.losses.huber_loss
mlx.nn.losses.kl_div_loss
@@ -452,14 +457,26 @@
mlx.nn.losses.triplet_loss
+Initializers
-Optimizers
+
+Optimizers
-Tree Utils
@@ -441,6 +445,7 @@
mlx.nn.losses.binary_cross_entropy
mlx.nn.losses.cosine_similarity_loss
mlx.nn.losses.cross_entropy
+mlx.nn.losses.gaussian_nll_loss
mlx.nn.losses.hinge_loss
mlx.nn.losses.huber_loss
mlx.nn.losses.kl_div_loss
@@ -452,14 +457,26 @@
mlx.nn.losses.triplet_loss
+Initializers
-Optimizers
+
+Optimizers
-Tree Utils
@@ -441,6 +445,7 @@
mlx.nn.losses.binary_cross_entropy
mlx.nn.losses.cosine_similarity_loss
mlx.nn.losses.cross_entropy
+mlx.nn.losses.gaussian_nll_loss
mlx.nn.losses.hinge_loss
mlx.nn.losses.huber_loss
mlx.nn.losses.kl_div_loss
@@ -452,14 +457,26 @@
mlx.nn.losses.triplet_loss
+Initializers
-Optimizers
+
+Optimizers
-Tree Utils
@@ -441,6 +445,7 @@
mlx.nn.losses.binary_cross_entropy
mlx.nn.losses.cosine_similarity_loss
mlx.nn.losses.cross_entropy
+mlx.nn.losses.gaussian_nll_loss
mlx.nn.losses.hinge_loss
mlx.nn.losses.huber_loss
mlx.nn.losses.kl_div_loss
@@ -452,14 +457,26 @@
mlx.nn.losses.triplet_loss
+Initializers
-Optimizers
+
+Optimizers
-Tree Utils
+Tree Utils
mlx.utils.tree_flatten
mlx.utils.tree_unflatten
mlx.utils.tree_map
@@ -650,16 +667,16 @@ document.write(`
mlx.core.tensordot
-mlx.core. tensordot ( a : array , b : array , / , dims : Union [ int , List [ List [ int ] ] ] = 2 , * , stream : Union [ None , Stream , Device ] = None ) → array
+mlx.core. tensordot ( a : array , b : array , / , axes : Union [ int , List [ List [ int ] ] ] = 2 , * , stream : Union [ None , Stream , Device ] = None ) → array
Compute the tensor dot product along the specified axes.
Parameters:
a (array ) – Input array
b (array ) – Input array
-dims (int or list ( list ( int ) ) , optional ) – The number of dimensions to
+
axes (int or list ( list ( int ) ) , optional ) – The number of dimensions to
sum over. If an integer is provided, then sum over the last
-dims
dimensions of a
and the first dims
dimensions of
+axes
dimensions of a
and the first axes
dimensions of
b
. If a list of lists is provided, then sum over the
corresponding dimensions of a
and b
. (default: 2)
diff --git a/docs/build/html/python/_autosummary/mlx.core.transpose.html b/docs/build/html/python/_autosummary/mlx.core.transpose.html
index 2cd6d8f6a..92a80a106 100644
--- a/docs/build/html/python/_autosummary/mlx.core.transpose.html
+++ b/docs/build/html/python/_autosummary/mlx.core.transpose.html
@@ -9,7 +9,7 @@
- mlx.core.transpose — MLX 0.0.9 documentation
+ mlx.core.transpose — MLX 0.1.0 documentation
@@ -134,8 +134,8 @@
-
-
+
+
@@ -241,6 +241,8 @@
mlx.core.cos
mlx.core.cosh
mlx.core.dequantize
+mlx.core.diag
+mlx.core.diagonal
mlx.core.divide
mlx.core.divmod
mlx.core.equal
@@ -351,7 +353,6 @@
mlx.core.jvp
mlx.core.vjp
mlx.core.vmap
-mlx.core.simplify
FFT
Linear Algebra
Neural Networks
@@ -433,6 +436,7 @@
mlx.nn.prelu
mlx.nn.relu
mlx.nn.selu
+mlx.nn.softshrink
mlx.nn.silu
mlx.nn.step
@@ -441,6 +445,7 @@
mlx.nn.losses.binary_cross_entropy
mlx.nn.losses.cosine_similarity_loss
mlx.nn.losses.cross_entropy
+mlx.nn.losses.gaussian_nll_loss
mlx.nn.losses.hinge_loss
mlx.nn.losses.huber_loss
mlx.nn.losses.kl_div_loss
@@ -452,14 +457,26 @@
mlx.nn.losses.triplet_loss
+Initializers
-Optimizers
+
+Optimizers
-Tree Utils
@@ -441,6 +445,7 @@
mlx.nn.losses.binary_cross_entropy
mlx.nn.losses.cosine_similarity_loss
mlx.nn.losses.cross_entropy
+mlx.nn.losses.gaussian_nll_loss
mlx.nn.losses.hinge_loss
mlx.nn.losses.huber_loss
mlx.nn.losses.kl_div_loss
@@ -452,14 +457,26 @@
mlx.nn.losses.triplet_loss
+Initializers
-Optimizers
+
+Optimizers
-Tree Utils
@@ -441,6 +445,7 @@
mlx.nn.losses.binary_cross_entropy
mlx.nn.losses.cosine_similarity_loss
mlx.nn.losses.cross_entropy
+mlx.nn.losses.gaussian_nll_loss
mlx.nn.losses.hinge_loss
mlx.nn.losses.huber_loss
mlx.nn.losses.kl_div_loss
@@ -452,14 +457,26 @@
mlx.nn.losses.triplet_loss
+Initializers
-Optimizers
+
+Optimizers
-Tree Utils
@@ -441,6 +445,7 @@
mlx.nn.losses.binary_cross_entropy
mlx.nn.losses.cosine_similarity_loss
mlx.nn.losses.cross_entropy
+mlx.nn.losses.gaussian_nll_loss
mlx.nn.losses.hinge_loss
mlx.nn.losses.huber_loss
mlx.nn.losses.kl_div_loss
@@ -452,14 +457,26 @@
mlx.nn.losses.triplet_loss
+Initializers
-Optimizers
+
+Optimizers
-Tree Utils
@@ -441,6 +445,7 @@
mlx.nn.losses.binary_cross_entropy
mlx.nn.losses.cosine_similarity_loss
mlx.nn.losses.cross_entropy
+mlx.nn.losses.gaussian_nll_loss
mlx.nn.losses.hinge_loss
mlx.nn.losses.huber_loss
mlx.nn.losses.kl_div_loss
@@ -452,14 +457,26 @@
mlx.nn.losses.triplet_loss
+Initializers
-Optimizers
+
+Optimizers
-Tree Utils
@@ -441,6 +445,7 @@
mlx.nn.losses.binary_cross_entropy
mlx.nn.losses.cosine_similarity_loss
mlx.nn.losses.cross_entropy
+mlx.nn.losses.gaussian_nll_loss
mlx.nn.losses.hinge_loss
mlx.nn.losses.huber_loss
mlx.nn.losses.kl_div_loss
@@ -452,14 +457,26 @@
mlx.nn.losses.triplet_loss
+Initializers
-Optimizers
+
+Optimizers
-Tree Utils
@@ -441,6 +445,7 @@
mlx.nn.losses.binary_cross_entropy
mlx.nn.losses.cosine_similarity_loss
mlx.nn.losses.cross_entropy
+mlx.nn.losses.gaussian_nll_loss
mlx.nn.losses.hinge_loss
mlx.nn.losses.huber_loss
mlx.nn.losses.kl_div_loss
@@ -452,14 +457,26 @@
mlx.nn.losses.triplet_loss
+Initializers
-Optimizers
+
+Optimizers
-Tree Utils
+Tree Utils
mlx.utils.tree_flatten
mlx.utils.tree_unflatten
mlx.utils.tree_map
diff --git a/docs/build/html/python/_autosummary/mlx.core.vmap.html b/docs/build/html/python/_autosummary/mlx.core.vmap.html
index 86e83ff1f..0c18e6904 100644
--- a/docs/build/html/python/_autosummary/mlx.core.vmap.html
+++ b/docs/build/html/python/_autosummary/mlx.core.vmap.html
@@ -9,7 +9,7 @@
- mlx.core.vmap — MLX 0.0.9 documentation
+ mlx.core.vmap — MLX 0.1.0 documentation
@@ -46,7 +46,7 @@
-
+
@@ -134,8 +134,8 @@
-
-
+
+
@@ -241,6 +241,8 @@
mlx.core.cos
mlx.core.cosh
mlx.core.dequantize
+mlx.core.diag
+mlx.core.diagonal
mlx.core.divide
mlx.core.divmod
mlx.core.equal
@@ -351,7 +353,6 @@
mlx.core.jvp
mlx.core.vjp
mlx.core.vmap
-mlx.core.simplify
FFT
Linear Algebra
Neural Networks
@@ -433,6 +436,7 @@
mlx.nn.prelu
mlx.nn.relu
mlx.nn.selu
+mlx.nn.softshrink
mlx.nn.silu
mlx.nn.step
@@ -441,6 +445,7 @@
mlx.nn.losses.binary_cross_entropy
mlx.nn.losses.cosine_similarity_loss
mlx.nn.losses.cross_entropy
+mlx.nn.losses.gaussian_nll_loss
mlx.nn.losses.hinge_loss
mlx.nn.losses.huber_loss
mlx.nn.losses.kl_div_loss
@@ -452,14 +457,26 @@
mlx.nn.losses.triplet_loss
+Initializers
-Optimizers
+
+Optimizers
-Tree Utils
+Tree Utils
mlx.utils.tree_flatten
mlx.utils.tree_unflatten
mlx.utils.tree_map
@@ -700,11 +717,11 @@ Defaults to 0
next
-
mlx.core.simplify
+
FFT
diff --git a/docs/build/html/python/_autosummary/mlx.core.where.html b/docs/build/html/python/_autosummary/mlx.core.where.html
index 81ebf3f17..b9f8745af 100644
--- a/docs/build/html/python/_autosummary/mlx.core.where.html
+++ b/docs/build/html/python/_autosummary/mlx.core.where.html
@@ -9,7 +9,7 @@
- mlx.core.where — MLX 0.0.9 documentation
+ mlx.core.where — MLX 0.1.0 documentation
@@ -134,8 +134,8 @@
-
-
+
+
@@ -241,6 +241,8 @@
mlx.core.cos
mlx.core.cosh
mlx.core.dequantize
+mlx.core.diag
+mlx.core.diagonal
mlx.core.divide
mlx.core.divmod
mlx.core.equal
@@ -351,7 +353,6 @@
mlx.core.jvp
mlx.core.vjp
mlx.core.vmap
-mlx.core.simplify
FFT
Linear Algebra
Neural Networks
@@ -433,6 +436,7 @@
mlx.nn.prelu
mlx.nn.relu
mlx.nn.selu
+mlx.nn.softshrink
mlx.nn.silu
mlx.nn.step
@@ -441,6 +445,7 @@
mlx.nn.losses.binary_cross_entropy
mlx.nn.losses.cosine_similarity_loss
mlx.nn.losses.cross_entropy
+mlx.nn.losses.gaussian_nll_loss
mlx.nn.losses.hinge_loss
mlx.nn.losses.huber_loss
mlx.nn.losses.kl_div_loss
@@ -452,14 +457,26 @@
mlx.nn.losses.triplet_loss
+Initializers
-Optimizers
+
+Optimizers
-Tree Utils
@@ -441,6 +445,7 @@
mlx.nn.losses.binary_cross_entropy
mlx.nn.losses.cosine_similarity_loss
mlx.nn.losses.cross_entropy
+mlx.nn.losses.gaussian_nll_loss
mlx.nn.losses.hinge_loss
mlx.nn.losses.huber_loss
mlx.nn.losses.kl_div_loss
@@ -452,14 +457,26 @@
mlx.nn.losses.triplet_loss
+Initializers
-Optimizers
+
+Optimizers
-Tree Utils
@@ -441,6 +445,7 @@
mlx.nn.losses.binary_cross_entropy
mlx.nn.losses.cosine_similarity_loss
mlx.nn.losses.cross_entropy
+mlx.nn.losses.gaussian_nll_loss
mlx.nn.losses.hinge_loss
mlx.nn.losses.huber_loss
mlx.nn.losses.kl_div_loss
@@ -452,14 +457,26 @@
mlx.nn.losses.triplet_loss
+Initializers
-Optimizers
+
+Optimizers
-Tree Utils
@@ -441,6 +445,7 @@
mlx.nn.losses.binary_cross_entropy
mlx.nn.losses.cosine_similarity_loss
mlx.nn.losses.cross_entropy
+mlx.nn.losses.gaussian_nll_loss
mlx.nn.losses.hinge_loss
mlx.nn.losses.huber_loss
mlx.nn.losses.kl_div_loss
@@ -452,14 +457,26 @@
mlx.nn.losses.triplet_loss
+Initializers
-Optimizers
+
+Optimizers
-Tree Utils
+Tree Utils
mlx.utils.tree_flatten
mlx.utils.tree_unflatten
mlx.utils.tree_map
@@ -650,7 +667,7 @@ document.write(`
mlx.nn.value_and_grad
-mlx.nn. value_and_grad ( model : Module , fn : Callable )
+mlx.nn. value_and_grad ( model : Module , fn : Callable )
Transform the passed function fn
to a function that computes the
gradients of fn
wrt the model’s trainable parameters and also its
value.
diff --git a/docs/build/html/python/_autosummary/mlx.optimizers.AdaDelta.html b/docs/build/html/python/_autosummary/mlx.optimizers.AdaDelta.html
index 5f587c44e..bdd88ec76 100644
--- a/docs/build/html/python/_autosummary/mlx.optimizers.AdaDelta.html
+++ b/docs/build/html/python/_autosummary/mlx.optimizers.AdaDelta.html
@@ -9,7 +9,7 @@
- mlx.optimizers.AdaDelta — MLX 0.0.9 documentation
+ mlx.optimizers.AdaDelta — MLX 0.1.0 documentation
@@ -48,7 +48,7 @@
-
+
@@ -135,8 +135,8 @@
-
-
+
+
@@ -242,6 +242,8 @@
mlx.core.cos
mlx.core.cosh
mlx.core.dequantize
+mlx.core.diag
+mlx.core.diagonal
mlx.core.divide
mlx.core.divmod
mlx.core.equal
@@ -352,7 +354,6 @@
mlx.core.jvp
mlx.core.vjp
mlx.core.vmap
-mlx.core.simplify
FFT
Linear Algebra
Neural Networks
@@ -434,6 +437,7 @@
mlx.nn.prelu
mlx.nn.relu
mlx.nn.selu
+mlx.nn.softshrink
mlx.nn.silu
mlx.nn.step
@@ -442,6 +446,7 @@
mlx.nn.losses.binary_cross_entropy
mlx.nn.losses.cosine_similarity_loss
mlx.nn.losses.cross_entropy
+mlx.nn.losses.gaussian_nll_loss
mlx.nn.losses.hinge_loss
mlx.nn.losses.huber_loss
mlx.nn.losses.kl_div_loss
@@ -453,14 +458,26 @@
mlx.nn.losses.triplet_loss
+Initializers
-Optimizers
+
+Optimizers
-Tree Utils
+Tree Utils
mlx.utils.tree_flatten
mlx.utils.tree_unflatten
mlx.utils.tree_map
@@ -652,7 +669,7 @@ document.write(`
class mlx.optimizers. AdaDelta ( learning_rate : float , rho : float = 0.9 , eps : float = 1e-06 )
-Implementation of the AdaDelta optimizer with learning rate[1].
+The AdaDelta optimizer with a learning rate [1].
Our AdaDelta implementation follows the original paper. In detail,
[1]: Zeiler, M.D., 2012. ADADELTA: an adaptive learning rate method. arXiv preprint arXiv:1212.5701.
@@ -698,12 +715,12 @@ numerical stability. Default:
1e-8
previous
-
mlx.optimizers.Adagrad
+
mlx.optimizers.Adafactor
+
+
+
+
+
+
+
+
+ mlx.optimizers.Adafactor — MLX 0.1.0 documentation
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+ Skip to main content
+
+
+
+
+
+
+ Back to top
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
mlx.optimizers.Adafactor
+
+
+
+
+
+
+
+
+
+
+mlx.optimizers.Adafactor
+
+
+class mlx.optimizers. Adafactor ( learning_rate : Optional [ float ] = None , eps : Tuple [ float , float ] = (1e-30, 0.001) , clip_threshold : float = 1.0 , decay_rate : float = - 0.8 , beta_1 : Optional [ float ] = None , weight_decay : float = 0.0 , scale_parameter : bool = True , relative_step : bool = True , warmup_init : bool = False )
+The Adafactor optimizer.
+Our Adafactor implementation follows the original paper: Adafactor:
+Adaptive Learning Rates with Sublinear Memory Cost
+
+Parameters:
+
+learning_rate (float , optional ) – The learning rate. Default: None
.
+eps (tuple ( float , float ) , optional ) – The first term \(\epsilon_1\)
+added to the square of the gradients to improve numerical
+stability and the second term \(\epsilon_2\) is used for
+parameter scaling if parameter_scale
is set to True
.
+Default: (1e-30, 1e-3)
.
+clip_threshold (float , optional ) – Clips the unscaled update at
+clip_threshold
. Default: 1.0
.
+decay_rate (float , optional ) – Coefficient for the running average
+of the squared gradient. Default: -0.8
.
+beta_1 (float , optional ) – If set to a value bigger than zero
+then first moment will be used. Default: None
.
+weight_decay (float , optional ) – The weight decay \(\lambda\) .
+Default: 0.0
.
+scale_parameter (bool , optional ) – If set to True
the learning rate
+will be scaled by \(\max(\epsilon_1, \text{RMS}(w_{t-1}))\) .
+Default: True
.
+relative_step (bool , optional ) – If set to True
the learning_rate
+will be ignored and relative step size will be computed.
+Default: True
.
+warmup_init (bool , optional ) – If set to True
then the relative
+step size will be calculated by the current step. Default:
+False
.
+
+
+
+Methods
+
+
+__init__
([learning_rate, eps, ...])
+
+
+apply_single
(gradient, parameter, state)
+Performs the Adafactor parameter and state update.
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/docs/build/html/python/_autosummary/mlx.optimizers.Adagrad.html b/docs/build/html/python/_autosummary/mlx.optimizers.Adagrad.html
index 812f257e9..5014fe6ea 100644
--- a/docs/build/html/python/_autosummary/mlx.optimizers.Adagrad.html
+++ b/docs/build/html/python/_autosummary/mlx.optimizers.Adagrad.html
@@ -9,7 +9,7 @@
-
mlx.optimizers.Adagrad — MLX 0.0.9 documentation
+
mlx.optimizers.Adagrad — MLX 0.1.0 documentation
@@ -47,7 +47,7 @@
-
+
@@ -135,8 +135,8 @@
-
-
+
+
@@ -242,6 +242,8 @@
mlx.core.cos
mlx.core.cosh
mlx.core.dequantize
+
mlx.core.diag
+
mlx.core.diagonal
mlx.core.divide
mlx.core.divmod
mlx.core.equal
@@ -352,7 +354,6 @@
mlx.core.jvp
mlx.core.vjp
mlx.core.vmap
-
mlx.core.simplify
FFT
Linear Algebra
Neural Networks
@@ -434,6 +437,7 @@
mlx.nn.prelu
mlx.nn.relu
mlx.nn.selu
+
mlx.nn.softshrink
mlx.nn.silu
mlx.nn.step
@@ -442,6 +446,7 @@
mlx.nn.losses.binary_cross_entropy
mlx.nn.losses.cosine_similarity_loss
mlx.nn.losses.cross_entropy
+
mlx.nn.losses.gaussian_nll_loss
mlx.nn.losses.hinge_loss
mlx.nn.losses.huber_loss
mlx.nn.losses.kl_div_loss
@@ -453,14 +458,26 @@
mlx.nn.losses.triplet_loss
+
Initializers
-
Optimizers
+
+
Optimizers
-
Tree Utils
+Tree Utils
mlx.utils.tree_flatten
mlx.utils.tree_unflatten
mlx.utils.tree_map
@@ -652,7 +669,7 @@ document.write(`
class mlx.optimizers. Adagrad ( learning_rate : float , eps : float = 1e-08 )
-Implementation of the Adagrad optimizer [1].
+The Adagrad optimizer [1].
Our Adagrad implementation follows the original paper. In detail,
[1]: Duchi, J., Hazan, E. and Singer, Y., 2011. Adaptive subgradient methods
for online learning and stochastic optimization. JMLR 2011.
@@ -704,11 +721,11 @@ denominator to improve numerical stability. Default:
next
-
mlx.optimizers.AdaDelta
+
mlx.optimizers.Adafactor
diff --git a/docs/build/html/python/_autosummary/mlx.optimizers.Adam.html b/docs/build/html/python/_autosummary/mlx.optimizers.Adam.html
index e8b34a7fb..9e732c26f 100644
--- a/docs/build/html/python/_autosummary/mlx.optimizers.Adam.html
+++ b/docs/build/html/python/_autosummary/mlx.optimizers.Adam.html
@@ -9,7 +9,7 @@
- mlx.optimizers.Adam — MLX 0.0.9 documentation
+ mlx.optimizers.Adam — MLX 0.1.0 documentation
@@ -135,8 +135,8 @@
-
-
+
+
@@ -242,6 +242,8 @@
mlx.core.cos
mlx.core.cosh
mlx.core.dequantize
+mlx.core.diag
+mlx.core.diagonal
mlx.core.divide
mlx.core.divmod
mlx.core.equal
@@ -352,7 +354,6 @@
mlx.core.jvp
mlx.core.vjp
mlx.core.vmap
-mlx.core.simplify
FFT
Linear Algebra
Neural Networks
@@ -434,6 +437,7 @@
mlx.nn.prelu
mlx.nn.relu
mlx.nn.selu
+mlx.nn.softshrink
mlx.nn.silu
mlx.nn.step
@@ -442,6 +446,7 @@
mlx.nn.losses.binary_cross_entropy
mlx.nn.losses.cosine_similarity_loss
mlx.nn.losses.cross_entropy
+mlx.nn.losses.gaussian_nll_loss
mlx.nn.losses.hinge_loss
mlx.nn.losses.huber_loss
mlx.nn.losses.kl_div_loss
@@ -453,14 +458,26 @@
mlx.nn.losses.triplet_loss
+Initializers
-Optimizers
+
+Optimizers
-Tree Utils
+Tree Utils
mlx.utils.tree_flatten
mlx.utils.tree_unflatten
mlx.utils.tree_map
@@ -652,7 +669,7 @@ document.write(`
class mlx.optimizers. Adam ( learning_rate : float , betas : List [ float ] = [0.9, 0.999] , eps : float = 1e-08 )
-Implementation of the Adam optimizer [1].
+The Adam optimizer [1].
Our Adam implementation follows the original paper and omits the bias
correction in the first and second moment estimates. In detail,
[1]: Kingma, D.P. and Ba, J., 2015. Adam: A method for stochastic
diff --git a/docs/build/html/python/_autosummary/mlx.optimizers.AdamW.html b/docs/build/html/python/_autosummary/mlx.optimizers.AdamW.html
index 871d81cde..d9c5ef21a 100644
--- a/docs/build/html/python/_autosummary/mlx.optimizers.AdamW.html
+++ b/docs/build/html/python/_autosummary/mlx.optimizers.AdamW.html
@@ -9,7 +9,7 @@
-
mlx.optimizers.AdamW — MLX 0.0.9 documentation
+ mlx.optimizers.AdamW — MLX 0.1.0 documentation
@@ -135,8 +135,8 @@
-
-
+
+
@@ -242,6 +242,8 @@
mlx.core.cos
mlx.core.cosh
mlx.core.dequantize
+mlx.core.diag
+mlx.core.diagonal
mlx.core.divide
mlx.core.divmod
mlx.core.equal
@@ -352,7 +354,6 @@
mlx.core.jvp
mlx.core.vjp
mlx.core.vmap
-mlx.core.simplify
FFT
Linear Algebra
Neural Networks
@@ -434,6 +437,7 @@
mlx.nn.prelu
mlx.nn.relu
mlx.nn.selu
+mlx.nn.softshrink
mlx.nn.silu
mlx.nn.step
@@ -442,6 +446,7 @@
mlx.nn.losses.binary_cross_entropy
mlx.nn.losses.cosine_similarity_loss
mlx.nn.losses.cross_entropy
+mlx.nn.losses.gaussian_nll_loss
mlx.nn.losses.hinge_loss
mlx.nn.losses.huber_loss
mlx.nn.losses.kl_div_loss
@@ -453,14 +458,26 @@
mlx.nn.losses.triplet_loss
+Initializers
-Optimizers
+
+Optimizers
-Tree Utils
+Tree Utils
mlx.utils.tree_flatten
mlx.utils.tree_unflatten
mlx.utils.tree_map
@@ -652,7 +669,7 @@ document.write(`
class mlx.optimizers. AdamW ( learning_rate : float , betas : List [ float ] = [0.9, 0.999] , eps : float = 1e-08 , weight_decay : float = 0.01 )
-Implementation of the AdamW optimizer [1].
+The AdamW optimizer [1].
Following the above convention, in contrast with [1], we do not use bias
correction in the first and second moments for AdamW. We update the weights
with a weight_decay (\(\lambda\) ) value:
diff --git a/docs/build/html/python/_autosummary/mlx.optimizers.Adamax.html b/docs/build/html/python/_autosummary/mlx.optimizers.Adamax.html
index 25df5c48b..f4e3af62d 100644
--- a/docs/build/html/python/_autosummary/mlx.optimizers.Adamax.html
+++ b/docs/build/html/python/_autosummary/mlx.optimizers.Adamax.html
@@ -9,7 +9,7 @@
- mlx.optimizers.Adamax — MLX 0.0.9 documentation
+ mlx.optimizers.Adamax — MLX 0.1.0 documentation
@@ -135,8 +135,8 @@
-
-
+
+
@@ -242,6 +242,8 @@
mlx.core.cos
mlx.core.cosh
mlx.core.dequantize
+mlx.core.diag
+mlx.core.diagonal
mlx.core.divide
mlx.core.divmod
mlx.core.equal
@@ -352,7 +354,6 @@
mlx.core.jvp
mlx.core.vjp
mlx.core.vmap
-mlx.core.simplify
FFT
Linear Algebra
Neural Networks
@@ -434,6 +437,7 @@
mlx.nn.prelu
mlx.nn.relu
mlx.nn.selu
+mlx.nn.softshrink
mlx.nn.silu
mlx.nn.step
@@ -442,6 +446,7 @@
mlx.nn.losses.binary_cross_entropy
mlx.nn.losses.cosine_similarity_loss
mlx.nn.losses.cross_entropy
+mlx.nn.losses.gaussian_nll_loss
mlx.nn.losses.hinge_loss
mlx.nn.losses.huber_loss
mlx.nn.losses.kl_div_loss
@@ -453,14 +458,26 @@
mlx.nn.losses.triplet_loss
+Initializers
-Optimizers
+
+Optimizers
-Tree Utils
+Tree Utils
mlx.utils.tree_flatten
mlx.utils.tree_unflatten
mlx.utils.tree_map
@@ -652,8 +669,7 @@ document.write(`
class mlx.optimizers. Adamax ( learning_rate : float , betas : List [ float ] = [0.9, 0.999] , eps : float = 1e-08 )
-Implementation of the Adamax optimizer. It is a variant of Adam based
-on the infinity norm [1].
+The Adamax optimizer, a variant of Adam based on the infinity norm [1].
Our Adam implementation follows the original paper and omits the bias
correction in the first and second moment estimates. In detail,
[1]: Kingma, D.P. and Ba, J., 2015. Adam: A method for stochastic
diff --git a/docs/build/html/python/_autosummary/mlx.optimizers.Lion.html b/docs/build/html/python/_autosummary/mlx.optimizers.Lion.html
index cb0fce4fc..43d764303 100644
--- a/docs/build/html/python/_autosummary/mlx.optimizers.Lion.html
+++ b/docs/build/html/python/_autosummary/mlx.optimizers.Lion.html
@@ -9,7 +9,7 @@
-
mlx.optimizers.Lion — MLX 0.0.9 documentation
+ mlx.optimizers.Lion — MLX 0.1.0 documentation
@@ -135,8 +135,8 @@
-
-
+
+
@@ -242,6 +242,8 @@
mlx.core.cos
mlx.core.cosh
mlx.core.dequantize
+mlx.core.diag
+mlx.core.diagonal
mlx.core.divide
mlx.core.divmod
mlx.core.equal
@@ -352,7 +354,6 @@
mlx.core.jvp
mlx.core.vjp
mlx.core.vmap
-mlx.core.simplify
FFT
Linear Algebra
Neural Networks
@@ -434,6 +437,7 @@
mlx.nn.prelu
mlx.nn.relu
mlx.nn.selu
+mlx.nn.softshrink
mlx.nn.silu
mlx.nn.step
@@ -442,6 +446,7 @@
mlx.nn.losses.binary_cross_entropy
mlx.nn.losses.cosine_similarity_loss
mlx.nn.losses.cross_entropy
+mlx.nn.losses.gaussian_nll_loss
mlx.nn.losses.hinge_loss
mlx.nn.losses.huber_loss
mlx.nn.losses.kl_div_loss
@@ -453,14 +458,26 @@
mlx.nn.losses.triplet_loss
+Initializers
-Optimizers
+
+Optimizers
-Tree Utils
+Tree Utils
mlx.utils.tree_flatten
mlx.utils.tree_unflatten
mlx.utils.tree_map
@@ -652,7 +669,7 @@ document.write(`
class mlx.optimizers. Lion ( learning_rate : float , betas : List [ float ] = [0.9, 0.99] , weight_decay : float = 0.0 )
-Implementation of the Lion optimizer [1].
+The Lion optimizer [1].
Since updates are computed through the sign operation, they tend to
have larger norm than for other optimizers such as SGD and Adam.
We recommend a learning rate that is 3-10x smaller than AdamW and a
@@ -662,9 +679,9 @@ detail,
[1]: Chen, X. Symbolic Discovery of Optimization Algorithms. arXiv
preprint arXiv:2302.06675.
-\[c_{t + 1} &= \beta_1 m_t + (1 - \beta_1) g_t
-m_{t + 1} &= \beta_2 m_t + (1 - \beta_2) g_t
-w_{t + 1} &= w_t - \eta (\text{sign}(c_t) + \lambda w_t)\]
+\[\begin{split}c_{t + 1} &= \beta_1 m_t + (1 - \beta_1) g_t \\
+m_{t + 1} &= \beta_2 m_t + (1 - \beta_2) g_t \\
+w_{t + 1} &= w_t - \eta (\text{sign}(c_t) + \lambda w_t)\end{split}\]
Parameters:
FFT
Linear Algebra
Neural Networks
@@ -433,6 +436,7 @@
mlx.nn.prelu
mlx.nn.relu
mlx.nn.selu
+mlx.nn.softshrink
mlx.nn.silu
mlx.nn.step
@@ -441,6 +445,7 @@
mlx.nn.losses.binary_cross_entropy
mlx.nn.losses.cosine_similarity_loss
mlx.nn.losses.cross_entropy
+mlx.nn.losses.gaussian_nll_loss
mlx.nn.losses.hinge_loss
mlx.nn.losses.huber_loss
mlx.nn.losses.kl_div_loss
@@ -452,14 +457,26 @@
mlx.nn.losses.triplet_loss
+Initializers
-Optimizers
+
+Optimizers
-Tree Utils
@@ -441,6 +445,7 @@
mlx.nn.losses.binary_cross_entropy
mlx.nn.losses.cosine_similarity_loss
mlx.nn.losses.cross_entropy
+mlx.nn.losses.gaussian_nll_loss
mlx.nn.losses.hinge_loss
mlx.nn.losses.huber_loss
mlx.nn.losses.kl_div_loss
@@ -452,14 +457,26 @@
mlx.nn.losses.triplet_loss
+Initializers
-Optimizers
+
+Optimizers
-Tree Utils
@@ -442,6 +446,7 @@
mlx.nn.losses.binary_cross_entropy
mlx.nn.losses.cosine_similarity_loss
mlx.nn.losses.cross_entropy
+mlx.nn.losses.gaussian_nll_loss
mlx.nn.losses.hinge_loss
mlx.nn.losses.huber_loss
mlx.nn.losses.kl_div_loss
@@ -453,14 +458,26 @@
mlx.nn.losses.triplet_loss
+Initializers
-Optimizers
+
+Optimizers
-Tree Utils
+Tree Utils
mlx.utils.tree_flatten
mlx.utils.tree_unflatten
mlx.utils.tree_map
@@ -652,7 +669,7 @@ document.write(`
class mlx.optimizers. RMSprop ( learning_rate : float , alpha : float = 0.99 , eps : float = 1e-08 )
-Implementation of the RMSprop optimizer [1].
+The RMSprop optimizer [1].
[1]: Tieleman, T. and Hinton, G. 2012. Lecture 6.5-rmsprop, coursera: Neural networks for machine learning
\[\begin{split}v_{t+1} &= \alpha v_t + (1 - \alpha) g_t^2 \\
diff --git a/docs/build/html/python/_autosummary/mlx.optimizers.SGD.html b/docs/build/html/python/_autosummary/mlx.optimizers.SGD.html
index 21de7d4d4..9e0291105 100644
--- a/docs/build/html/python/_autosummary/mlx.optimizers.SGD.html
+++ b/docs/build/html/python/_autosummary/mlx.optimizers.SGD.html
@@ -9,7 +9,7 @@
-
mlx.optimizers.SGD — MLX 0.0.9 documentation
+
mlx.optimizers.SGD — MLX 0.1.0 documentation
@@ -135,8 +135,8 @@
-
-
+
+
@@ -242,6 +242,8 @@
mlx.core.cos
mlx.core.cosh
mlx.core.dequantize
+mlx.core.diag
+mlx.core.diagonal
mlx.core.divide
mlx.core.divmod
mlx.core.equal
@@ -352,7 +354,6 @@
mlx.core.jvp
mlx.core.vjp
mlx.core.vmap
-mlx.core.simplify
FFT
Linear Algebra
Neural Networks
@@ -434,6 +437,7 @@
mlx.nn.prelu
mlx.nn.relu
mlx.nn.selu
+mlx.nn.softshrink
mlx.nn.silu
mlx.nn.step
@@ -442,6 +446,7 @@
mlx.nn.losses.binary_cross_entropy
mlx.nn.losses.cosine_similarity_loss
mlx.nn.losses.cross_entropy
+mlx.nn.losses.gaussian_nll_loss
mlx.nn.losses.hinge_loss
mlx.nn.losses.huber_loss
mlx.nn.losses.kl_div_loss
@@ -453,14 +458,26 @@
mlx.nn.losses.triplet_loss
+Initializers
-Optimizers
+
+Optimizers
-Tree Utils
+Tree Utils
mlx.utils.tree_flatten
mlx.utils.tree_unflatten
mlx.utils.tree_map
@@ -652,7 +669,7 @@ document.write(`
class mlx.optimizers. SGD ( learning_rate : float , momentum : float = 0.0 , weight_decay : float = 0.0 , dampening : float = 0.0 , nesterov : bool = False )
-Stochastic gradient descent optimizer.
+The stochastic gradient descent optimizer.
Updates a parameter \(w\) with a gradient \(g\) as follows
\[\begin{split}v_{t+1} &= \mu v_t + (1 - \tau) g_t \\
diff --git a/docs/build/html/python/_autosummary/mlx.utils.tree_flatten.html b/docs/build/html/python/_autosummary/mlx.utils.tree_flatten.html
index f506bf1b9..ec3376a5b 100644
--- a/docs/build/html/python/_autosummary/mlx.utils.tree_flatten.html
+++ b/docs/build/html/python/_autosummary/mlx.utils.tree_flatten.html
@@ -9,7 +9,7 @@
-
mlx.utils.tree_flatten — MLX 0.0.9 documentation
+
mlx.utils.tree_flatten — MLX 0.1.0 documentation
@@ -134,8 +134,8 @@
-
-
+
+
@@ -241,6 +241,8 @@
mlx.core.cos
mlx.core.cosh
mlx.core.dequantize
+mlx.core.diag
+mlx.core.diagonal
mlx.core.divide
mlx.core.divmod
mlx.core.equal
@@ -351,7 +353,6 @@
mlx.core.jvp
mlx.core.vjp
mlx.core.vmap
-mlx.core.simplify
FFT
Linear Algebra
Neural Networks
@@ -433,6 +436,7 @@
mlx.nn.prelu
mlx.nn.relu
mlx.nn.selu
+mlx.nn.softshrink
mlx.nn.silu
mlx.nn.step
@@ -441,6 +445,7 @@
mlx.nn.losses.binary_cross_entropy
mlx.nn.losses.cosine_similarity_loss
mlx.nn.losses.cross_entropy
+mlx.nn.losses.gaussian_nll_loss
mlx.nn.losses.hinge_loss
mlx.nn.losses.huber_loss
mlx.nn.losses.kl_div_loss
@@ -452,14 +457,26 @@
mlx.nn.losses.triplet_loss
+Initializers
-Optimizers
+
+Optimizers
-Tree Utils
@@ -441,6 +445,7 @@
mlx.nn.losses.binary_cross_entropy
mlx.nn.losses.cosine_similarity_loss
mlx.nn.losses.cross_entropy
+mlx.nn.losses.gaussian_nll_loss
mlx.nn.losses.hinge_loss
mlx.nn.losses.huber_loss
mlx.nn.losses.kl_div_loss
@@ -452,14 +457,26 @@
mlx.nn.losses.triplet_loss
+Initializers
-Optimizers
+
+Optimizers
-Tree Utils
@@ -441,6 +445,7 @@
mlx.nn.losses.binary_cross_entropy
mlx.nn.losses.cosine_similarity_loss
mlx.nn.losses.cross_entropy
+mlx.nn.losses.gaussian_nll_loss
mlx.nn.losses.hinge_loss
mlx.nn.losses.huber_loss
mlx.nn.losses.kl_div_loss
@@ -452,14 +457,26 @@
mlx.nn.losses.triplet_loss
+Initializers
-Optimizers
+
+Optimizers
-Tree Utils
@@ -441,6 +445,7 @@
mlx.nn.losses.binary_cross_entropy
mlx.nn.losses.cosine_similarity_loss
mlx.nn.losses.cross_entropy
+mlx.nn.losses.gaussian_nll_loss
mlx.nn.losses.hinge_loss
mlx.nn.losses.huber_loss
mlx.nn.losses.kl_div_loss
@@ -452,14 +457,26 @@
mlx.nn.losses.triplet_loss
+Initializers
-Optimizers
+
+Optimizers
-Tree Utils
@@ -439,6 +443,7 @@
mlx.nn.losses.binary_cross_entropy
mlx.nn.losses.cosine_similarity_loss
mlx.nn.losses.cross_entropy
+mlx.nn.losses.gaussian_nll_loss
mlx.nn.losses.hinge_loss
mlx.nn.losses.huber_loss
mlx.nn.losses.kl_div_loss
@@ -450,14 +455,26 @@
mlx.nn.losses.triplet_loss
+Initializers
-Optimizers
+
+Optimizers
-Tree Utils
@@ -441,6 +445,7 @@
mlx.nn.losses.binary_cross_entropy
mlx.nn.losses.cosine_similarity_loss
mlx.nn.losses.cross_entropy
+mlx.nn.losses.gaussian_nll_loss
mlx.nn.losses.hinge_loss
mlx.nn.losses.huber_loss
mlx.nn.losses.kl_div_loss
@@ -452,14 +457,26 @@
mlx.nn.losses.triplet_loss
+Initializers
-Optimizers
+
+Optimizers
-Tree Utils
+Tree Utils
mlx.utils.tree_flatten
mlx.utils.tree_unflatten
mlx.utils.tree_map
diff --git a/docs/build/html/python/fft.html b/docs/build/html/python/fft.html
index 9d465bf50..3a250eb6b 100644
--- a/docs/build/html/python/fft.html
+++ b/docs/build/html/python/fft.html
@@ -9,7 +9,7 @@
- FFT — MLX 0.0.9 documentation
+ FFT — MLX 0.1.0 documentation
@@ -47,7 +47,7 @@
-
+
@@ -134,8 +134,8 @@
-
-
+
+
@@ -241,6 +241,8 @@
mlx.core.cos
mlx.core.cosh
mlx.core.dequantize
+mlx.core.diag
+mlx.core.diagonal
mlx.core.divide
mlx.core.divmod
mlx.core.equal
@@ -351,7 +353,6 @@
mlx.core.jvp
mlx.core.vjp
mlx.core.vmap
-mlx.core.simplify
FFT
Linear Algebra
Neural Networks
@@ -433,6 +436,7 @@
mlx.nn.prelu
mlx.nn.relu
mlx.nn.selu
+mlx.nn.softshrink
mlx.nn.silu
mlx.nn.step
@@ -441,6 +445,7 @@
mlx.nn.losses.binary_cross_entropy
mlx.nn.losses.cosine_similarity_loss
mlx.nn.losses.cross_entropy
+mlx.nn.losses.gaussian_nll_loss
mlx.nn.losses.hinge_loss
mlx.nn.losses.huber_loss
mlx.nn.losses.kl_div_loss
@@ -452,14 +457,26 @@
mlx.nn.losses.triplet_loss
+Initializers
-Optimizers
+
+Optimizers
-Tree Utils
@@ -441,6 +445,7 @@
mlx.nn.losses.binary_cross_entropy
mlx.nn.losses.cosine_similarity_loss
mlx.nn.losses.cross_entropy
+mlx.nn.losses.gaussian_nll_loss
mlx.nn.losses.hinge_loss
mlx.nn.losses.huber_loss
mlx.nn.losses.kl_div_loss
@@ -452,14 +457,26 @@
mlx.nn.losses.triplet_loss
+Initializers
-Optimizers
+
+Optimizers
-Tree Utils
diff --git a/docs/build/html/python/nn.html b/docs/build/html/python/nn.html
index 9108d2de2..3e8eaf664 100644
--- a/docs/build/html/python/nn.html
+++ b/docs/build/html/python/nn.html
@@ -9,7 +9,7 @@
- Neural Networks — MLX 0.0.9 documentation
+ Neural Networks — MLX 0.1.0 documentation
@@ -47,7 +47,7 @@
-
+
@@ -134,8 +134,8 @@
-
-
+
+
@@ -241,6 +241,8 @@
mlx.core.cos
mlx.core.cosh
mlx.core.dequantize
+mlx.core.diag
+mlx.core.diagonal
mlx.core.divide
mlx.core.divmod
mlx.core.equal
@@ -351,7 +353,6 @@
mlx.core.jvp
mlx.core.vjp
mlx.core.vmap
-mlx.core.simplify
FFT
Linear Algebra
Neural Networks
@@ -433,6 +436,7 @@
mlx.nn.prelu
mlx.nn.relu
mlx.nn.selu
+mlx.nn.softshrink
mlx.nn.silu
mlx.nn.step
@@ -441,6 +445,7 @@
mlx.nn.losses.binary_cross_entropy
mlx.nn.losses.cosine_similarity_loss
mlx.nn.losses.cross_entropy
+mlx.nn.losses.gaussian_nll_loss
mlx.nn.losses.hinge_loss
mlx.nn.losses.huber_loss
mlx.nn.losses.kl_div_loss
@@ -452,14 +457,26 @@
mlx.nn.losses.triplet_loss
+Initializers
-Optimizers
+
+Optimizers
-Tree Utils
@@ -880,6 +899,7 @@ parameters as the first argument to the function returned by
mlx.nn.losses.binary_cross_entropy
mlx.nn.losses.cosine_similarity_loss
mlx.nn.losses.cross_entropy
+mlx.nn.losses.gaussian_nll_loss
mlx.nn.losses.hinge_loss
mlx.nn.losses.huber_loss
mlx.nn.losses.kl_div_loss
@@ -891,6 +911,17 @@ parameters as the first argument to the function returned by
mlx.nn.losses.triplet_loss
+Initializers
+
@@ -908,12 +939,12 @@ parameters as the first argument to the function returned by
@@ -241,6 +241,8 @@
mlx.core.cos
mlx.core.cosh
mlx.core.dequantize
+mlx.core.diag
+mlx.core.diagonal
mlx.core.divide
mlx.core.divmod
mlx.core.equal
@@ -351,7 +353,6 @@
mlx.core.jvp
mlx.core.vjp
mlx.core.vmap
-mlx.core.simplify
FFT
Linear Algebra
Neural Networks
@@ -433,6 +436,7 @@
mlx.nn.prelu
mlx.nn.relu
mlx.nn.selu
+mlx.nn.softshrink
mlx.nn.silu
mlx.nn.step
@@ -441,6 +445,7 @@
mlx.nn.losses.binary_cross_entropy
mlx.nn.losses.cosine_similarity_loss
mlx.nn.losses.cross_entropy
+mlx.nn.losses.gaussian_nll_loss
mlx.nn.losses.hinge_loss
mlx.nn.losses.huber_loss
mlx.nn.losses.kl_div_loss
@@ -452,14 +457,26 @@
mlx.nn.losses.triplet_loss
+Initializers
-Optimizers
+
+Optimizers
-Tree Utils
@@ -442,6 +446,7 @@
mlx.nn.losses.binary_cross_entropy
mlx.nn.losses.cosine_similarity_loss
mlx.nn.losses.cross_entropy
+mlx.nn.losses.gaussian_nll_loss
mlx.nn.losses.hinge_loss
mlx.nn.losses.huber_loss
mlx.nn.losses.kl_div_loss
@@ -453,14 +458,26 @@
mlx.nn.losses.triplet_loss
+Initializers
-Optimizers
+
+Optimizers
-Tree Utils
@@ -441,6 +445,7 @@
mlx.nn.losses.binary_cross_entropy
mlx.nn.losses.cosine_similarity_loss
mlx.nn.losses.cross_entropy
+mlx.nn.losses.gaussian_nll_loss
mlx.nn.losses.hinge_loss
mlx.nn.losses.huber_loss
mlx.nn.losses.kl_div_loss
@@ -452,14 +457,26 @@
mlx.nn.losses.triplet_loss
+Initializers
-Optimizers
+
+Optimizers
-Tree Utils
@@ -441,6 +445,7 @@
mlx.nn.losses.binary_cross_entropy
mlx.nn.losses.cosine_similarity_loss
mlx.nn.losses.cross_entropy
+mlx.nn.losses.gaussian_nll_loss
mlx.nn.losses.hinge_loss
mlx.nn.losses.huber_loss
mlx.nn.losses.kl_div_loss
@@ -452,14 +457,26 @@
mlx.nn.losses.triplet_loss
+Initializers
-Optimizers
+
+Optimizers
-Tree Utils
@@ -442,6 +446,7 @@
mlx.nn.losses.binary_cross_entropy
mlx.nn.losses.cosine_similarity_loss
mlx.nn.losses.cross_entropy
+mlx.nn.losses.gaussian_nll_loss
mlx.nn.losses.hinge_loss
mlx.nn.losses.huber_loss
mlx.nn.losses.kl_div_loss
@@ -453,14 +458,26 @@
mlx.nn.losses.triplet_loss
+Initializers
-Optimizers
+
+Optimizers
-Tree Utils
@@ -442,6 +446,7 @@
mlx.nn.losses.binary_cross_entropy
mlx.nn.losses.cosine_similarity_loss
mlx.nn.losses.cross_entropy
+mlx.nn.losses.gaussian_nll_loss
mlx.nn.losses.hinge_loss
mlx.nn.losses.huber_loss
mlx.nn.losses.kl_div_loss
@@ -453,14 +458,26 @@
mlx.nn.losses.triplet_loss
+Initializers
-Optimizers
+
+Optimizers
-Tree Utils
@@ -442,6 +446,7 @@
mlx.nn.losses.binary_cross_entropy
mlx.nn.losses.cosine_similarity_loss
mlx.nn.losses.cross_entropy
+mlx.nn.losses.gaussian_nll_loss
mlx.nn.losses.hinge_loss
mlx.nn.losses.huber_loss
mlx.nn.losses.kl_div_loss
@@ -453,14 +458,26 @@
mlx.nn.losses.triplet_loss
+Initializers
-Optimizers
+
+Optimizers
-Tree Utils
@@ -441,6 +445,7 @@
mlx.nn.losses.binary_cross_entropy
mlx.nn.losses.cosine_similarity_loss
mlx.nn.losses.cross_entropy
+mlx.nn.losses.gaussian_nll_loss
mlx.nn.losses.hinge_loss
mlx.nn.losses.huber_loss
mlx.nn.losses.kl_div_loss
@@ -452,14 +457,26 @@
mlx.nn.losses.triplet_loss
+Initializers
-Optimizers
+
+Optimizers
-Tree Utils
@@ -442,6 +446,7 @@
mlx.nn.losses.binary_cross_entropy
mlx.nn.losses.cosine_similarity_loss
mlx.nn.losses.cross_entropy
+mlx.nn.losses.gaussian_nll_loss
mlx.nn.losses.hinge_loss
mlx.nn.losses.huber_loss
mlx.nn.losses.kl_div_loss
@@ -453,14 +458,26 @@
mlx.nn.losses.triplet_loss
+Initializers
-Optimizers
+
+Optimizers
-Tree Utils
@@ -442,6 +446,7 @@
mlx.nn.losses.binary_cross_entropy
mlx.nn.losses.cosine_similarity_loss
mlx.nn.losses.cross_entropy
+mlx.nn.losses.gaussian_nll_loss
mlx.nn.losses.hinge_loss
mlx.nn.losses.huber_loss
mlx.nn.losses.kl_div_loss
@@ -453,14 +458,26 @@
mlx.nn.losses.triplet_loss
+Initializers
-Optimizers
+
+Optimizers
-Tree Utils
@@ -442,6 +446,7 @@
mlx.nn.losses.binary_cross_entropy
mlx.nn.losses.cosine_similarity_loss
mlx.nn.losses.cross_entropy
+mlx.nn.losses.gaussian_nll_loss
mlx.nn.losses.hinge_loss
mlx.nn.losses.huber_loss
mlx.nn.losses.kl_div_loss
@@ -453,14 +458,26 @@
mlx.nn.losses.triplet_loss
+Initializers
-Optimizers
+
+Optimizers
-Tree Utils
@@ -442,6 +446,7 @@
mlx.nn.losses.binary_cross_entropy
mlx.nn.losses.cosine_similarity_loss
mlx.nn.losses.cross_entropy
+mlx.nn.losses.gaussian_nll_loss
mlx.nn.losses.hinge_loss
mlx.nn.losses.huber_loss
mlx.nn.losses.kl_div_loss
@@ -453,14 +458,26 @@
mlx.nn.losses.triplet_loss
+Initializers
-Optimizers
+
+Optimizers
-Tree Utils
@@ -442,6 +446,7 @@
mlx.nn.losses.binary_cross_entropy
mlx.nn.losses.cosine_similarity_loss
mlx.nn.losses.cross_entropy
+mlx.nn.losses.gaussian_nll_loss
mlx.nn.losses.hinge_loss
mlx.nn.losses.huber_loss
mlx.nn.losses.kl_div_loss
@@ -453,14 +458,26 @@
mlx.nn.losses.triplet_loss
+Initializers
-Optimizers
+
+Optimizers
-Tree Utils
@@ -442,6 +446,7 @@
mlx.nn.losses.binary_cross_entropy
mlx.nn.losses.cosine_similarity_loss
mlx.nn.losses.cross_entropy
+mlx.nn.losses.gaussian_nll_loss
mlx.nn.losses.hinge_loss
mlx.nn.losses.huber_loss
mlx.nn.losses.kl_div_loss
@@ -453,14 +458,26 @@
mlx.nn.losses.triplet_loss
+Initializers
-Optimizers
+
+Optimizers
-Tree Utils
@@ -441,6 +445,7 @@
mlx.nn.losses.binary_cross_entropy
mlx.nn.losses.cosine_similarity_loss
mlx.nn.losses.cross_entropy
+mlx.nn.losses.gaussian_nll_loss
mlx.nn.losses.hinge_loss
mlx.nn.losses.huber_loss
mlx.nn.losses.kl_div_loss
@@ -452,14 +457,26 @@
mlx.nn.losses.triplet_loss
+Initializers
-Optimizers
+
+Optimizers
-Tree Utils
@@ -441,6 +445,7 @@
mlx.nn.losses.binary_cross_entropy
mlx.nn.losses.cosine_similarity_loss
mlx.nn.losses.cross_entropy
+mlx.nn.losses.gaussian_nll_loss
mlx.nn.losses.hinge_loss
mlx.nn.losses.huber_loss
mlx.nn.losses.kl_div_loss
@@ -452,14 +457,26 @@
mlx.nn.losses.triplet_loss
+Initializers
-Optimizers
+
+Optimizers
-Tree Utils
@@ -441,6 +445,7 @@
mlx.nn.losses.binary_cross_entropy
mlx.nn.losses.cosine_similarity_loss
mlx.nn.losses.cross_entropy
+mlx.nn.losses.gaussian_nll_loss
mlx.nn.losses.hinge_loss
mlx.nn.losses.huber_loss
mlx.nn.losses.kl_div_loss
@@ -452,14 +457,26 @@
mlx.nn.losses.triplet_loss
+Initializers
-Optimizers
+
+Optimizers
-Tree Utils
@@ -441,6 +445,7 @@
mlx.nn.losses.binary_cross_entropy
mlx.nn.losses.cosine_similarity_loss
mlx.nn.losses.cross_entropy
+mlx.nn.losses.gaussian_nll_loss
mlx.nn.losses.hinge_loss
mlx.nn.losses.huber_loss
mlx.nn.losses.kl_div_loss
@@ -452,14 +457,26 @@
mlx.nn.losses.triplet_loss
+Initializers
-Optimizers
+
+Optimizers
-Tree Utils
@@ -441,6 +445,7 @@
mlx.nn.losses.binary_cross_entropy
mlx.nn.losses.cosine_similarity_loss
mlx.nn.losses.cross_entropy
+mlx.nn.losses.gaussian_nll_loss
mlx.nn.losses.hinge_loss
mlx.nn.losses.huber_loss
mlx.nn.losses.kl_div_loss
@@ -452,14 +457,26 @@
mlx.nn.losses.triplet_loss
+Initializers
-Optimizers
+
+Optimizers
-Tree Utils
@@ -441,6 +445,7 @@
mlx.nn.losses.binary_cross_entropy
mlx.nn.losses.cosine_similarity_loss
mlx.nn.losses.cross_entropy
+mlx.nn.losses.gaussian_nll_loss
mlx.nn.losses.hinge_loss
mlx.nn.losses.huber_loss
mlx.nn.losses.kl_div_loss
@@ -452,14 +457,26 @@
mlx.nn.losses.triplet_loss
+Initializers
-Optimizers
+
+Optimizers
-Tree Utils
@@ -441,6 +445,7 @@
mlx.nn.losses.binary_cross_entropy
mlx.nn.losses.cosine_similarity_loss
mlx.nn.losses.cross_entropy
+mlx.nn.losses.gaussian_nll_loss
mlx.nn.losses.hinge_loss
mlx.nn.losses.huber_loss
mlx.nn.losses.kl_div_loss
@@ -452,14 +457,26 @@
mlx.nn.losses.triplet_loss
+Initializers
-Optimizers
+
+Optimizers
-Tree Utils
@@ -441,6 +445,7 @@
mlx.nn.losses.binary_cross_entropy
mlx.nn.losses.cosine_similarity_loss
mlx.nn.losses.cross_entropy
+mlx.nn.losses.gaussian_nll_loss
mlx.nn.losses.hinge_loss
mlx.nn.losses.huber_loss
mlx.nn.losses.kl_div_loss
@@ -452,14 +457,26 @@
mlx.nn.losses.triplet_loss
+Initializers
-Optimizers
+
+Optimizers
-Tree Utils
+Tree Utils
mlx.utils.tree_flatten
mlx.utils.tree_unflatten
mlx.utils.tree_map
@@ -651,12 +668,12 @@ document.write(`
Module. load_weights ( file_or_weights : Union [ str , List [ Tuple [ str , array ] ] ] , strict : bool = True )
-Update the model’s weights from a .npz
or a list.
+Update the model’s weights from a .npz
, a .safetensors
file, or a list.
Parameters:
file_or_weights (str or list ( tuple ( str , mx.array ) ) ) – The path to
-the weights .npz
file or a list of pairs of parameter names
+the weights .npz
file (.npz
or .safetensors
) or a list of pairs of parameter names
and arrays.
strict (bool , optional ) – If True
then checks that the provided
weights exactly match the parameters of the model. Otherwise,
@@ -673,6 +690,9 @@ shapes are not checked. Default: # Load from file
model . load_weights ( "weights.npz" )
+# Load from .safetensors file
+model . load_weights ( "weights.safetensors" )
+
# Load from list
weights = [
( "weight" , mx . random . uniform ( shape = ( 10 , 10 ))),
diff --git a/docs/build/html/python/nn/_autosummary/mlx.nn.Module.modules.html b/docs/build/html/python/nn/_autosummary/mlx.nn.Module.modules.html
index 53471e36e..fb2b2153d 100644
--- a/docs/build/html/python/nn/_autosummary/mlx.nn.Module.modules.html
+++ b/docs/build/html/python/nn/_autosummary/mlx.nn.Module.modules.html
@@ -9,7 +9,7 @@
- mlx.nn.Module.modules — MLX 0.0.9 documentation
+ mlx.nn.Module.modules — MLX 0.1.0 documentation
@@ -134,8 +134,8 @@
-
-
+
+
@@ -241,6 +241,8 @@
mlx.core.cos
mlx.core.cosh
mlx.core.dequantize
+mlx.core.diag
+mlx.core.diagonal
mlx.core.divide
mlx.core.divmod
mlx.core.equal
@@ -351,7 +353,6 @@
mlx.core.jvp
mlx.core.vjp
mlx.core.vmap
-mlx.core.simplify
FFT
Linear Algebra
Neural Networks
@@ -433,6 +436,7 @@
mlx.nn.prelu
mlx.nn.relu
mlx.nn.selu
+mlx.nn.softshrink
mlx.nn.silu
mlx.nn.step
@@ -441,6 +445,7 @@
mlx.nn.losses.binary_cross_entropy
mlx.nn.losses.cosine_similarity_loss
mlx.nn.losses.cross_entropy
+mlx.nn.losses.gaussian_nll_loss
mlx.nn.losses.hinge_loss
mlx.nn.losses.huber_loss
mlx.nn.losses.kl_div_loss
@@ -452,14 +457,26 @@
mlx.nn.losses.triplet_loss
+Initializers
-Optimizers
+
+Optimizers
-Tree Utils
@@ -441,6 +445,7 @@
mlx.nn.losses.binary_cross_entropy
mlx.nn.losses.cosine_similarity_loss
mlx.nn.losses.cross_entropy
+mlx.nn.losses.gaussian_nll_loss
mlx.nn.losses.hinge_loss
mlx.nn.losses.huber_loss
mlx.nn.losses.kl_div_loss
@@ -452,14 +457,26 @@
mlx.nn.losses.triplet_loss
+Initializers
-Optimizers
+
+Optimizers
-Tree Utils
@@ -441,6 +445,7 @@
mlx.nn.losses.binary_cross_entropy
mlx.nn.losses.cosine_similarity_loss
mlx.nn.losses.cross_entropy
+mlx.nn.losses.gaussian_nll_loss
mlx.nn.losses.hinge_loss
mlx.nn.losses.huber_loss
mlx.nn.losses.kl_div_loss
@@ -452,14 +457,26 @@
mlx.nn.losses.triplet_loss
+Initializers
-Optimizers
+
+Optimizers
-Tree Utils
@@ -441,6 +445,7 @@
mlx.nn.losses.binary_cross_entropy
mlx.nn.losses.cosine_similarity_loss
mlx.nn.losses.cross_entropy
+mlx.nn.losses.gaussian_nll_loss
mlx.nn.losses.hinge_loss
mlx.nn.losses.huber_loss
mlx.nn.losses.kl_div_loss
@@ -452,14 +457,26 @@
mlx.nn.losses.triplet_loss
+Initializers
-Optimizers
+
+Optimizers
-Tree Utils
+Tree Utils
mlx.utils.tree_flatten
mlx.utils.tree_unflatten
mlx.utils.tree_map
@@ -651,7 +668,9 @@ document.write(`
Module. save_weights ( file : str )
-Save the model’s weights to a .npz
file.
+Save the model’s weights to a file. The saving method is determined by the file extension:
+- .npz
will use mx.savez()
+- .safetensors
will use mx.save_safetensors()
diff --git a/docs/build/html/python/nn/_autosummary/mlx.nn.Module.train.html b/docs/build/html/python/nn/_autosummary/mlx.nn.Module.train.html
index 9db0ad371..98718945a 100644
--- a/docs/build/html/python/nn/_autosummary/mlx.nn.Module.train.html
+++ b/docs/build/html/python/nn/_autosummary/mlx.nn.Module.train.html
@@ -9,7 +9,7 @@
- mlx.nn.Module.train — MLX 0.0.9 documentation
+ mlx.nn.Module.train — MLX 0.1.0 documentation
@@ -134,8 +134,8 @@
-
-
+
+
@@ -241,6 +241,8 @@
mlx.core.cos
mlx.core.cosh
mlx.core.dequantize
+mlx.core.diag
+mlx.core.diagonal
mlx.core.divide
mlx.core.divmod
mlx.core.equal
@@ -351,7 +353,6 @@
mlx.core.jvp
mlx.core.vjp
mlx.core.vmap
-mlx.core.simplify
FFT
Linear Algebra
Neural Networks
@@ -433,6 +436,7 @@
mlx.nn.prelu
mlx.nn.relu
mlx.nn.selu
+mlx.nn.softshrink
mlx.nn.silu
mlx.nn.step
@@ -441,6 +445,7 @@
mlx.nn.losses.binary_cross_entropy
mlx.nn.losses.cosine_similarity_loss
mlx.nn.losses.cross_entropy
+mlx.nn.losses.gaussian_nll_loss
mlx.nn.losses.hinge_loss
mlx.nn.losses.huber_loss
mlx.nn.losses.kl_div_loss
@@ -452,14 +457,26 @@
mlx.nn.losses.triplet_loss
+Initializers
-Optimizers
+
+Optimizers
-Tree Utils
@@ -441,6 +445,7 @@
mlx.nn.losses.binary_cross_entropy
mlx.nn.losses.cosine_similarity_loss
mlx.nn.losses.cross_entropy
+mlx.nn.losses.gaussian_nll_loss
mlx.nn.losses.hinge_loss
mlx.nn.losses.huber_loss
mlx.nn.losses.kl_div_loss
@@ -452,14 +457,26 @@
mlx.nn.losses.triplet_loss
+Initializers
-Optimizers
+
+Optimizers
-Tree Utils
@@ -441,6 +445,7 @@
mlx.nn.losses.binary_cross_entropy
mlx.nn.losses.cosine_similarity_loss
mlx.nn.losses.cross_entropy
+mlx.nn.losses.gaussian_nll_loss
mlx.nn.losses.hinge_loss
mlx.nn.losses.huber_loss
mlx.nn.losses.kl_div_loss
@@ -452,14 +457,26 @@
mlx.nn.losses.triplet_loss
+Initializers
-Optimizers
+
+Optimizers
-Tree Utils
@@ -441,6 +445,7 @@
mlx.nn.losses.binary_cross_entropy
mlx.nn.losses.cosine_similarity_loss
mlx.nn.losses.cross_entropy
+mlx.nn.losses.gaussian_nll_loss
mlx.nn.losses.hinge_loss
mlx.nn.losses.huber_loss
mlx.nn.losses.kl_div_loss
@@ -452,14 +457,26 @@
mlx.nn.losses.triplet_loss
+Initializers
-Optimizers
+
+Optimizers
-Tree Utils
@@ -441,6 +445,7 @@
mlx.nn.losses.binary_cross_entropy
mlx.nn.losses.cosine_similarity_loss
mlx.nn.losses.cross_entropy
+mlx.nn.losses.gaussian_nll_loss
mlx.nn.losses.hinge_loss
mlx.nn.losses.huber_loss
mlx.nn.losses.kl_div_loss
@@ -452,14 +457,26 @@
mlx.nn.losses.triplet_loss
+Initializers
-Optimizers
+
+Optimizers
-Tree Utils
@@ -441,6 +445,7 @@
mlx.nn.losses.binary_cross_entropy
mlx.nn.losses.cosine_similarity_loss
mlx.nn.losses.cross_entropy
+mlx.nn.losses.gaussian_nll_loss
mlx.nn.losses.hinge_loss
mlx.nn.losses.huber_loss
mlx.nn.losses.kl_div_loss
@@ -452,14 +457,26 @@
mlx.nn.losses.triplet_loss
+Initializers
-Optimizers
+
+Optimizers
-Tree Utils
@@ -441,6 +445,7 @@
mlx.nn.losses.binary_cross_entropy
mlx.nn.losses.cosine_similarity_loss
mlx.nn.losses.cross_entropy
+mlx.nn.losses.gaussian_nll_loss
mlx.nn.losses.hinge_loss
mlx.nn.losses.huber_loss
mlx.nn.losses.kl_div_loss
@@ -452,14 +457,26 @@
mlx.nn.losses.triplet_loss
+Initializers
-Optimizers
+
+Optimizers
-Tree Utils
@@ -442,6 +446,7 @@
mlx.nn.losses.binary_cross_entropy
mlx.nn.losses.cosine_similarity_loss
mlx.nn.losses.cross_entropy
+mlx.nn.losses.gaussian_nll_loss
mlx.nn.losses.hinge_loss
mlx.nn.losses.huber_loss
mlx.nn.losses.kl_div_loss
@@ -453,14 +458,26 @@
mlx.nn.losses.triplet_loss
+Initializers
-Optimizers
+
+Optimizers
-Tree Utils
+Tree Utils
mlx.utils.tree_flatten
mlx.utils.tree_unflatten
mlx.utils.tree_map
@@ -657,12 +674,12 @@ document.write(`
is an array.
-See prelu()
, for the functional equivalent.
+See prelu()
for the functional equivalent.
Parameters:
-num_parameters – number of \(a\) to learn. Default: 1
-init – the initial value of \(a\) . Default: 0.25
+num_parameters – number of \(a\) to learn. Default: 1
+init – the initial value of \(a\) . Default: 0.25
diff --git a/docs/build/html/python/nn/_autosummary/mlx.nn.QuantizedLinear.html b/docs/build/html/python/nn/_autosummary/mlx.nn.QuantizedLinear.html
index 9470d2eae..9167ae938 100644
--- a/docs/build/html/python/nn/_autosummary/mlx.nn.QuantizedLinear.html
+++ b/docs/build/html/python/nn/_autosummary/mlx.nn.QuantizedLinear.html
@@ -9,7 +9,7 @@
- mlx.nn.QuantizedLinear — MLX 0.0.9 documentation
+ mlx.nn.QuantizedLinear — MLX 0.1.0 documentation
@@ -134,8 +134,8 @@
-
-
+
+
@@ -241,6 +241,8 @@
mlx.core.cos
mlx.core.cosh
mlx.core.dequantize
+mlx.core.diag
+mlx.core.diagonal
mlx.core.divide
mlx.core.divmod
mlx.core.equal
@@ -351,7 +353,6 @@
mlx.core.jvp
mlx.core.vjp
mlx.core.vmap
-mlx.core.simplify
FFT
Linear Algebra
Neural Networks
@@ -433,6 +436,7 @@
mlx.nn.prelu
mlx.nn.relu
mlx.nn.selu
+mlx.nn.softshrink
mlx.nn.silu
mlx.nn.step
@@ -441,6 +445,7 @@
mlx.nn.losses.binary_cross_entropy
mlx.nn.losses.cosine_similarity_loss
mlx.nn.losses.cross_entropy
+mlx.nn.losses.gaussian_nll_loss
mlx.nn.losses.hinge_loss
mlx.nn.losses.huber_loss
mlx.nn.losses.kl_div_loss
@@ -452,14 +457,26 @@
mlx.nn.losses.triplet_loss
+Initializers
-Optimizers
+
+Optimizers
-Tree Utils
@@ -442,6 +446,7 @@
mlx.nn.losses.binary_cross_entropy
mlx.nn.losses.cosine_similarity_loss
mlx.nn.losses.cross_entropy
+mlx.nn.losses.gaussian_nll_loss
mlx.nn.losses.hinge_loss
mlx.nn.losses.huber_loss
mlx.nn.losses.kl_div_loss
@@ -453,14 +458,26 @@
mlx.nn.losses.triplet_loss
+Initializers
-Optimizers
+
+Optimizers
-Tree Utils
@@ -441,6 +445,7 @@
mlx.nn.losses.binary_cross_entropy
mlx.nn.losses.cosine_similarity_loss
mlx.nn.losses.cross_entropy
+mlx.nn.losses.gaussian_nll_loss
mlx.nn.losses.hinge_loss
mlx.nn.losses.huber_loss
mlx.nn.losses.kl_div_loss
@@ -452,14 +457,26 @@
mlx.nn.losses.triplet_loss
+Initializers
-Optimizers
+
+Optimizers
-Tree Utils
@@ -441,6 +445,7 @@
mlx.nn.losses.binary_cross_entropy
mlx.nn.losses.cosine_similarity_loss
mlx.nn.losses.cross_entropy
+mlx.nn.losses.gaussian_nll_loss
mlx.nn.losses.hinge_loss
mlx.nn.losses.huber_loss
mlx.nn.losses.kl_div_loss
@@ -452,14 +457,26 @@
mlx.nn.losses.triplet_loss
+Initializers
-Optimizers
+
+Optimizers
-Tree Utils
@@ -442,6 +446,7 @@
mlx.nn.losses.binary_cross_entropy
mlx.nn.losses.cosine_similarity_loss
mlx.nn.losses.cross_entropy
+mlx.nn.losses.gaussian_nll_loss
mlx.nn.losses.hinge_loss
mlx.nn.losses.huber_loss
mlx.nn.losses.kl_div_loss
@@ -453,14 +458,26 @@
mlx.nn.losses.triplet_loss
+Initializers
-Optimizers
+
+Optimizers
-Tree Utils
@@ -441,6 +445,7 @@
mlx.nn.losses.binary_cross_entropy
mlx.nn.losses.cosine_similarity_loss
mlx.nn.losses.cross_entropy
+mlx.nn.losses.gaussian_nll_loss
mlx.nn.losses.hinge_loss
mlx.nn.losses.huber_loss
mlx.nn.losses.kl_div_loss
@@ -452,14 +457,26 @@
mlx.nn.losses.triplet_loss
+Initializers
-Optimizers
+
+Optimizers
-Tree Utils
@@ -442,6 +446,7 @@
mlx.nn.losses.binary_cross_entropy
mlx.nn.losses.cosine_similarity_loss
mlx.nn.losses.cross_entropy
+mlx.nn.losses.gaussian_nll_loss
mlx.nn.losses.hinge_loss
mlx.nn.losses.huber_loss
mlx.nn.losses.kl_div_loss
@@ -453,14 +458,26 @@
mlx.nn.losses.triplet_loss
+Initializers
-Optimizers
+
+Optimizers
-Tree Utils
+Tree Utils
mlx.utils.tree_flatten
mlx.utils.tree_unflatten
mlx.utils.tree_map
diff --git a/docs/build/html/python/nn/_autosummary/mlx.nn.SinusoidalPositionalEncoding.html b/docs/build/html/python/nn/_autosummary/mlx.nn.SinusoidalPositionalEncoding.html
index 812844f5f..a07d38e0a 100644
--- a/docs/build/html/python/nn/_autosummary/mlx.nn.SinusoidalPositionalEncoding.html
+++ b/docs/build/html/python/nn/_autosummary/mlx.nn.SinusoidalPositionalEncoding.html
@@ -9,7 +9,7 @@
- mlx.nn.SinusoidalPositionalEncoding — MLX 0.0.9 documentation
+ mlx.nn.SinusoidalPositionalEncoding — MLX 0.1.0 documentation
@@ -47,7 +47,7 @@
-
+
@@ -135,8 +135,8 @@
-
-
+
+
@@ -242,6 +242,8 @@
mlx.core.cos
mlx.core.cosh
mlx.core.dequantize
+mlx.core.diag
+mlx.core.diagonal
mlx.core.divide
mlx.core.divmod
mlx.core.equal
@@ -352,7 +354,6 @@
mlx.core.jvp
mlx.core.vjp
mlx.core.vmap
-mlx.core.simplify
FFT
Linear Algebra
Neural Networks
@@ -434,6 +437,7 @@
mlx.nn.prelu
mlx.nn.relu
mlx.nn.selu
+mlx.nn.softshrink
mlx.nn.silu
mlx.nn.step
@@ -442,6 +446,7 @@
mlx.nn.losses.binary_cross_entropy
mlx.nn.losses.cosine_similarity_loss
mlx.nn.losses.cross_entropy
+mlx.nn.losses.gaussian_nll_loss
mlx.nn.losses.hinge_loss
mlx.nn.losses.huber_loss
mlx.nn.losses.kl_div_loss
@@ -453,14 +458,26 @@
mlx.nn.losses.triplet_loss
+Initializers
-Optimizers
+
+Optimizers
-Tree Utils
+Tree Utils
mlx.utils.tree_flatten
mlx.utils.tree_unflatten
mlx.utils.tree_map
@@ -696,11 +713,11 @@ instead of the reverse. Default:
next
-
mlx.nn.Step
+
mlx.nn.Softshrink
diff --git a/docs/build/html/python/nn/_autosummary/mlx.nn.Softshrink.html b/docs/build/html/python/nn/_autosummary/mlx.nn.Softshrink.html
new file mode 100644
index 000000000..3c25fe6aa
--- /dev/null
+++ b/docs/build/html/python/nn/_autosummary/mlx.nn.Softshrink.html
@@ -0,0 +1,768 @@
+
+
+
+
+
+
+
+
+
+
+
+ mlx.nn.Softshrink — MLX 0.1.0 documentation
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+ Skip to main content
+
+
+
+
+
+
+ Back to top
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
mlx.nn.Softshrink
+
+
+
+
+
+
+
+
+
+
+mlx.nn.Softshrink
+
+
+class mlx.nn. Softshrink ( lambd = 0.5 )
+Applies the Softshrink function.
+See softshrink()
for the functional equivalent.
+
+Parameters:
+lambd – the \(\lambda\) value for Softshrink. Default: 0.5
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/docs/build/html/python/nn/_autosummary/mlx.nn.Step.html b/docs/build/html/python/nn/_autosummary/mlx.nn.Step.html
index 0d60418a5..c5f04474a 100644
--- a/docs/build/html/python/nn/_autosummary/mlx.nn.Step.html
+++ b/docs/build/html/python/nn/_autosummary/mlx.nn.Step.html
@@ -9,7 +9,7 @@
- mlx.nn.Step — MLX 0.0.9 documentation
+ mlx.nn.Step — MLX 0.1.0 documentation
@@ -48,7 +48,7 @@
-
+
@@ -135,8 +135,8 @@
-
-
+
+
@@ -242,6 +242,8 @@
mlx.core.cos
mlx.core.cosh
mlx.core.dequantize
+mlx.core.diag
+mlx.core.diagonal
mlx.core.divide
mlx.core.divmod
mlx.core.equal
@@ -352,7 +354,6 @@
mlx.core.jvp
mlx.core.vjp
mlx.core.vmap
-mlx.core.simplify
FFT
Linear Algebra
Neural Networks
@@ -434,6 +437,7 @@
mlx.nn.prelu
mlx.nn.relu
mlx.nn.selu
+mlx.nn.softshrink
mlx.nn.silu
mlx.nn.step
@@ -442,6 +446,7 @@
mlx.nn.losses.binary_cross_entropy
mlx.nn.losses.cosine_similarity_loss
mlx.nn.losses.cross_entropy
+mlx.nn.losses.gaussian_nll_loss
mlx.nn.losses.hinge_loss
mlx.nn.losses.huber_loss
mlx.nn.losses.kl_div_loss
@@ -453,14 +458,26 @@
mlx.nn.losses.triplet_loss
+Initializers
-Optimizers
+
+Optimizers
-Tree Utils
@@ -441,6 +445,7 @@
mlx.nn.losses.binary_cross_entropy
mlx.nn.losses.cosine_similarity_loss
mlx.nn.losses.cross_entropy
+mlx.nn.losses.gaussian_nll_loss
mlx.nn.losses.hinge_loss
mlx.nn.losses.huber_loss
mlx.nn.losses.kl_div_loss
@@ -452,14 +457,26 @@
mlx.nn.losses.triplet_loss
+Initializers
-Optimizers
+
+Optimizers
-Tree Utils
+Tree Utils
mlx.utils.tree_flatten
mlx.utils.tree_unflatten
mlx.utils.tree_map
@@ -650,7 +667,7 @@ document.write(`
mlx.nn.Transformer
-class mlx.nn. Transformer ( dims: int = 512, num_heads: int = 8, num_encoder_layers: int = 6, num_decoder_layers: int = 6, mlp_dims: ~typing.Optional[int] = None, dropout: float = 0.0, activation: ~typing.Callable[[~typing.Any], ~typing.Any] = <function relu>, custom_encoder: ~typing.Optional[~typing.Any] = None, custom_decoder: ~typing.Optional[~typing.Any] = None, norm_first: bool = False )
+class mlx.nn. Transformer ( dims: int = 512, num_heads: int = 8, num_encoder_layers: int = 6, num_decoder_layers: int = 6, mlp_dims: ~typing.Optional[int] = None, dropout: float = 0.0, activation: ~typing.Callable[[~typing.Any], ~typing.Any] = <function relu>, custom_encoder: ~typing.Optional[~typing.Any] = None, custom_decoder: ~typing.Optional[~typing.Any] = None, norm_first: bool = True, checkpoint: bool = False )
Implements a standard Transformer model.
The implementation is based on Attention Is All You Need .
The Transformer model contains an encoder and a decoder. The encoder
@@ -682,7 +699,10 @@ standard Transformer encoder. Default: None
.
norm_first (bool , optional ) – if True
, encoder and decoder layers
will perform layer normalization before attention and MLP
-operations, otherwise after. Default: False
.
+operations, otherwise after. Default: True
.
+chekpoint (bool , optional ) – if True
perform gradient checkpointing
+to reduce the memory usage at the expense of more computation.
+Default: False
.
diff --git a/docs/build/html/python/nn/_autosummary/mlx.nn.init.constant.html b/docs/build/html/python/nn/_autosummary/mlx.nn.init.constant.html
new file mode 100644
index 000000000..745d1e5ef
--- /dev/null
+++ b/docs/build/html/python/nn/_autosummary/mlx.nn.init.constant.html
@@ -0,0 +1,784 @@
+
+
+
+
+
+
+
+
+
+
+
+ mlx.nn.init.constant — MLX 0.1.0 documentation
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+ Skip to main content
+
+
+
+
+
+
+ Back to top
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
mlx.nn.init.constant
+
+
+
+
+
+
+
+
+
+
+mlx.nn.init.constant
+
+
+mlx.nn.init. constant ( value : float , dtype : Dtype = mlx.core.float32 ) → Callable [ [ array ] , array ]
+An initializer that returns an array filled with value
.
+
+Parameters:
+
+value (float ) – The value to fill the array with.
+dtype (Dtype , optional ) – The data type of the array. Default:
+float32
.
+
+
+Returns:
+An initializer that returns an array with the
+same shape as the input, filled with value
.
+
+Return type:
+Callable [[array ], array ]
+
+
+Example
+>>> init_fn = nn . init . constant ( 0.5 )
+>>> init_fn ( mx . zeros (( 2 , 2 )))
+array([[0.5, 0.5],
+ [0.5, 0.5]], dtype=float32)
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/docs/build/html/python/nn/_autosummary/mlx.nn.init.glorot_normal.html b/docs/build/html/python/nn/_autosummary/mlx.nn.init.glorot_normal.html
new file mode 100644
index 000000000..e84304f67
--- /dev/null
+++ b/docs/build/html/python/nn/_autosummary/mlx.nn.init.glorot_normal.html
@@ -0,0 +1,792 @@
+
+
+
+
+
+
+
+
+
+
+
+ mlx.nn.init.glorot_normal — MLX 0.1.0 documentation
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+ Skip to main content
+
+
+
+
+
+
+ Back to top
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
mlx.nn.init.glorot_normal
+
+
+
+
+
+
+
+
+
+
+mlx.nn.init.glorot_normal
+
+
+mlx.nn.init. glorot_normal ( dtype : Dtype = mlx.core.float32 ) → Callable [ [ array , float ] , array ]
+A Glorot normal initializer.
+This initializer samples from a normal distribution with a standard
+deviation computed from the number of input (fan_in
) and output
+(fan_out
) units according to:
+
+\[\sigma = \gamma \sqrt{\frac{2.0}{\text{fan_in} + \text{fan_out}}}\]
+For more details see the original reference: Understanding the difficulty
+of training deep feedforward neural networks
+
+Parameters:
+dtype (Dtype , optional ) – The data type of the array. Default: float32
.
+
+Returns:
+An initializer that returns an array
+with the same shape as the input, filled with samples from the Glorot
+normal distribution.
+
+Return type:
+Callable [[array , float ], array ]
+
+
+Example
+>>> init_fn = nn . init . glorot_normal ()
+>>> init_fn ( mx . zeros (( 2 , 2 )))
+array([[0.191107, 1.61278],
+ [-0.150594, -0.363207]], dtype=float32)
+>>> init_fn ( mx . zeros (( 2 , 2 )), gain = 4.0 )
+array([[1.89613, -4.53947],
+ [4.48095, 0.995016]], dtype=float32)
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/docs/build/html/python/nn/_autosummary/mlx.nn.init.glorot_uniform.html b/docs/build/html/python/nn/_autosummary/mlx.nn.init.glorot_uniform.html
new file mode 100644
index 000000000..f5635fe4b
--- /dev/null
+++ b/docs/build/html/python/nn/_autosummary/mlx.nn.init.glorot_uniform.html
@@ -0,0 +1,792 @@
+
+
+
+
+
+
+
+
+
+
+
+ mlx.nn.init.glorot_uniform — MLX 0.1.0 documentation
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+ Skip to main content
+
+
+
+
+
+
+ Back to top
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
mlx.nn.init.glorot_uniform
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/docs/build/html/python/nn/_autosummary/mlx.nn.init.he_normal.html b/docs/build/html/python/nn/_autosummary/mlx.nn.init.he_normal.html
new file mode 100644
index 000000000..eced872ee
--- /dev/null
+++ b/docs/build/html/python/nn/_autosummary/mlx.nn.init.he_normal.html
@@ -0,0 +1,795 @@
+
+
+
+
+
+
+
+
+
+
+
+ mlx.nn.init.he_normal — MLX 0.1.0 documentation
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+ Skip to main content
+
+
+
+
+
+
+ Back to top
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
mlx.nn.init.he_normal
+
+
+
+
+
+
+
+
+
+
+mlx.nn.init.he_normal
+
+
+mlx.nn.init. he_normal ( dtype : Dtype = mlx.core.float32 ) → Callable [ [ array , str , float ] , array ]
+Build a He normal initializer.
+This initializer samples from a normal distribution with a standard
+deviation computed from the number of input (fan_in
) or output
+(fan_out
) units according to:
+
+\[\sigma = \gamma \frac{1}{\sqrt{\text{fan}}}\]
+where \(\text{fan}\) is either the number of input units when the
+mode
is "fan_in"
or output units when the mode
is
+"fan_out"
.
+For more details see the original reference: Delving Deep into Rectifiers:
+Surpassing Human-Level Performance on ImageNet Classification
+
+Parameters:
+dtype (Dtype , optional ) – The data type of the array. Defaults to mx.float32.
+
+Returns:
+An initializer that returns an
+array with the same shape as the input, filled with samples from the He
+normal distribution.
+
+Return type:
+Callable [[array , str , float ], array ]
+
+
+Example
+>>> init_fn = nn . init . he_normal ()
+>>> init_fn ( mx . zeros (( 2 , 2 ))) # uses fan_in
+array([[-1.25211, 0.458835],
+ [-0.177208, -0.0137595]], dtype=float32)
+>>> init_fn ( mx . zeros (( 2 , 2 )), mode = "fan_out" , gain = 5 )
+array([[5.6967, 4.02765],
+ [-4.15268, -2.75787]], dtype=float32)
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/docs/build/html/python/nn/_autosummary/mlx.nn.init.he_uniform.html b/docs/build/html/python/nn/_autosummary/mlx.nn.init.he_uniform.html
new file mode 100644
index 000000000..f11d0d6e1
--- /dev/null
+++ b/docs/build/html/python/nn/_autosummary/mlx.nn.init.he_uniform.html
@@ -0,0 +1,795 @@
+
+
+
+
+
+
+
+
+
+
+
+ mlx.nn.init.he_uniform — MLX 0.1.0 documentation
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+ Skip to main content
+
+
+
+
+
+
+ Back to top
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
mlx.nn.init.he_uniform
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/docs/build/html/python/nn/_autosummary/mlx.nn.init.identity.html b/docs/build/html/python/nn/_autosummary/mlx.nn.init.identity.html
new file mode 100644
index 000000000..685cad050
--- /dev/null
+++ b/docs/build/html/python/nn/_autosummary/mlx.nn.init.identity.html
@@ -0,0 +1,781 @@
+
+
+
+
+
+
+
+
+
+
+
+ mlx.nn.init.identity — MLX 0.1.0 documentation
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+ Skip to main content
+
+
+
+
+
+
+ Back to top
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
mlx.nn.init.identity
+
+
+
+
+
+
+
+
+
+
+mlx.nn.init.identity
+
+
+mlx.nn.init. identity ( dtype : Dtype = mlx.core.float32 ) → Callable [ [ array ] , array ]
+An initializer that returns an identity matrix.
+
+Parameters:
+dtype (Dtype , optional ) – The data type of the array. Defaults:
+float32
.
+
+Returns:
+An initializer that returns an identity
+matrix with the same shape as the input.
+
+Return type:
+Callable [[array ], array ]
+
+
+Example
+>>> init_fn = nn . init . identity ()
+>>> init_fn ( mx . zeros (( 2 , 2 )))
+array([[1, 0],
+ [0, 1]], dtype=float32)
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/docs/build/html/python/nn/_autosummary/mlx.nn.init.normal.html b/docs/build/html/python/nn/_autosummary/mlx.nn.init.normal.html
new file mode 100644
index 000000000..3a31487d0
--- /dev/null
+++ b/docs/build/html/python/nn/_autosummary/mlx.nn.init.normal.html
@@ -0,0 +1,787 @@
+
+
+
+
+
+
+
+
+
+
+
+ mlx.nn.init.normal — MLX 0.1.0 documentation
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+ Skip to main content
+
+
+
+
+
+
+ Back to top
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
mlx.nn.init.normal
+
+
+
+
+
+
+
+
+
+
+mlx.nn.init.normal
+
+
+mlx.nn.init. normal ( mean : float = 0.0 , std : float = 1.0 , dtype : Dtype = mlx.core.float32 ) → Callable [ [ array ] , array ]
+An initializer that returns samples from a normal distribution.
+
+Parameters:
+
+mean (float , optional ) – Mean of the normal distribution. Default:
+0.0
.
+std (float , optional ) – Standard deviation of the normal distribution.
+Default: 1.0
.
+dtype (Dtype , optional ) – The data type of the array. Default:
+float32
.
+
+
+Returns:
+An initializer that returns an array with the
+same shape as the input, filled with samples from a normal distribution.
+
+Return type:
+Callable [[array ], array ]
+
+
+Example
+>>> init_fn = nn . init . normal ()
+>>> init_fn ( mx . zeros (( 2 , 2 )))
+array([[-0.982273, -0.534422],
+ [0.380709, 0.0645099]], dtype=float32)
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/docs/build/html/python/nn/_autosummary/mlx.nn.init.uniform.html b/docs/build/html/python/nn/_autosummary/mlx.nn.init.uniform.html
new file mode 100644
index 000000000..26535b35a
--- /dev/null
+++ b/docs/build/html/python/nn/_autosummary/mlx.nn.init.uniform.html
@@ -0,0 +1,787 @@
+
+
+
+
+
+
+
+
+
+
+
+ mlx.nn.init.uniform — MLX 0.1.0 documentation
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+ Skip to main content
+
+
+
+
+
+
+ Back to top
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
mlx.nn.init.uniform
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/docs/build/html/python/nn/_autosummary_functions/mlx.nn.gelu.html b/docs/build/html/python/nn/_autosummary_functions/mlx.nn.gelu.html
index 842999ec4..d25abbc47 100644
--- a/docs/build/html/python/nn/_autosummary_functions/mlx.nn.gelu.html
+++ b/docs/build/html/python/nn/_autosummary_functions/mlx.nn.gelu.html
@@ -9,7 +9,7 @@
- mlx.nn.gelu — MLX 0.0.9 documentation
+ mlx.nn.gelu — MLX 0.1.0 documentation
@@ -135,8 +135,8 @@
-
-
+
+
@@ -242,6 +242,8 @@
mlx.core.cos
mlx.core.cosh
mlx.core.dequantize
+mlx.core.diag
+mlx.core.diagonal
mlx.core.divide
mlx.core.divmod
mlx.core.equal
@@ -352,7 +354,6 @@
mlx.core.jvp
mlx.core.vjp
mlx.core.vmap
-mlx.core.simplify
FFT
Linear Algebra
Neural Networks
@@ -434,6 +437,7 @@
mlx.nn.prelu
mlx.nn.relu
mlx.nn.selu
+mlx.nn.softshrink
mlx.nn.silu
mlx.nn.step
@@ -442,6 +446,7 @@
mlx.nn.losses.binary_cross_entropy
mlx.nn.losses.cosine_similarity_loss
mlx.nn.losses.cross_entropy
+mlx.nn.losses.gaussian_nll_loss
mlx.nn.losses.hinge_loss
mlx.nn.losses.huber_loss
mlx.nn.losses.kl_div_loss
@@ -453,14 +458,26 @@
mlx.nn.losses.triplet_loss
+Initializers
-Optimizers
+
+Optimizers
-Tree Utils
@@ -442,6 +446,7 @@
mlx.nn.losses.binary_cross_entropy
mlx.nn.losses.cosine_similarity_loss
mlx.nn.losses.cross_entropy
+mlx.nn.losses.gaussian_nll_loss
mlx.nn.losses.hinge_loss
mlx.nn.losses.huber_loss
mlx.nn.losses.kl_div_loss
@@ -453,14 +458,26 @@
mlx.nn.losses.triplet_loss
+Initializers
-Optimizers
+
+Optimizers
-Tree Utils
@@ -442,6 +446,7 @@
mlx.nn.losses.binary_cross_entropy
mlx.nn.losses.cosine_similarity_loss
mlx.nn.losses.cross_entropy
+mlx.nn.losses.gaussian_nll_loss
mlx.nn.losses.hinge_loss
mlx.nn.losses.huber_loss
mlx.nn.losses.kl_div_loss
@@ -453,14 +458,26 @@
mlx.nn.losses.triplet_loss
+Initializers
-Optimizers
+
+Optimizers
-Tree Utils
+Tree Utils
mlx.utils.tree_flatten
mlx.utils.tree_unflatten
mlx.utils.tree_map
diff --git a/docs/build/html/python/nn/_autosummary_functions/mlx.nn.init.constant.html b/docs/build/html/python/nn/_autosummary_functions/mlx.nn.init.constant.html
new file mode 100644
index 000000000..5b9891b48
--- /dev/null
+++ b/docs/build/html/python/nn/_autosummary_functions/mlx.nn.init.constant.html
@@ -0,0 +1,748 @@
+
+
+
+
+
+
+
+
+
+
+
+ mlx.nn.init.constant — MLX documentation
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+ Skip to main content
+
+
+
+
+
+
+ Back to top
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
mlx.nn.init.constant
+
+
+
+
+
+
+
+
+
+
+mlx.nn.init.constant
+
+
+mlx.nn.init. constant ( value : float , dtype : Dtype = mlx.core.float32 ) → Callable [ [ array ] , array ]
+An initializer that returns an array filled with value
.
+
+Parameters:
+
+value (float ) – The value to fill the array with.
+dtype (Dtype , optional ) – The data type of the array. Default: float32
.
+
+
+Returns:
+An initializer that returns an array with the
+same shape as the input, filled with value
.
+
+Return type:
+Callable [[array ], array ]
+
+
+Example
+>>> init_fn = nn . init . constant ( 0.5 )
+>>> init_fn ( mx . zeros (( 2 , 2 )))
+array([[0.5, 0.5],
+ [0.5, 0.5]], dtype=float32)
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/docs/build/html/python/nn/_autosummary_functions/mlx.nn.init.glorot_normal.html b/docs/build/html/python/nn/_autosummary_functions/mlx.nn.init.glorot_normal.html
new file mode 100644
index 000000000..7887e4e89
--- /dev/null
+++ b/docs/build/html/python/nn/_autosummary_functions/mlx.nn.init.glorot_normal.html
@@ -0,0 +1,757 @@
+
+
+
+
+
+
+
+
+
+
+
+ mlx.nn.init.glorot_normal — MLX documentation
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+ Skip to main content
+
+
+
+
+
+
+ Back to top
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
mlx.nn.init.glorot_normal
+
+
+
+
+
+
+
+
+
+
+mlx.nn.init.glorot_normal
+
+
+mlx.nn.init. glorot_normal ( dtype : Dtype = mlx.core.float32 ) → Callable [ [ array , float ] , array ]
+A Glorot normal initializer.
+This initializer samples from a normal distribution with a standard
+deviation computed from the number of input (fan_in
) and output
+(fan_out
) units according to:
+
+\[\sigma = \gamma \sqrt{\frac{2.0}{\text{fan_in} + \text{fan_out}}}\]
+For more details see the original reference:
+Understanding the difficulty of training deep feedforward neural networks
+
+Parameters:
+dtype (Dtype , optional ) – The data type of the array. Default: float32
.
+
+Returns:
+An initializer that returns an array
+with the same shape as the input, filled with samples from the Glorot
+normal distribution.
+
+Return type:
+Callable [[array , float ], array ]
+
+
+Example
+>>> init_fn = nn . init . glorot_normal ()
+>>> init_fn ( mx . zeros (( 2 , 2 )))
+array([[0.191107, 1.61278],
+ [-0.150594, -0.363207]], dtype=float32)
+>>> init_fn ( mx . zeros (( 2 , 2 )), gain = 4.0 )
+array([[1.89613, -4.53947],
+ [4.48095, 0.995016]], dtype=float32)
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/docs/build/html/python/nn/_autosummary_functions/mlx.nn.init.glorot_uniform.html b/docs/build/html/python/nn/_autosummary_functions/mlx.nn.init.glorot_uniform.html
new file mode 100644
index 000000000..7b62eaa2e
--- /dev/null
+++ b/docs/build/html/python/nn/_autosummary_functions/mlx.nn.init.glorot_uniform.html
@@ -0,0 +1,757 @@
+
+
+
+
+
+
+
+
+
+
+
+ mlx.nn.init.glorot_uniform — MLX documentation
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+ Skip to main content
+
+
+
+
+
+
+ Back to top
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
mlx.nn.init.glorot_uniform
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/docs/build/html/python/nn/_autosummary_functions/mlx.nn.init.he_normal.html b/docs/build/html/python/nn/_autosummary_functions/mlx.nn.init.he_normal.html
new file mode 100644
index 000000000..89469ead0
--- /dev/null
+++ b/docs/build/html/python/nn/_autosummary_functions/mlx.nn.init.he_normal.html
@@ -0,0 +1,761 @@
+
+
+
+
+
+
+
+
+
+
+
+ mlx.nn.init.he_normal — MLX documentation
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+ Skip to main content
+
+
+
+
+
+
+ Back to top
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
mlx.nn.init.he_normal
+
+
+
+
+
+
+
+
+
+
+mlx.nn.init.he_normal
+
+
+mlx.nn.init. he_normal ( dtype : Dtype = mlx.core.float32 ) → Callable [ [ array , str , float ] , array ]
+Build a He normal initializer.
+This initializer samples from a normal distribution with a standard
+deviation computed from the number of input (fan_in
) or output
+(fan_out
) units according to:
+
+\[\sigma = \frac{\gramma}{\sqrt{\text{fan}}}\]
+where \(\text{fan}\) is either the number of input units when the
+mode
is "fan_in"
or output units when the mode
is
+"fan_out"
.
+For more details see the original reference:
+Delving Deep into Rectifiers: Surpassing Human-Level Performance on
+ImageNet Classification
+
+Parameters:
+dtype (Dtype , optional ) – The data type of the array. Defaults to mx.float32.
+
+Returns:
+An initializer that returns an
+array with the same shape as the input, filled with samples from the He
+normal distribution.
+
+Return type:
+Callable [[array , str , float ], array ]
+
+
+Example
+>>> init_fn = nn . init . he_normal ()
+>>> init_fn ( mx . zeros (( 2 , 2 ))) # uses fan_in
+array([[-1.25211, 0.458835],
+ [-0.177208, -0.0137595]], dtype=float32)
+>>> init_fn ( mx . zeros (( 2 , 2 )), mode = "fan_out" , gain = 5 )
+array([[5.6967, 4.02765],
+ [-4.15268, -2.75787]], dtype=float32)
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/docs/build/html/python/nn/_autosummary_functions/mlx.nn.init.he_uniform.html b/docs/build/html/python/nn/_autosummary_functions/mlx.nn.init.he_uniform.html
new file mode 100644
index 000000000..e5fd098cf
--- /dev/null
+++ b/docs/build/html/python/nn/_autosummary_functions/mlx.nn.init.he_uniform.html
@@ -0,0 +1,761 @@
+
+
+
+
+
+
+
+
+
+
+
+ mlx.nn.init.he_uniform — MLX documentation
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+ Skip to main content
+
+
+
+
+
+
+ Back to top
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
mlx.nn.init.he_uniform
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/docs/build/html/python/nn/_autosummary_functions/mlx.nn.init.identity.html b/docs/build/html/python/nn/_autosummary_functions/mlx.nn.init.identity.html
new file mode 100644
index 000000000..ae669ae58
--- /dev/null
+++ b/docs/build/html/python/nn/_autosummary_functions/mlx.nn.init.identity.html
@@ -0,0 +1,746 @@
+
+
+
+
+
+
+
+
+
+
+
+ mlx.nn.init.identity — MLX documentation
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+ Skip to main content
+
+
+
+
+
+
+ Back to top
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
mlx.nn.init.identity
+
+
+
+
+
+
+
+
+
+
+mlx.nn.init.identity
+
+
+mlx.nn.init. identity ( dtype : Dtype = mlx.core.float32 ) → Callable [ [ array ] , array ]
+An initializer that returns an identity matrix.
+
+Parameters:
+dtype (Dtype , optional ) – The data type of the array. Defaults:
+float32
.
+
+Returns:
+An initializer that returns an identity
+matrix with the same shape as the input.
+
+Return type:
+Callable [[array ], array ]
+
+
+Example
+>>> init_fn = nn . init . identity ()
+>>> init_fn ( mx . zeros (( 2 , 2 )))
+array([[1, 0],
+ [0, 1]], dtype=float32)
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/docs/build/html/python/nn/_autosummary_functions/mlx.nn.init.normal.html b/docs/build/html/python/nn/_autosummary_functions/mlx.nn.init.normal.html
new file mode 100644
index 000000000..8cb8633f0
--- /dev/null
+++ b/docs/build/html/python/nn/_autosummary_functions/mlx.nn.init.normal.html
@@ -0,0 +1,752 @@
+
+
+
+
+
+
+
+
+
+
+
+ mlx.nn.init.normal — MLX documentation
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+ Skip to main content
+
+
+
+
+
+
+ Back to top
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
mlx.nn.init.normal
+
+
+
+
+
+
+
+
+
+
+mlx.nn.init.normal
+
+
+mlx.nn.init. normal ( mean : float = 0.0 , std : float = 1.0 , dtype : Dtype = mlx.core.float32 ) → Callable [ [ array ] , array ]
+An initializer that returns samples from a normal distribution.
+
+Parameters:
+
+mean (float , optional ) – Mean of the normal distribution. Default:
+0.0
.
+std (float , optional ) – Standard deviation of the normal distribution.
+Default: 1.0
.
+dtype (Dtype , optional ) – The data type of the array. Default:
+float32
.
+
+
+Returns:
+An initializer that returns an array with the
+same shape as the input, filled with samples from a normal distribution.
+
+Return type:
+Callable [[array ], array ]
+
+
+Example
+>>> init_fn = nn . init . normal ()
+>>> init_fn ( mx . zeros (( 2 , 2 )))
+array([[-0.982273, -0.534422],
+ [0.380709, 0.0645099]], dtype=float32)
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/docs/build/html/python/nn/_autosummary_functions/mlx.nn.init.uniform.html b/docs/build/html/python/nn/_autosummary_functions/mlx.nn.init.uniform.html
new file mode 100644
index 000000000..b8c85f7c2
--- /dev/null
+++ b/docs/build/html/python/nn/_autosummary_functions/mlx.nn.init.uniform.html
@@ -0,0 +1,752 @@
+
+
+
+
+
+
+
+
+
+
+
+ mlx.nn.init.uniform — MLX documentation
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+ Skip to main content
+
+
+
+
+
+
+ Back to top
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
mlx.nn.init.uniform
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/docs/build/html/python/nn/_autosummary_functions/mlx.nn.initializers.constant.html b/docs/build/html/python/nn/_autosummary_functions/mlx.nn.initializers.constant.html
new file mode 100644
index 000000000..6eb6b0a94
--- /dev/null
+++ b/docs/build/html/python/nn/_autosummary_functions/mlx.nn.initializers.constant.html
@@ -0,0 +1,720 @@
+
+
+
+
+
+
+
+
+
+
+
+ mlx.nn.initializers.constant — MLX documentation
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+ Skip to main content
+
+
+
+
+
+
+ Back to top
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
mlx.nn.initializers.constant
+
+
+
+
+
+
+
+
+
+
+mlx.nn.initializers.constant
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/docs/build/html/python/nn/_autosummary_functions/mlx.nn.initializers.glorot_normal.html b/docs/build/html/python/nn/_autosummary_functions/mlx.nn.initializers.glorot_normal.html
new file mode 100644
index 000000000..cee367694
--- /dev/null
+++ b/docs/build/html/python/nn/_autosummary_functions/mlx.nn.initializers.glorot_normal.html
@@ -0,0 +1,720 @@
+
+
+
+
+
+
+
+
+
+
+
+ mlx.nn.initializers.glorot_normal — MLX documentation
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+ Skip to main content
+
+
+
+
+
+
+ Back to top
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
mlx.nn.initializers.glorot_normal
+
+
+
+
+
+
+
+
+
+
+mlx.nn.initializers.glorot_normal
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/docs/build/html/python/nn/_autosummary_functions/mlx.nn.initializers.glorot_uniform.html b/docs/build/html/python/nn/_autosummary_functions/mlx.nn.initializers.glorot_uniform.html
new file mode 100644
index 000000000..00f6fb9d9
--- /dev/null
+++ b/docs/build/html/python/nn/_autosummary_functions/mlx.nn.initializers.glorot_uniform.html
@@ -0,0 +1,720 @@
+
+
+
+
+
+
+
+
+
+
+
+ mlx.nn.initializers.glorot_uniform — MLX documentation
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+ Skip to main content
+
+
+
+
+
+
+ Back to top
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
mlx.nn.initializers.glorot_uniform
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/docs/build/html/python/nn/_autosummary_functions/mlx.nn.initializers.he_normal.html b/docs/build/html/python/nn/_autosummary_functions/mlx.nn.initializers.he_normal.html
new file mode 100644
index 000000000..fb55c661e
--- /dev/null
+++ b/docs/build/html/python/nn/_autosummary_functions/mlx.nn.initializers.he_normal.html
@@ -0,0 +1,720 @@
+
+
+
+
+
+
+
+
+
+
+
+ mlx.nn.initializers.he_normal — MLX documentation
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+ Skip to main content
+
+
+
+
+
+
+ Back to top
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
mlx.nn.initializers.he_normal
+
+
+
+
+
+
+
+
+
+
+mlx.nn.initializers.he_normal
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/docs/build/html/python/nn/_autosummary_functions/mlx.nn.initializers.he_uniform.html b/docs/build/html/python/nn/_autosummary_functions/mlx.nn.initializers.he_uniform.html
new file mode 100644
index 000000000..950207657
--- /dev/null
+++ b/docs/build/html/python/nn/_autosummary_functions/mlx.nn.initializers.he_uniform.html
@@ -0,0 +1,720 @@
+
+
+
+
+
+
+
+
+
+
+
+ mlx.nn.initializers.he_uniform — MLX documentation
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+ Skip to main content
+
+
+
+
+
+
+ Back to top
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
mlx.nn.initializers.he_uniform
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/docs/build/html/python/nn/_autosummary_functions/mlx.nn.initializers.identity.html b/docs/build/html/python/nn/_autosummary_functions/mlx.nn.initializers.identity.html
new file mode 100644
index 000000000..09399bb5a
--- /dev/null
+++ b/docs/build/html/python/nn/_autosummary_functions/mlx.nn.initializers.identity.html
@@ -0,0 +1,720 @@
+
+
+
+
+
+
+
+
+
+
+
+ mlx.nn.initializers.identity — MLX documentation
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+ Skip to main content
+
+
+
+
+
+
+ Back to top
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
mlx.nn.initializers.identity
+
+
+
+
+
+
+
+
+
+
+mlx.nn.initializers.identity
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/docs/build/html/python/nn/_autosummary_functions/mlx.nn.initializers.normal.html b/docs/build/html/python/nn/_autosummary_functions/mlx.nn.initializers.normal.html
new file mode 100644
index 000000000..e62ee5239
--- /dev/null
+++ b/docs/build/html/python/nn/_autosummary_functions/mlx.nn.initializers.normal.html
@@ -0,0 +1,720 @@
+
+
+
+
+
+
+
+
+
+
+
+ mlx.nn.initializers.normal — MLX documentation
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+ Skip to main content
+
+
+
+
+
+
+ Back to top
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
mlx.nn.initializers.normal
+
+
+
+
+
+
+
+
+
+
+mlx.nn.initializers.normal
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/docs/build/html/python/nn/_autosummary_functions/mlx.nn.initializers.uniform.html b/docs/build/html/python/nn/_autosummary_functions/mlx.nn.initializers.uniform.html
new file mode 100644
index 000000000..997779fb3
--- /dev/null
+++ b/docs/build/html/python/nn/_autosummary_functions/mlx.nn.initializers.uniform.html
@@ -0,0 +1,720 @@
+
+
+
+
+
+
+
+
+
+
+
+ mlx.nn.initializers.uniform — MLX documentation
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+ Skip to main content
+
+
+
+
+
+
+ Back to top
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
mlx.nn.initializers.uniform
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/docs/build/html/python/nn/_autosummary_functions/mlx.nn.losses.binary_cross_entropy.html b/docs/build/html/python/nn/_autosummary_functions/mlx.nn.losses.binary_cross_entropy.html
index e08ad3bcd..5731d3276 100644
--- a/docs/build/html/python/nn/_autosummary_functions/mlx.nn.losses.binary_cross_entropy.html
+++ b/docs/build/html/python/nn/_autosummary_functions/mlx.nn.losses.binary_cross_entropy.html
@@ -9,7 +9,7 @@
- mlx.nn.losses.binary_cross_entropy — MLX 0.0.9 documentation
+ mlx.nn.losses.binary_cross_entropy — MLX 0.1.0 documentation
@@ -134,8 +134,8 @@
-
-
+
+
@@ -241,6 +241,8 @@
mlx.core.cos
mlx.core.cosh
mlx.core.dequantize
+mlx.core.diag
+mlx.core.diagonal
mlx.core.divide
mlx.core.divmod
mlx.core.equal
@@ -351,7 +353,6 @@
mlx.core.jvp
mlx.core.vjp
mlx.core.vmap
-mlx.core.simplify
FFT
Linear Algebra
Neural Networks
@@ -433,6 +436,7 @@
mlx.nn.prelu
mlx.nn.relu
mlx.nn.selu
+mlx.nn.softshrink
mlx.nn.silu
mlx.nn.step
@@ -441,6 +445,7 @@
mlx.nn.losses.binary_cross_entropy
mlx.nn.losses.cosine_similarity_loss
mlx.nn.losses.cross_entropy
+mlx.nn.losses.gaussian_nll_loss
mlx.nn.losses.hinge_loss
mlx.nn.losses.huber_loss
mlx.nn.losses.kl_div_loss
@@ -452,14 +457,26 @@
mlx.nn.losses.triplet_loss
+Initializers
-Optimizers
+
+Optimizers
-Tree Utils
+Tree Utils
mlx.utils.tree_flatten
mlx.utils.tree_unflatten
mlx.utils.tree_map
@@ -650,15 +667,17 @@ document.write(`
mlx.nn.losses.binary_cross_entropy
-class mlx.nn.losses. binary_cross_entropy ( logits : array , targets : array , reduction : str = 'none' )
+class mlx.nn.losses. binary_cross_entropy ( inputs : array , targets : array , with_logits : bool = True , reduction : Literal [ 'none' , 'mean' , 'sum' ] = 'mean' )
Computes the binary cross entropy loss.
Parameters:
-logits (array ) – The unnormalized (pre-sigmoid) predicted logits.
+inputs (array ) – The predicted values. If with_logits
is True
, then
+inputs
are unnormalized logits. Otherwise, inputs
are probabilities.
targets (array ) – The binary target values in {0, 1}.
+with_logits (bool , optional ) – Whether inputs
are logits. Default: True
.
reduction (str , optional ) – Specifies the reduction to apply to the output:
-'none'
| 'mean'
| 'sum'
. Default: 'none'
.
+'none'
| 'mean'
| 'sum'
. Default: 'mean'
.
Returns:
@@ -671,11 +690,20 @@ document.write(`
Examples
>>> import mlx.core as mx
>>> import mlx.nn as nn
->>> inputs = mx . array ([ 0.105361 , 0.223144 , 1.20397 , 0.916291 ])
+
+
+>>> logits = mx . array ([ 0.105361 , 0.223144 , 1.20397 , 0.916291 ])
>>> targets = mx . array ([ 0 , 0 , 1 , 1 ])
->>> loss = nn . losses . binary_cross_entropy ( inputs , targets , "mean" )
+>>> loss = nn . losses . binary_cross_entropy ( logits , targets , reduction = "mean" )
>>> loss
-array([0.612192], dtype=float32)
+array(0.539245, dtype=float32)
+
+
+>>> probs = mx . array ([ 0.1 , 0.1 , 0.4 , 0.4 ])
+>>> targets = mx . array ([ 0 , 0 , 1 , 1 ])
+>>> loss = nn . losses . binary_cross_entropy ( probs , targets , with_logits = False , reduction = "mean" )
+>>> loss
+array(0.510826, dtype=float32)
diff --git a/docs/build/html/python/nn/_autosummary_functions/mlx.nn.losses.cosine_similarity_loss.html b/docs/build/html/python/nn/_autosummary_functions/mlx.nn.losses.cosine_similarity_loss.html
index 4d03d99aa..8255dbc30 100644
--- a/docs/build/html/python/nn/_autosummary_functions/mlx.nn.losses.cosine_similarity_loss.html
+++ b/docs/build/html/python/nn/_autosummary_functions/mlx.nn.losses.cosine_similarity_loss.html
@@ -9,7 +9,7 @@
- mlx.nn.losses.cosine_similarity_loss — MLX 0.0.9 documentation
+ mlx.nn.losses.cosine_similarity_loss — MLX 0.1.0 documentation
@@ -135,8 +135,8 @@
-
-
+
+
@@ -242,6 +242,8 @@
mlx.core.cos
mlx.core.cosh
mlx.core.dequantize
+mlx.core.diag
+mlx.core.diagonal
mlx.core.divide
mlx.core.divmod
mlx.core.equal
@@ -352,7 +354,6 @@
mlx.core.jvp
mlx.core.vjp
mlx.core.vmap
-mlx.core.simplify
FFT
Linear Algebra
Neural Networks
@@ -434,6 +437,7 @@
mlx.nn.prelu
mlx.nn.relu
mlx.nn.selu
+mlx.nn.softshrink
mlx.nn.silu
mlx.nn.step
@@ -442,6 +446,7 @@
mlx.nn.losses.binary_cross_entropy
mlx.nn.losses.cosine_similarity_loss
mlx.nn.losses.cross_entropy
+mlx.nn.losses.gaussian_nll_loss
mlx.nn.losses.hinge_loss
mlx.nn.losses.huber_loss
mlx.nn.losses.kl_div_loss
@@ -453,14 +458,26 @@
mlx.nn.losses.triplet_loss
+Initializers
-Optimizers
+
+Optimizers
-Tree Utils
+Tree Utils
mlx.utils.tree_flatten
mlx.utils.tree_unflatten
mlx.utils.tree_map
@@ -651,7 +668,7 @@ document.write(`
mlx.nn.losses.cosine_similarity_loss
-class mlx.nn.losses. cosine_similarity_loss ( x1 : array , x2 : array , axis : int = 1 , eps : float = 1e-08 , reduction : str = 'none' )
+class mlx.nn.losses. cosine_similarity_loss ( x1 : array , x2 : array , axis : int = 1 , eps : float = 1e-08 , reduction : Literal [ 'none' , 'mean' , 'sum' ] = 'none' )
Computes the cosine similarity between the two inputs.
The cosine similarity loss is given by
diff --git a/docs/build/html/python/nn/_autosummary_functions/mlx.nn.losses.cross_entropy.html b/docs/build/html/python/nn/_autosummary_functions/mlx.nn.losses.cross_entropy.html
index fa3193bc3..bf587fd4a 100644
--- a/docs/build/html/python/nn/_autosummary_functions/mlx.nn.losses.cross_entropy.html
+++ b/docs/build/html/python/nn/_autosummary_functions/mlx.nn.losses.cross_entropy.html
@@ -9,7 +9,7 @@
-
mlx.nn.losses.cross_entropy — MLX 0.0.9 documentation
+
mlx.nn.losses.cross_entropy — MLX 0.1.0 documentation
@@ -46,7 +46,7 @@
-
+
@@ -134,8 +134,8 @@
-
-
+
+
@@ -241,6 +241,8 @@
mlx.core.cos
mlx.core.cosh
mlx.core.dequantize
+mlx.core.diag
+mlx.core.diagonal
mlx.core.divide
mlx.core.divmod
mlx.core.equal
@@ -351,7 +353,6 @@
mlx.core.jvp
mlx.core.vjp
mlx.core.vmap
-mlx.core.simplify
FFT
Linear Algebra
Neural Networks
@@ -433,6 +436,7 @@
mlx.nn.prelu
mlx.nn.relu
mlx.nn.selu
+mlx.nn.softshrink
mlx.nn.silu
mlx.nn.step
@@ -441,6 +445,7 @@
mlx.nn.losses.binary_cross_entropy
mlx.nn.losses.cosine_similarity_loss
mlx.nn.losses.cross_entropy
+mlx.nn.losses.gaussian_nll_loss
mlx.nn.losses.hinge_loss
mlx.nn.losses.huber_loss
mlx.nn.losses.kl_div_loss
@@ -452,14 +457,26 @@
mlx.nn.losses.triplet_loss
+Initializers
-Optimizers
+
+Optimizers
-Tree Utils
+Tree Utils
mlx.utils.tree_flatten
mlx.utils.tree_unflatten
mlx.utils.tree_map
@@ -650,14 +667,19 @@ document.write(`
mlx.nn.losses.cross_entropy
-class mlx.nn.losses. cross_entropy ( logits : array , targets : array , weights : Optional [ array ] = None , axis : int = - 1 , label_smoothing : float = 0.0 , reduction : str = 'none' )
+class mlx.nn.losses. cross_entropy ( logits : array , targets : array , weights : Optional [ array ] = None , axis : int = - 1 , label_smoothing : float = 0.0 , reduction : Literal [ 'none' , 'mean' , 'sum' ] = 'none' )
Computes the cross entropy loss.
Parameters:
-logits (array ) – The unnormalized predicted logits.
-targets (array ) – The target values, as class indices.
-weights (array , optional ) – Weights for each target. Default: None
.
+logits (array ) – The unnormalized logits.
+targets (array ) – The ground truth values. These can be class indices or
+probabilities for each class. If the targets
are class indices,
+then targets
shape should match the logits
shape with
+the axis
dimension removed. If the targets
are probabilities
+(or one-hot encoded), then the targets
shape should be the same as
+the logits
shape.
+weights (array , optional ) – Optional weights for each target. Default: None
.
axis (int , optional ) – The axis over which to compute softmax. Default: -1
.
label_smoothing (float , optional ) – Label smoothing factor. Default: 0
.
reduction (str , optional ) – Specifies the reduction to apply to the output:
@@ -671,6 +693,23 @@ document.write(`
array
+Examples
+>>> import mlx.core as mx
+>>> import mlx.nn as nn
+>>>
+>>> # Class indices as targets
+>>> logits = mx . array ([[ 2.0 , - 1.0 ], [ - 1.0 , 2.0 ]])
+>>> targets = mx . array ([ 0 , 1 ])
+>>> nn . losses . cross_entropy ( logits , targets )
+array([0.0485873, 0.0485873], dtype=float32)
+>>>
+>>> # Probabilities (or one-hot vectors) as targets
+>>> logits = mx . array ([[ 2.0 , - 1.0 ], [ - 1.0 , 2.0 ]])
+>>> targets = mx . array ([[ 0.9 , 0.1 ], [ 0.1 , 0.9 ]])
+>>> nn . losses . cross_entropy ( logits , targets )
+array([0.348587, 0.348587], dtype=float32)
+
+
@@ -696,11 +735,11 @@ document.write(`
next
-
mlx.nn.losses.hinge_loss
+
mlx.nn.losses.gaussian_nll_loss
diff --git a/docs/build/html/python/nn/_autosummary_functions/mlx.nn.losses.gaussian_nll_loss.html b/docs/build/html/python/nn/_autosummary_functions/mlx.nn.losses.gaussian_nll_loss.html
new file mode 100644
index 000000000..56a1691d4
--- /dev/null
+++ b/docs/build/html/python/nn/_autosummary_functions/mlx.nn.losses.gaussian_nll_loss.html
@@ -0,0 +1,790 @@
+
+
+
+
+
+
+
+
+
+
+
+ mlx.nn.losses.gaussian_nll_loss — MLX 0.1.0 documentation
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+ Skip to main content
+
+
+
+
+
+
+ Back to top
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
mlx.nn.losses.gaussian_nll_loss
+
+
+
+
+
+
+
+
+
+
+mlx.nn.losses.gaussian_nll_loss
+
+
+class mlx.nn.losses. gaussian_nll_loss ( inputs : array , targets : array , vars : array , full : bool = False , eps : float = 1e-06 , reduction : Literal [ 'none' , 'mean' , 'sum' ] = 'mean' )
+Computes the negative log likelihood loss for a Gaussian distribution.
+The loss is given by:
+
+\[\frac{1}{2}\left(\log\left(\max\left(\text{vars},
+\ \epsilon\right)\right) + \frac{\left(\text{inputs} - \text{targets} \right)^2}
+{\max\left(\text{vars}, \ \epsilon \right)}\right) + \text{const.}\]
+where inputs
are the predicted means and vars
are the the
+predicted variances.
+
+Parameters:
+
+inputs (array ) – The predicted expectation of the Gaussian distribution.
+targets (array ) – The target values (samples from the Gaussian distribution).
+vars (array ) – The predicted variance of the Gaussian distribution.
+full (bool , optional ) – Whether to include the constant term in the loss calculation.
+Default: False
.
+eps (float , optional ) – Small positive constant for numerical stability.
+Default: 1e-6
.
+reduction (str , optional ) – Specifies the reduction to apply to the output:
+'none'
| 'mean'
| 'sum'
. Default: 'none'
.
+
+
+Returns:
+The Gaussian NLL loss.
+
+Return type:
+array
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/docs/build/html/python/nn/_autosummary_functions/mlx.nn.losses.hinge_loss.html b/docs/build/html/python/nn/_autosummary_functions/mlx.nn.losses.hinge_loss.html
index c16c8b53c..a5caf3bda 100644
--- a/docs/build/html/python/nn/_autosummary_functions/mlx.nn.losses.hinge_loss.html
+++ b/docs/build/html/python/nn/_autosummary_functions/mlx.nn.losses.hinge_loss.html
@@ -9,7 +9,7 @@
- mlx.nn.losses.hinge_loss — MLX 0.0.9 documentation
+ mlx.nn.losses.hinge_loss — MLX 0.1.0 documentation
@@ -48,7 +48,7 @@
-
+
@@ -135,8 +135,8 @@
-
-
+
+
@@ -242,6 +242,8 @@
mlx.core.cos
mlx.core.cosh
mlx.core.dequantize
+mlx.core.diag
+mlx.core.diagonal
mlx.core.divide
mlx.core.divmod
mlx.core.equal
@@ -352,7 +354,6 @@
mlx.core.jvp
mlx.core.vjp
mlx.core.vmap
-mlx.core.simplify
FFT
Linear Algebra
Neural Networks
@@ -434,6 +437,7 @@
mlx.nn.prelu
mlx.nn.relu
mlx.nn.selu
+mlx.nn.softshrink
mlx.nn.silu
mlx.nn.step
@@ -442,6 +446,7 @@
mlx.nn.losses.binary_cross_entropy
mlx.nn.losses.cosine_similarity_loss
mlx.nn.losses.cross_entropy
+mlx.nn.losses.gaussian_nll_loss
mlx.nn.losses.hinge_loss
mlx.nn.losses.huber_loss
mlx.nn.losses.kl_div_loss
@@ -453,14 +458,26 @@
mlx.nn.losses.triplet_loss
+Initializers
-Optimizers
+
+Optimizers
-Tree Utils
+Tree Utils
mlx.utils.tree_flatten
mlx.utils.tree_unflatten
mlx.utils.tree_map
@@ -651,7 +668,7 @@ document.write(`
mlx.nn.losses.hinge_loss
-class mlx.nn.losses. hinge_loss ( inputs : array , targets : array , reduction : str = 'none' )
+class mlx.nn.losses. hinge_loss ( inputs : array , targets : array , reduction : Literal [ 'none' , 'mean' , 'sum' ] = 'none' )
Computes the hinge loss between inputs and targets.
\[\text{hinge}(y, y_{\text{pred}}) = \max(0, 1 - y \cdot y_{\text{pred}})\]
@@ -687,12 +704,12 @@ document.write(`
@@ -242,6 +242,8 @@
mlx.core.cos
mlx.core.cosh
mlx.core.dequantize
+mlx.core.diag
+mlx.core.diagonal
mlx.core.divide
mlx.core.divmod
mlx.core.equal
@@ -352,7 +354,6 @@
mlx.core.jvp
mlx.core.vjp
mlx.core.vmap
-mlx.core.simplify
FFT
Linear Algebra
Neural Networks
@@ -434,6 +437,7 @@
mlx.nn.prelu
mlx.nn.relu
mlx.nn.selu
+mlx.nn.softshrink
mlx.nn.silu
mlx.nn.step
@@ -442,6 +446,7 @@
mlx.nn.losses.binary_cross_entropy
mlx.nn.losses.cosine_similarity_loss
mlx.nn.losses.cross_entropy
+mlx.nn.losses.gaussian_nll_loss
mlx.nn.losses.hinge_loss
mlx.nn.losses.huber_loss
mlx.nn.losses.kl_div_loss
@@ -453,14 +458,26 @@
mlx.nn.losses.triplet_loss
+Initializers
-Optimizers
+
+Optimizers
-Tree Utils
+Tree Utils
mlx.utils.tree_flatten
mlx.utils.tree_unflatten
mlx.utils.tree_map
@@ -651,10 +668,10 @@ document.write(`
mlx.nn.losses.huber_loss
-class mlx.nn.losses. huber_loss ( inputs : array , targets : array , delta : float = 1.0 , reduction : str = 'none' )
+class mlx.nn.losses. huber_loss ( inputs : array , targets : array , delta : float = 1.0 , reduction : Literal [ 'none' , 'mean' , 'sum' ] = 'none' )
Computes the Huber loss between inputs and targets.
-\[\begin{split}L_{\delta}(a) =
+\[\begin{split}l_{\delta}(a) =
\left\{ \begin{array}{ll}
\frac{1}{2} a^2 & \text{for } |a| \leq \delta, \\
\delta \left( |a| - \frac{1}{2} \delta \right) & \text{otherwise.}
diff --git a/docs/build/html/python/nn/_autosummary_functions/mlx.nn.losses.kl_div_loss.html b/docs/build/html/python/nn/_autosummary_functions/mlx.nn.losses.kl_div_loss.html
index 084980d41..3b2ed786a 100644
--- a/docs/build/html/python/nn/_autosummary_functions/mlx.nn.losses.kl_div_loss.html
+++ b/docs/build/html/python/nn/_autosummary_functions/mlx.nn.losses.kl_div_loss.html
@@ -9,7 +9,7 @@
-
mlx.nn.losses.kl_div_loss — MLX 0.0.9 documentation
+
mlx.nn.losses.kl_div_loss — MLX 0.1.0 documentation
@@ -134,8 +134,8 @@
-
-
+
+
@@ -241,6 +241,8 @@
mlx.core.cos
mlx.core.cosh
mlx.core.dequantize
+mlx.core.diag
+mlx.core.diagonal
mlx.core.divide
mlx.core.divmod
mlx.core.equal
@@ -351,7 +353,6 @@
mlx.core.jvp
mlx.core.vjp
mlx.core.vmap
-mlx.core.simplify
FFT
Linear Algebra
Neural Networks
@@ -433,6 +436,7 @@
mlx.nn.prelu
mlx.nn.relu
mlx.nn.selu
+mlx.nn.softshrink
mlx.nn.silu
mlx.nn.step
@@ -441,6 +445,7 @@
mlx.nn.losses.binary_cross_entropy
mlx.nn.losses.cosine_similarity_loss
mlx.nn.losses.cross_entropy
+mlx.nn.losses.gaussian_nll_loss
mlx.nn.losses.hinge_loss
mlx.nn.losses.huber_loss
mlx.nn.losses.kl_div_loss
@@ -452,14 +457,26 @@
mlx.nn.losses.triplet_loss
+Initializers
-Optimizers
+
+Optimizers
-Tree Utils
+Tree Utils
mlx.utils.tree_flatten
mlx.utils.tree_unflatten
mlx.utils.tree_map
@@ -650,7 +667,7 @@ document.write(`
mlx.nn.losses.kl_div_loss
-class mlx.nn.losses. kl_div_loss ( inputs : array , targets : array , axis : int = - 1 , reduction : str = 'none' )
+class mlx.nn.losses. kl_div_loss ( inputs : array , targets : array , axis : int = - 1 , reduction : Literal [ 'none' , 'mean' , 'sum' ] = 'none' )
Computes the Kullback-Leibler divergence loss.
Computes the following when reduction == 'none'
:
mx . exp ( targets ) * ( targets - inputs ) . sum ( axis )
diff --git a/docs/build/html/python/nn/_autosummary_functions/mlx.nn.losses.l1_loss.html b/docs/build/html/python/nn/_autosummary_functions/mlx.nn.losses.l1_loss.html
index 545ec1284..797677444 100644
--- a/docs/build/html/python/nn/_autosummary_functions/mlx.nn.losses.l1_loss.html
+++ b/docs/build/html/python/nn/_autosummary_functions/mlx.nn.losses.l1_loss.html
@@ -9,7 +9,7 @@
- mlx.nn.losses.l1_loss — MLX 0.0.9 documentation
+ mlx.nn.losses.l1_loss — MLX 0.1.0 documentation
@@ -134,8 +134,8 @@
-
-
+
+
@@ -241,6 +241,8 @@
mlx.core.cos
mlx.core.cosh
mlx.core.dequantize
+
mlx.core.diag
+
mlx.core.diagonal
mlx.core.divide
mlx.core.divmod
mlx.core.equal
@@ -351,7 +353,6 @@
mlx.core.jvp
mlx.core.vjp
mlx.core.vmap
-
mlx.core.simplify
FFT
Linear Algebra
Neural Networks
@@ -433,6 +436,7 @@
mlx.nn.prelu
mlx.nn.relu
mlx.nn.selu
+
mlx.nn.softshrink
mlx.nn.silu
mlx.nn.step
@@ -441,6 +445,7 @@
mlx.nn.losses.binary_cross_entropy
mlx.nn.losses.cosine_similarity_loss
mlx.nn.losses.cross_entropy
+
mlx.nn.losses.gaussian_nll_loss
mlx.nn.losses.hinge_loss
mlx.nn.losses.huber_loss
mlx.nn.losses.kl_div_loss
@@ -452,14 +457,26 @@
mlx.nn.losses.triplet_loss
+
Initializers
-
Optimizers
+
+
Optimizers
-
Tree Utils
+Tree Utils
mlx.utils.tree_flatten
mlx.utils.tree_unflatten
mlx.utils.tree_map
@@ -650,7 +667,7 @@ document.write(`
mlx.nn.losses.l1_loss
-class mlx.nn.losses. l1_loss ( predictions : array , targets : array , reduction : str = 'mean' )
+class mlx.nn.losses. l1_loss ( predictions : array , targets : array , reduction : Literal [ 'none' , 'mean' , 'sum' ] = 'mean' )
Computes the L1 loss.
Parameters:
diff --git a/docs/build/html/python/nn/_autosummary_functions/mlx.nn.losses.log_cosh_loss.html b/docs/build/html/python/nn/_autosummary_functions/mlx.nn.losses.log_cosh_loss.html
index ec099ea09..98fee95e9 100644
--- a/docs/build/html/python/nn/_autosummary_functions/mlx.nn.losses.log_cosh_loss.html
+++ b/docs/build/html/python/nn/_autosummary_functions/mlx.nn.losses.log_cosh_loss.html
@@ -9,7 +9,7 @@
- mlx.nn.losses.log_cosh_loss — MLX 0.0.9 documentation
+ mlx.nn.losses.log_cosh_loss — MLX 0.1.0 documentation
@@ -135,8 +135,8 @@
-
-
+
+
@@ -242,6 +242,8 @@
mlx.core.cos
mlx.core.cosh
mlx.core.dequantize
+mlx.core.diag
+mlx.core.diagonal
mlx.core.divide
mlx.core.divmod
mlx.core.equal
@@ -352,7 +354,6 @@
mlx.core.jvp
mlx.core.vjp
mlx.core.vmap
-mlx.core.simplify
FFT
Linear Algebra
Neural Networks
@@ -434,6 +437,7 @@
mlx.nn.prelu
mlx.nn.relu
mlx.nn.selu
+mlx.nn.softshrink
mlx.nn.silu
mlx.nn.step
@@ -442,6 +446,7 @@
mlx.nn.losses.binary_cross_entropy
mlx.nn.losses.cosine_similarity_loss
mlx.nn.losses.cross_entropy
+mlx.nn.losses.gaussian_nll_loss
mlx.nn.losses.hinge_loss
mlx.nn.losses.huber_loss
mlx.nn.losses.kl_div_loss
@@ -453,14 +458,26 @@
mlx.nn.losses.triplet_loss
+Initializers
-Optimizers
+
+Optimizers
-Tree Utils
+Tree Utils
mlx.utils.tree_flatten
mlx.utils.tree_unflatten
mlx.utils.tree_map
@@ -651,7 +668,7 @@ document.write(`
mlx.nn.losses.log_cosh_loss
-class mlx.nn.losses. log_cosh_loss ( inputs : array , targets : array , reduction : str = 'none' )
+class mlx.nn.losses. log_cosh_loss ( inputs : array , targets : array , reduction : Literal [ 'none' , 'mean' , 'sum' ] = 'none' )
Computes the log cosh loss between inputs and targets.
Logcosh acts like L2 loss for small errors, ensuring stable gradients,
and like the L1 loss for large errors, reducing sensitivity to outliers. This
diff --git a/docs/build/html/python/nn/_autosummary_functions/mlx.nn.losses.mse_loss.html b/docs/build/html/python/nn/_autosummary_functions/mlx.nn.losses.mse_loss.html
index 5fae50fee..cb30814cd 100644
--- a/docs/build/html/python/nn/_autosummary_functions/mlx.nn.losses.mse_loss.html
+++ b/docs/build/html/python/nn/_autosummary_functions/mlx.nn.losses.mse_loss.html
@@ -9,7 +9,7 @@
-
mlx.nn.losses.mse_loss — MLX 0.0.9 documentation
+ mlx.nn.losses.mse_loss — MLX 0.1.0 documentation
@@ -134,8 +134,8 @@
-
-
+
+
@@ -241,6 +241,8 @@
mlx.core.cos
mlx.core.cosh
mlx.core.dequantize
+mlx.core.diag
+mlx.core.diagonal
mlx.core.divide
mlx.core.divmod
mlx.core.equal
@@ -351,7 +353,6 @@
mlx.core.jvp
mlx.core.vjp
mlx.core.vmap
-mlx.core.simplify
FFT
Linear Algebra
Neural Networks
@@ -433,6 +436,7 @@
mlx.nn.prelu
mlx.nn.relu
mlx.nn.selu
+mlx.nn.softshrink
mlx.nn.silu
mlx.nn.step
@@ -441,6 +445,7 @@
mlx.nn.losses.binary_cross_entropy
mlx.nn.losses.cosine_similarity_loss
mlx.nn.losses.cross_entropy
+mlx.nn.losses.gaussian_nll_loss
mlx.nn.losses.hinge_loss
mlx.nn.losses.huber_loss
mlx.nn.losses.kl_div_loss
@@ -452,14 +457,26 @@
mlx.nn.losses.triplet_loss
+Initializers
-Optimizers
+
+Optimizers
-Tree Utils
+Tree Utils
mlx.utils.tree_flatten
mlx.utils.tree_unflatten
mlx.utils.tree_map
@@ -650,7 +667,7 @@ document.write(`
mlx.nn.losses.mse_loss
-class mlx.nn.losses. mse_loss ( predictions : array , targets : array , reduction : str = 'mean' )
+class mlx.nn.losses. mse_loss ( predictions : array , targets : array , reduction : Literal [ 'none' , 'mean' , 'sum' ] = 'mean' )
Computes the mean squared error loss.
Parameters:
diff --git a/docs/build/html/python/nn/_autosummary_functions/mlx.nn.losses.nll_loss.html b/docs/build/html/python/nn/_autosummary_functions/mlx.nn.losses.nll_loss.html
index 2cd90a9fd..0bd0dc35e 100644
--- a/docs/build/html/python/nn/_autosummary_functions/mlx.nn.losses.nll_loss.html
+++ b/docs/build/html/python/nn/_autosummary_functions/mlx.nn.losses.nll_loss.html
@@ -9,7 +9,7 @@
- mlx.nn.losses.nll_loss — MLX 0.0.9 documentation
+ mlx.nn.losses.nll_loss — MLX 0.1.0 documentation
@@ -134,8 +134,8 @@
-
-
+
+
@@ -241,6 +241,8 @@
mlx.core.cos
mlx.core.cosh
mlx.core.dequantize
+mlx.core.diag
+mlx.core.diagonal
mlx.core.divide
mlx.core.divmod
mlx.core.equal
@@ -351,7 +353,6 @@
mlx.core.jvp
mlx.core.vjp
mlx.core.vmap
-mlx.core.simplify
FFT
Linear Algebra
Neural Networks
@@ -433,6 +436,7 @@
mlx.nn.prelu
mlx.nn.relu
mlx.nn.selu
+mlx.nn.softshrink
mlx.nn.silu
mlx.nn.step
@@ -441,6 +445,7 @@
mlx.nn.losses.binary_cross_entropy
mlx.nn.losses.cosine_similarity_loss
mlx.nn.losses.cross_entropy
+mlx.nn.losses.gaussian_nll_loss
mlx.nn.losses.hinge_loss
mlx.nn.losses.huber_loss
mlx.nn.losses.kl_div_loss
@@ -452,14 +457,26 @@
mlx.nn.losses.triplet_loss
+Initializers
-Optimizers
+
+Optimizers
-Tree Utils
+Tree Utils
mlx.utils.tree_flatten
mlx.utils.tree_unflatten
mlx.utils.tree_map
@@ -650,7 +667,7 @@ document.write(`
mlx.nn.losses.nll_loss
-class mlx.nn.losses. nll_loss ( inputs : array , targets : array , axis : int = - 1 , reduction : str = 'none' )
+class mlx.nn.losses. nll_loss ( inputs : array , targets : array , axis : int = - 1 , reduction : Literal [ 'none' , 'mean' , 'sum' ] = 'none' )
Computes the negative log likelihood loss.
Parameters:
diff --git a/docs/build/html/python/nn/_autosummary_functions/mlx.nn.losses.smooth_l1_loss.html b/docs/build/html/python/nn/_autosummary_functions/mlx.nn.losses.smooth_l1_loss.html
index 490ae81c5..322943fb3 100644
--- a/docs/build/html/python/nn/_autosummary_functions/mlx.nn.losses.smooth_l1_loss.html
+++ b/docs/build/html/python/nn/_autosummary_functions/mlx.nn.losses.smooth_l1_loss.html
@@ -9,7 +9,7 @@
- mlx.nn.losses.smooth_l1_loss — MLX 0.0.9 documentation
+ mlx.nn.losses.smooth_l1_loss — MLX 0.1.0 documentation
@@ -135,8 +135,8 @@
-
-
+
+
@@ -242,6 +242,8 @@
mlx.core.cos
mlx.core.cosh
mlx.core.dequantize
+mlx.core.diag
+mlx.core.diagonal
mlx.core.divide
mlx.core.divmod
mlx.core.equal
@@ -352,7 +354,6 @@
mlx.core.jvp
mlx.core.vjp
mlx.core.vmap
-mlx.core.simplify
FFT
Linear Algebra
Neural Networks
@@ -434,6 +437,7 @@
mlx.nn.prelu
mlx.nn.relu
mlx.nn.selu
+mlx.nn.softshrink
mlx.nn.silu
mlx.nn.step
@@ -442,6 +446,7 @@
mlx.nn.losses.binary_cross_entropy
mlx.nn.losses.cosine_similarity_loss
mlx.nn.losses.cross_entropy
+mlx.nn.losses.gaussian_nll_loss
mlx.nn.losses.hinge_loss
mlx.nn.losses.huber_loss
mlx.nn.losses.kl_div_loss
@@ -453,14 +458,26 @@
mlx.nn.losses.triplet_loss
+Initializers
-Optimizers
+
+Optimizers
-Tree Utils
+Tree Utils
mlx.utils.tree_flatten
mlx.utils.tree_unflatten
mlx.utils.tree_map
@@ -651,7 +668,7 @@ document.write(`
mlx.nn.losses.smooth_l1_loss
-class mlx.nn.losses. smooth_l1_loss ( predictions : array , targets : array , beta : float = 1.0 , reduction : str = 'mean' )
+class mlx.nn.losses. smooth_l1_loss ( predictions : array , targets : array , beta : float = 1.0 , reduction : Literal [ 'none' , 'mean' , 'sum' ] = 'mean' )
Computes the smooth L1 loss.
The smooth L1 loss is a variant of the L1 loss which replaces the absolute
difference with a squared difference when the absolute difference is less
@@ -659,10 +676,10 @@ than beta The formula for the smooth L1 Loss is:
\[\begin{split}l =
- \begin{cases}
- 0.5 (x - y)^2, & \text{ if } & (x - y) < \beta \\
- |x - y| - 0.5 \beta, & & \text{otherwise}
- \end{cases}\end{split}\]
+ \begin{cases}
+ 0.5 (x - y)^2, & \text{ if } & (x - y) < \beta \\
+ |x - y| - 0.5 \beta, & & \text{otherwise}
+ \end{cases}\end{split}\]
Parameters:
diff --git a/docs/build/html/python/nn/_autosummary_functions/mlx.nn.losses.triplet_loss.html b/docs/build/html/python/nn/_autosummary_functions/mlx.nn.losses.triplet_loss.html
index 6e3ec976c..6fe0d14fc 100644
--- a/docs/build/html/python/nn/_autosummary_functions/mlx.nn.losses.triplet_loss.html
+++ b/docs/build/html/python/nn/_autosummary_functions/mlx.nn.losses.triplet_loss.html
@@ -9,7 +9,7 @@
- mlx.nn.losses.triplet_loss — MLX 0.0.9 documentation
+ mlx.nn.losses.triplet_loss — MLX 0.1.0 documentation
@@ -47,7 +47,7 @@
-
+
@@ -135,8 +135,8 @@
-
-
+
+
@@ -242,6 +242,8 @@
mlx.core.cos
mlx.core.cosh
mlx.core.dequantize
+mlx.core.diag
+mlx.core.diagonal
mlx.core.divide
mlx.core.divmod
mlx.core.equal
@@ -352,7 +354,6 @@
mlx.core.jvp
mlx.core.vjp
mlx.core.vmap
-mlx.core.simplify
FFT
Linear Algebra
Neural Networks
@@ -434,6 +437,7 @@
mlx.nn.prelu
mlx.nn.relu
mlx.nn.selu
+mlx.nn.softshrink
mlx.nn.silu
mlx.nn.step
@@ -442,6 +446,7 @@
mlx.nn.losses.binary_cross_entropy
mlx.nn.losses.cosine_similarity_loss
mlx.nn.losses.cross_entropy
+mlx.nn.losses.gaussian_nll_loss
mlx.nn.losses.hinge_loss
mlx.nn.losses.huber_loss
mlx.nn.losses.kl_div_loss
@@ -453,14 +458,26 @@
mlx.nn.losses.triplet_loss
+Initializers
-Optimizers
+
+Optimizers
-Tree Utils
+Tree Utils
mlx.utils.tree_flatten
mlx.utils.tree_unflatten
mlx.utils.tree_map
@@ -651,11 +668,11 @@ document.write(`
mlx.nn.losses.triplet_loss
-class mlx.nn.losses. triplet_loss ( anchors : array , positives : array , negatives : array , axis : int = - 1 , p : int = 2 , margin : float = 1.0 , eps : float = 1e-06 , reduction : str = 'none' )
+class mlx.nn.losses. triplet_loss ( anchors : array , positives : array , negatives : array , axis : int = - 1 , p : int = 2 , margin : float = 1.0 , eps : float = 1e-06 , reduction : Literal [ 'none' , 'mean' , 'sum' ] = 'none' )
Computes the triplet loss for a set of anchor, positive, and negative samples.
Margin is represented with alpha in the math section.
-\[L_{\text{triplet}} = \max\left(\|A - P\|_p - \|A - N\|_p + \alpha, 0\right)\]
+\[\max\left(\|A - P\|_p - \|A - N\|_p + \alpha, 0\right)\]
Parameters:
@@ -706,11 +723,11 @@ Margin is represented with alpha in the math section.
next
-
Optimizers
+
Initializers
diff --git a/docs/build/html/python/nn/_autosummary_functions/mlx.nn.mish.html b/docs/build/html/python/nn/_autosummary_functions/mlx.nn.mish.html
index 84720b763..831a1e049 100644
--- a/docs/build/html/python/nn/_autosummary_functions/mlx.nn.mish.html
+++ b/docs/build/html/python/nn/_autosummary_functions/mlx.nn.mish.html
@@ -9,7 +9,7 @@
- mlx.nn.mish — MLX 0.0.9 documentation
+ mlx.nn.mish — MLX 0.1.0 documentation
@@ -135,8 +135,8 @@
-
-
+
+
@@ -242,6 +242,8 @@
mlx.core.cos
mlx.core.cosh
mlx.core.dequantize
+mlx.core.diag
+mlx.core.diagonal
mlx.core.divide
mlx.core.divmod
mlx.core.equal
@@ -352,7 +354,6 @@
mlx.core.jvp
mlx.core.vjp
mlx.core.vmap
-mlx.core.simplify
FFT
Linear Algebra
Neural Networks
@@ -434,6 +437,7 @@
mlx.nn.prelu
mlx.nn.relu
mlx.nn.selu
+mlx.nn.softshrink
mlx.nn.silu
mlx.nn.step
@@ -442,6 +446,7 @@
mlx.nn.losses.binary_cross_entropy
mlx.nn.losses.cosine_similarity_loss
mlx.nn.losses.cross_entropy
+mlx.nn.losses.gaussian_nll_loss
mlx.nn.losses.hinge_loss
mlx.nn.losses.huber_loss
mlx.nn.losses.kl_div_loss
@@ -453,14 +458,26 @@
mlx.nn.losses.triplet_loss
+Initializers
-Optimizers
+
+Optimizers
-Tree Utils
@@ -442,6 +446,7 @@
mlx.nn.losses.binary_cross_entropy
mlx.nn.losses.cosine_similarity_loss
mlx.nn.losses.cross_entropy
+mlx.nn.losses.gaussian_nll_loss
mlx.nn.losses.hinge_loss
mlx.nn.losses.huber_loss
mlx.nn.losses.kl_div_loss
@@ -453,14 +458,26 @@
mlx.nn.losses.triplet_loss
+Initializers
-Optimizers
+
+Optimizers
-Tree Utils
@@ -441,6 +445,7 @@
mlx.nn.losses.binary_cross_entropy
mlx.nn.losses.cosine_similarity_loss
mlx.nn.losses.cross_entropy
+mlx.nn.losses.gaussian_nll_loss
mlx.nn.losses.hinge_loss
mlx.nn.losses.huber_loss
mlx.nn.losses.kl_div_loss
@@ -452,14 +457,26 @@
mlx.nn.losses.triplet_loss
+Initializers
-Optimizers
+
+Optimizers
-Tree Utils
+Tree Utils
mlx.utils.tree_flatten
mlx.utils.tree_unflatten
mlx.utils.tree_map
diff --git a/docs/build/html/python/nn/_autosummary_functions/mlx.nn.selu.html b/docs/build/html/python/nn/_autosummary_functions/mlx.nn.selu.html
index aef28d7ad..89da75b72 100644
--- a/docs/build/html/python/nn/_autosummary_functions/mlx.nn.selu.html
+++ b/docs/build/html/python/nn/_autosummary_functions/mlx.nn.selu.html
@@ -9,7 +9,7 @@
- mlx.nn.selu — MLX 0.0.9 documentation
+ mlx.nn.selu — MLX 0.1.0 documentation
@@ -47,7 +47,7 @@
-
+
@@ -135,8 +135,8 @@
-
-
+
+
@@ -242,6 +242,8 @@
mlx.core.cos
mlx.core.cosh
mlx.core.dequantize
+mlx.core.diag
+mlx.core.diagonal
mlx.core.divide
mlx.core.divmod
mlx.core.equal
@@ -352,7 +354,6 @@
mlx.core.jvp
mlx.core.vjp
mlx.core.vmap
-mlx.core.simplify
FFT
Linear Algebra
Neural Networks
@@ -434,6 +437,7 @@
mlx.nn.prelu
mlx.nn.relu
mlx.nn.selu
+mlx.nn.softshrink
mlx.nn.silu
mlx.nn.step
@@ -442,6 +446,7 @@
mlx.nn.losses.binary_cross_entropy
mlx.nn.losses.cosine_similarity_loss
mlx.nn.losses.cross_entropy
+mlx.nn.losses.gaussian_nll_loss
mlx.nn.losses.hinge_loss
mlx.nn.losses.huber_loss
mlx.nn.losses.kl_div_loss
@@ -453,14 +458,26 @@
mlx.nn.losses.triplet_loss
+Initializers
-Optimizers
+
+Optimizers
-Tree Utils
+Tree Utils
mlx.utils.tree_flatten
mlx.utils.tree_unflatten
mlx.utils.tree_map
@@ -685,11 +702,11 @@ document.write(`
next
-
mlx.nn.silu
+
mlx.nn.softshrink
diff --git a/docs/build/html/python/nn/_autosummary_functions/mlx.nn.silu.html b/docs/build/html/python/nn/_autosummary_functions/mlx.nn.silu.html
index 6fc1def8c..49138e87b 100644
--- a/docs/build/html/python/nn/_autosummary_functions/mlx.nn.silu.html
+++ b/docs/build/html/python/nn/_autosummary_functions/mlx.nn.silu.html
@@ -9,7 +9,7 @@
- mlx.nn.silu — MLX 0.0.9 documentation
+ mlx.nn.silu — MLX 0.1.0 documentation
@@ -48,7 +48,7 @@
-
+
@@ -135,8 +135,8 @@
-
-
+
+
@@ -242,6 +242,8 @@
mlx.core.cos
mlx.core.cosh
mlx.core.dequantize
+mlx.core.diag
+mlx.core.diagonal
mlx.core.divide
mlx.core.divmod
mlx.core.equal
@@ -352,7 +354,6 @@
mlx.core.jvp
mlx.core.vjp
mlx.core.vmap
-mlx.core.simplify
FFT
Linear Algebra
Neural Networks
@@ -434,6 +437,7 @@
mlx.nn.prelu
mlx.nn.relu
mlx.nn.selu
+mlx.nn.softshrink
mlx.nn.silu
mlx.nn.step
@@ -442,6 +446,7 @@
mlx.nn.losses.binary_cross_entropy
mlx.nn.losses.cosine_similarity_loss
mlx.nn.losses.cross_entropy
+mlx.nn.losses.gaussian_nll_loss
mlx.nn.losses.hinge_loss
mlx.nn.losses.huber_loss
mlx.nn.losses.kl_div_loss
@@ -453,14 +458,26 @@
mlx.nn.losses.triplet_loss
+Initializers
-Optimizers
+
+Optimizers
-Tree Utils
+Tree Utils
mlx.utils.tree_flatten
mlx.utils.tree_unflatten
mlx.utils.tree_map
@@ -671,12 +688,12 @@ the logistic sigmoid.
previous
-
mlx.nn.selu
+
mlx.nn.softshrink
+
+
+
+
+
+
+
+
+ mlx.nn.softshrink — MLX 0.1.0 documentation
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+ Skip to main content
+
+
+
+
+
+
+ Back to top
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
mlx.nn.softshrink
+
+
+
+
+
+
+
+
+
+
+mlx.nn.softshrink
+
+
+class mlx.nn. softshrink ( x , lambd : float = 0.5 )
+Applies the Softshrink activation function.
+
+\[\begin{split}\text{softshrink}(x) = \begin{cases}
+x - \lambda & \text{if } x > \lambda \\
+x + \lambda & \text{if } x < -\lambda \\
+0 & \text{otherwise}
+\end{cases}\end{split}\]
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/docs/build/html/python/nn/_autosummary_functions/mlx.nn.step.html b/docs/build/html/python/nn/_autosummary_functions/mlx.nn.step.html
index 584d3ebd0..6d56b52bc 100644
--- a/docs/build/html/python/nn/_autosummary_functions/mlx.nn.step.html
+++ b/docs/build/html/python/nn/_autosummary_functions/mlx.nn.step.html
@@ -9,7 +9,7 @@
-
mlx.nn.step — MLX 0.0.9 documentation
+
mlx.nn.step — MLX 0.1.0 documentation
@@ -135,8 +135,8 @@
-
-
+
+
@@ -242,6 +242,8 @@
mlx.core.cos
mlx.core.cosh
mlx.core.dequantize
+mlx.core.diag
+mlx.core.diagonal
mlx.core.divide
mlx.core.divmod
mlx.core.equal
@@ -352,7 +354,6 @@
mlx.core.jvp
mlx.core.vjp
mlx.core.vmap
-mlx.core.simplify
FFT
Linear Algebra
Neural Networks
@@ -434,6 +437,7 @@
mlx.nn.prelu
mlx.nn.relu
mlx.nn.selu
+mlx.nn.softshrink
mlx.nn.silu
mlx.nn.step
@@ -442,6 +446,7 @@
mlx.nn.losses.binary_cross_entropy
mlx.nn.losses.cosine_similarity_loss
mlx.nn.losses.cross_entropy
+mlx.nn.losses.gaussian_nll_loss
mlx.nn.losses.hinge_loss
mlx.nn.losses.huber_loss
mlx.nn.losses.kl_div_loss
@@ -453,14 +458,26 @@
mlx.nn.losses.triplet_loss
+Initializers
-Optimizers
+
+Optimizers
-Tree Utils
@@ -441,6 +445,7 @@
mlx.nn.losses.binary_cross_entropy
mlx.nn.losses.cosine_similarity_loss
mlx.nn.losses.cross_entropy
+mlx.nn.losses.gaussian_nll_loss
mlx.nn.losses.hinge_loss
mlx.nn.losses.huber_loss
mlx.nn.losses.kl_div_loss
@@ -452,14 +457,26 @@
mlx.nn.losses.triplet_loss
+Initializers
-Optimizers
+
+Optimizers
-Tree Utils
+Tree Utils
mlx.utils.tree_flatten
mlx.utils.tree_unflatten
mlx.utils.tree_map
@@ -673,10 +690,13 @@ simple functions.
selu
(x)
Applies the Scaled Exponential Linear Unit.
-silu
(x)
+softshrink
(x[, lambd])
+Applies the Softshrink activation function.
+
+silu
(x)
Applies the Sigmoid Linear Unit.
-step
(x[, threshold])
+step
(x[, threshold])
Applies the Step Activation Function.
diff --git a/docs/build/html/python/nn/init.html b/docs/build/html/python/nn/init.html
new file mode 100644
index 000000000..648f38e11
--- /dev/null
+++ b/docs/build/html/python/nn/init.html
@@ -0,0 +1,804 @@
+
+
+
+
+
+
+
+
+
+
+
+ Initializers — MLX 0.1.0 documentation
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+ Skip to main content
+
+
+
+
+
+
+ Back to top
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+Initializers
+The mlx.nn.init
package contains commonly used initializers for neural
+network parameters. Initializers return a function which can be applied to any
+input mlx.core.array
to produce an initialized output.
+For example:
+import mlx.core as mx
+import mlx.nn as nn
+
+init_fn = nn . init . uniform ()
+
+# Produces a [2, 2] uniform matrix
+param = init_fn ( mx . zeros (( 2 , 2 )))
+
+
+To re-initialize all the parameter in an mlx.nn.Module
from say a uniform
+distribution, you can do:
+import mlx.nn as nn
+model = nn . Sequential ( nn . Linear ( 5 , 10 ), nn . ReLU (), nn . Linear ( 10 , 5 ))
+init_fn = nn . init . uniform ( low =- 0.1 , high = 0.1 )
+model . apply ( init_fn )
+
+
+
+
+constant
(value[, dtype])
+An initializer that returns an array filled with value
.
+
+normal
([mean, std, dtype])
+An initializer that returns samples from a normal distribution.
+
+uniform
([low, high, dtype])
+An initializer that returns samples from a uniform distribution.
+
+identity
([dtype])
+An initializer that returns an identity matrix.
+
+glorot_normal
([dtype])
+A Glorot normal initializer.
+
+glorot_uniform
([dtype])
+A Glorot uniform initializer.
+
+he_normal
([dtype])
+Build a He normal initializer.
+
+he_uniform
([dtype])
+A He uniform (Kaiming uniform) initializer.
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/docs/build/html/python/nn/initializers.html b/docs/build/html/python/nn/initializers.html
new file mode 100644
index 000000000..c01560e4b
--- /dev/null
+++ b/docs/build/html/python/nn/initializers.html
@@ -0,0 +1,778 @@
+
+
+
+
+
+
+
+
+
+
+
+ Initializers — MLX documentation
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+ Skip to main content
+
+
+
+
+
+
+ Back to top
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+Initializers
+
+
+constant
(value[, dtype])
+An initializer that returns an array filled with value
.
+
+normal
([mean, std, dtype])
+An initializer that returns samples from a normal distribution.
+
+uniform
([low, high, dtype])
+An initializer that returns random values from a uniform distribution.
+
+identity
([dtype])
+An initializer that returns an identity matrix.
+
+glorot_normal
([dtype])
+A Glorot normal initializer.
+
+glorot_uniform
([dtype])
+A Glorot uniform initializer.
+
+he_normal
([dtype])
+Build a He normal initializer.
+
+he_uniform
([dtype])
+A He uniform (Kaiming uniform) initializer.
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/docs/build/html/python/nn/layers.html b/docs/build/html/python/nn/layers.html
index 745d0fdc0..9dfadf611 100644
--- a/docs/build/html/python/nn/layers.html
+++ b/docs/build/html/python/nn/layers.html
@@ -9,7 +9,7 @@
- Layers — MLX 0.0.9 documentation
+ Layers — MLX 0.1.0 documentation
@@ -134,8 +134,8 @@
-
-
+
+
@@ -241,6 +241,8 @@
mlx.core.cos
mlx.core.cosh
mlx.core.dequantize
+mlx.core.diag
+mlx.core.diagonal
mlx.core.divide
mlx.core.divmod
mlx.core.equal
@@ -351,7 +353,6 @@
mlx.core.jvp
mlx.core.vjp
mlx.core.vmap
-mlx.core.simplify
FFT
Linear Algebra
Neural Networks
@@ -433,6 +436,7 @@
mlx.nn.prelu
mlx.nn.relu
mlx.nn.selu
+mlx.nn.softshrink
mlx.nn.silu
mlx.nn.step
@@ -441,6 +445,7 @@
mlx.nn.losses.binary_cross_entropy
mlx.nn.losses.cosine_similarity_loss
mlx.nn.losses.cross_entropy
+mlx.nn.losses.gaussian_nll_loss
mlx.nn.losses.hinge_loss
mlx.nn.losses.huber_loss
mlx.nn.losses.kl_div_loss
@@ -452,14 +457,26 @@
mlx.nn.losses.triplet_loss
+Initializers
-Optimizers
+
+Optimizers
-Tree Utils
+Tree Utils
mlx.utils.tree_flatten
mlx.utils.tree_unflatten
mlx.utils.tree_map
@@ -722,10 +739,13 @@ document.write(`
SinusoidalPositionalEncoding
(dims[, ...])
Implements sinusoidal positional encoding.
-Step
([threshold])
+Softshrink
([lambd])
+Applies the Softshrink function.
+
+Step
([threshold])
Applies the Step Activation Function.
-Transformer
(dims, num_heads, ...)
+Transformer
(dims, num_heads, ...)
Implements a standard Transformer model.
diff --git a/docs/build/html/python/nn/losses.html b/docs/build/html/python/nn/losses.html
index 62fdfee6f..bc37f49fc 100644
--- a/docs/build/html/python/nn/losses.html
+++ b/docs/build/html/python/nn/losses.html
@@ -9,7 +9,7 @@
- Loss Functions — MLX 0.0.9 documentation
+ Loss Functions — MLX 0.1.0 documentation
@@ -134,8 +134,8 @@
-
-
+
+
@@ -241,6 +241,8 @@
mlx.core.cos
mlx.core.cosh
mlx.core.dequantize
+mlx.core.diag
+mlx.core.diagonal
mlx.core.divide
mlx.core.divmod
mlx.core.equal
@@ -351,7 +353,6 @@
mlx.core.jvp
mlx.core.vjp
mlx.core.vmap
-mlx.core.simplify
FFT
Linear Algebra
Neural Networks
@@ -433,6 +436,7 @@
mlx.nn.prelu
mlx.nn.relu
mlx.nn.selu
+mlx.nn.softshrink
mlx.nn.silu
mlx.nn.step
@@ -441,6 +445,7 @@
mlx.nn.losses.binary_cross_entropy
mlx.nn.losses.cosine_similarity_loss
mlx.nn.losses.cross_entropy
+mlx.nn.losses.gaussian_nll_loss
mlx.nn.losses.hinge_loss
mlx.nn.losses.huber_loss
mlx.nn.losses.kl_div_loss
@@ -452,14 +457,26 @@
mlx.nn.losses.triplet_loss
+Initializers
-Optimizers
+
+Optimizers
-Tree Utils
+Tree Utils
mlx.utils.tree_flatten
mlx.utils.tree_unflatten
mlx.utils.tree_map
@@ -650,7 +667,7 @@ document.write(`
Loss Functions
-binary_cross_entropy
(logits, targets[, ...])
+binary_cross_entropy
(inputs, targets[, ...])
Computes the binary cross entropy loss.
cosine_similarity_loss
(x1, x2[, axis, eps, ...])
@@ -659,31 +676,34 @@ document.write(`
cross_entropy
(logits, targets[, weights, ...])
Computes the cross entropy loss.
-hinge_loss
(inputs, targets[, reduction])
+gaussian_nll_loss
(inputs, targets, vars[, ...])
+Computes the negative log likelihood loss for a Gaussian distribution.
+
+hinge_loss
(inputs, targets[, reduction])
Computes the hinge loss between inputs and targets.
-huber_loss
(inputs, targets[, delta, reduction])
+huber_loss
(inputs, targets[, delta, reduction])
Computes the Huber loss between inputs and targets.
-kl_div_loss
(inputs, targets[, axis, reduction])
+kl_div_loss
(inputs, targets[, axis, reduction])
Computes the Kullback-Leibler divergence loss.
-l1_loss
(predictions, targets[, reduction])
+l1_loss
(predictions, targets[, reduction])
Computes the L1 loss.
-log_cosh_loss
(inputs, targets[, reduction])
+log_cosh_loss
(inputs, targets[, reduction])
Computes the log cosh loss between inputs and targets.
-mse_loss
(predictions, targets[, reduction])
+mse_loss
(predictions, targets[, reduction])
Computes the mean squared error loss.
-nll_loss
(inputs, targets[, axis, reduction])
+nll_loss
(inputs, targets[, axis, reduction])
Computes the negative log likelihood loss.
-smooth_l1_loss
(predictions, targets[, beta, ...])
+smooth_l1_loss
(predictions, targets[, beta, ...])
Computes the smooth L1 loss.
-triplet_loss
(anchors, positives, negatives)
+triplet_loss
(anchors, positives, negatives)
Computes the triplet loss for a set of anchor, positive, and negative samples.
diff --git a/docs/build/html/python/nn/module.html b/docs/build/html/python/nn/module.html
index c01ac92f6..700526581 100644
--- a/docs/build/html/python/nn/module.html
+++ b/docs/build/html/python/nn/module.html
@@ -9,7 +9,7 @@
- Module — MLX 0.0.9 documentation
+ Module — MLX 0.1.0 documentation
@@ -134,8 +134,8 @@
-
-
+
+
@@ -241,6 +241,8 @@
mlx.core.cos
mlx.core.cosh
mlx.core.dequantize
+mlx.core.diag
+mlx.core.diagonal
mlx.core.divide
mlx.core.divmod
mlx.core.equal
@@ -351,7 +353,6 @@
mlx.core.jvp
mlx.core.vjp
mlx.core.vmap
-mlx.core.simplify
FFT
Linear Algebra
Neural Networks
@@ -433,6 +436,7 @@
mlx.nn.prelu
mlx.nn.relu
mlx.nn.selu
+mlx.nn.softshrink
mlx.nn.silu
mlx.nn.step
@@ -441,6 +445,7 @@
mlx.nn.losses.binary_cross_entropy
mlx.nn.losses.cosine_similarity_loss
mlx.nn.losses.cross_entropy
+mlx.nn.losses.gaussian_nll_loss
mlx.nn.losses.hinge_loss
mlx.nn.losses.huber_loss
mlx.nn.losses.kl_div_loss
@@ -452,14 +457,26 @@
mlx.nn.losses.triplet_loss
+Initializers
-Optimizers
+
+Optimizers
-Tree Utils
+Tree Utils
mlx.utils.tree_flatten
mlx.utils.tree_unflatten
mlx.utils.tree_map
@@ -724,7 +741,7 @@ set by calling Module.load_weights
(file_or_weights[, strict])
-Update the model's weights from a .npz
or a list.
+Update the model's weights from a .npz
, a .safetensors
file, or a list.
Module.modules
()
Return a list with all the modules in this instance.
@@ -736,7 +753,7 @@ set by calling mlx.core.array
members of this Module as a dict of dicts and lists.
Module.save_weights
(file)
-Save the model's weights to a .npz
file.
+Save the model's weights to a file.
Module.train
([mode])
Set the model in or out of training mode.
diff --git a/docs/build/html/python/ops.html b/docs/build/html/python/ops.html
index b4d1a3c40..edf1b084f 100644
--- a/docs/build/html/python/ops.html
+++ b/docs/build/html/python/ops.html
@@ -9,7 +9,7 @@
- Operations — MLX 0.0.9 documentation
+ Operations — MLX 0.1.0 documentation
@@ -134,8 +134,8 @@
-
-
+
+
@@ -241,6 +241,8 @@
mlx.core.cos
mlx.core.cosh
mlx.core.dequantize
+mlx.core.diag
+mlx.core.diagonal
mlx.core.divide
mlx.core.divmod
mlx.core.equal
@@ -351,7 +353,6 @@
mlx.core.jvp
mlx.core.vjp
mlx.core.vmap
-mlx.core.simplify
FFT
Linear Algebra
Neural Networks
@@ -433,6 +436,7 @@
mlx.nn.prelu
mlx.nn.relu
mlx.nn.selu
+mlx.nn.softshrink
mlx.nn.silu
mlx.nn.step
@@ -441,6 +445,7 @@
mlx.nn.losses.binary_cross_entropy
mlx.nn.losses.cosine_similarity_loss
mlx.nn.losses.cross_entropy
+mlx.nn.losses.gaussian_nll_loss
mlx.nn.losses.hinge_loss
mlx.nn.losses.huber_loss
mlx.nn.losses.kl_div_loss
@@ -452,14 +457,26 @@
mlx.nn.losses.triplet_loss
+Initializers
-Optimizers
+
+Optimizers
-Tree Utils
+Tree Utils
mlx.utils.tree_flatten
mlx.utils.tree_unflatten
mlx.utils.tree_map
@@ -659,7 +676,7 @@ document.write(`
all
(a, /[, axis, keepdims, stream])
An and reduction over the given axes.
-allclose
(a, b, /[, rtol, atol, stream])
+allclose
(a, b, /[, rtol, atol, equal_nan, ...])
Approximate comparison of two arrays.
any
(a, /[, axis, keepdims, stream])
@@ -731,6 +748,12 @@ document.write(`
dequantize
(w, /, scales, biases[, ...])
Dequantize the matrix w
using the provided scales
and biases
and the group_size
and bits
configuration.
+diag
(a, /[, k, stream])
+Extract a diagonal or construct a diagonal matrix.
+
+diagonal
(a[, offset, axis1, axis2, stream])
+Return specified diagonals.
+
divide
(a, b[, stream])
Element-wise division.
@@ -800,7 +823,7 @@ document.write(`
linspace
(start, stop[, num, dtype, stream])
Generate num
evenly spaced numbers over interval [start, stop]
.
-load
(file, /[, format, stream])
+load
(file, /[, format, return_metadata, stream])
Load array(s) from a binary file.
log
(a, /, *[, stream])
@@ -905,7 +928,7 @@ document.write(`
savez_compressed
(file, *args, **kwargs)
Save several arrays to a binary file in compressed .npz
format.
-save_gguf
(file, arrays)
+save_gguf
(file, arrays, metadata)
Save array(s) to a binary file in .gguf
format.
save_safetensors
(file, arrays)
@@ -968,7 +991,7 @@ document.write(`
tanh
(a, /, *[, stream])
Element-wise hyperbolic tangent.
-tensordot
(a, b, /[, dims, stream])
+tensordot
(a, b, /[, axes, stream])
Compute the tensor dot product along the specified axes.
transpose
(a, /[, axes, stream])
diff --git a/docs/build/html/python/optimizers.html b/docs/build/html/python/optimizers.html
index e8e03b42d..0632bae0f 100644
--- a/docs/build/html/python/optimizers.html
+++ b/docs/build/html/python/optimizers.html
@@ -9,7 +9,7 @@
- Optimizers — MLX 0.0.9 documentation
+ Optimizers — MLX 0.1.0 documentation
@@ -47,7 +47,7 @@
-
+
@@ -134,8 +134,8 @@
-
-
+
+
@@ -241,6 +241,8 @@
mlx.core.cos
mlx.core.cosh
mlx.core.dequantize
+mlx.core.diag
+mlx.core.diagonal
mlx.core.divide
mlx.core.divmod
mlx.core.equal
@@ -351,7 +353,6 @@
mlx.core.jvp
mlx.core.vjp
mlx.core.vmap
-mlx.core.simplify
FFT
Linear Algebra
Neural Networks
@@ -433,6 +436,7 @@
mlx.nn.prelu
mlx.nn.relu
mlx.nn.selu
+mlx.nn.softshrink
mlx.nn.silu
mlx.nn.step
@@ -441,6 +445,7 @@
mlx.nn.losses.binary_cross_entropy
mlx.nn.losses.cosine_similarity_loss
mlx.nn.losses.cross_entropy
+mlx.nn.losses.gaussian_nll_loss
mlx.nn.losses.hinge_loss
mlx.nn.losses.huber_loss
mlx.nn.losses.kl_div_loss
@@ -452,14 +457,26 @@
mlx.nn.losses.triplet_loss
+Initializers
-Optimizers
+
+Optimizers
-Tree Utils
SGD
(learning_rate[, momentum, weight_decay, ...])
-Stochastic gradient descent optimizer.
+The stochastic gradient descent optimizer.
RMSprop
(learning_rate[, alpha, eps])
-Implementation of the RMSprop optimizer [1].
+The RMSprop optimizer [1].
Adagrad
(learning_rate[, eps])
-Implementation of the Adagrad optimizer [1].
+The Adagrad optimizer [1].
-AdaDelta
(learning_rate[, rho, eps])
-Implementation of the AdaDelta optimizer with learning rate[1].
+Adafactor
([learning_rate, eps, ...])
+The Adafactor optimizer.
-Adam
(learning_rate[, betas, eps])
-Implementation of the Adam optimizer [1].
+AdaDelta
(learning_rate[, rho, eps])
+The AdaDelta optimizer with a learning rate [1].
-AdamW
(learning_rate[, betas, eps, weight_decay])
-Implementation of the AdamW optimizer [1].
+Adam
(learning_rate[, betas, eps])
+The Adam optimizer [1].
-Adamax
(learning_rate[, betas, eps])
-Implementation of the Adamax optimizer.
+AdamW
(learning_rate[, betas, eps, weight_decay])
+The AdamW optimizer [1].
-Lion
(learning_rate[, betas, weight_decay])
-Implementation of the Lion optimizer [1].
+Adamax
(learning_rate[, betas, eps])
+The Adamax optimizer, a variant of Adam based on the infinity norm [1].
+
+Lion
(learning_rate[, betas, weight_decay])
+The Lion optimizer [1].
@@ -720,12 +740,12 @@ model’s parameters and the optimizer state .
@@ -241,6 +241,8 @@
mlx.core.cos
mlx.core.cosh
mlx.core.dequantize
+mlx.core.diag
+mlx.core.diagonal
mlx.core.divide
mlx.core.divmod
mlx.core.equal
@@ -351,7 +353,6 @@
mlx.core.jvp
mlx.core.vjp
mlx.core.vmap
-mlx.core.simplify
FFT
Linear Algebra
Neural Networks
@@ -433,6 +436,7 @@
mlx.nn.prelu
mlx.nn.relu
mlx.nn.selu
+mlx.nn.softshrink
mlx.nn.silu
mlx.nn.step
@@ -441,6 +445,7 @@
mlx.nn.losses.binary_cross_entropy
mlx.nn.losses.cosine_similarity_loss
mlx.nn.losses.cross_entropy
+mlx.nn.losses.gaussian_nll_loss
mlx.nn.losses.hinge_loss
mlx.nn.losses.huber_loss
mlx.nn.losses.kl_div_loss
@@ -452,14 +457,26 @@
mlx.nn.losses.triplet_loss
+Initializers
-Optimizers
+
+Optimizers
-Tree Utils
@@ -441,6 +445,7 @@
mlx.nn.losses.binary_cross_entropy
mlx.nn.losses.cosine_similarity_loss
mlx.nn.losses.cross_entropy
+mlx.nn.losses.gaussian_nll_loss
mlx.nn.losses.hinge_loss
mlx.nn.losses.huber_loss
mlx.nn.losses.kl_div_loss
@@ -452,14 +457,26 @@
mlx.nn.losses.triplet_loss
+Initializers
-Optimizers
+
+Optimizers
-Tree Utils
diff --git a/docs/build/html/python/tree_utils.html b/docs/build/html/python/tree_utils.html
index 964b7e1d1..fd3f760d6 100644
--- a/docs/build/html/python/tree_utils.html
+++ b/docs/build/html/python/tree_utils.html
@@ -9,7 +9,7 @@
- Tree Utils — MLX 0.0.9 documentation
+ Tree Utils — MLX 0.1.0 documentation
@@ -134,8 +134,8 @@
-
-
+
+
@@ -241,6 +241,8 @@
mlx.core.cos
mlx.core.cosh
mlx.core.dequantize
+mlx.core.diag
+mlx.core.diagonal
mlx.core.divide
mlx.core.divmod
mlx.core.equal
@@ -351,7 +353,6 @@
mlx.core.jvp
mlx.core.vjp
mlx.core.vmap
-mlx.core.simplify
FFT
Linear Algebra
Neural Networks
@@ -433,6 +436,7 @@
mlx.nn.prelu
mlx.nn.relu
mlx.nn.selu
+mlx.nn.softshrink
mlx.nn.silu
mlx.nn.step
@@ -441,6 +445,7 @@
mlx.nn.losses.binary_cross_entropy
mlx.nn.losses.cosine_similarity_loss
mlx.nn.losses.cross_entropy
+mlx.nn.losses.gaussian_nll_loss
mlx.nn.losses.hinge_loss
mlx.nn.losses.huber_loss
mlx.nn.losses.kl_div_loss
@@ -452,14 +457,26 @@
mlx.nn.losses.triplet_loss
+Initializers
-Optimizers
+
+Optimizers
-Tree Utils
@@ -440,6 +444,7 @@
mlx.nn.losses.binary_cross_entropy
mlx.nn.losses.cosine_similarity_loss
mlx.nn.losses.cross_entropy
+mlx.nn.losses.gaussian_nll_loss
mlx.nn.losses.hinge_loss
mlx.nn.losses.huber_loss
mlx.nn.losses.kl_div_loss
@@ -451,14 +456,26 @@
mlx.nn.losses.triplet_loss
+Initializers
-Optimizers
+
+Optimizers
-Tree Utils
+Tree Utils
mlx.utils.tree_flatten
mlx.utils.tree_unflatten
mlx.utils.tree_map
diff --git a/docs/build/html/searchindex.js b/docs/build/html/searchindex.js
index 890004ff5..73e85e6e0 100644
--- a/docs/build/html/searchindex.js
+++ b/docs/build/html/searchindex.js
@@ -1 +1 @@
-Search.setIndex({"docnames": ["cpp/ops", "dev/extensions", "examples/linear_regression", "examples/llama-inference", "examples/mlp", "index", "install", "python/_autosummary/mlx.core.Device", "python/_autosummary/mlx.core.Dtype", "python/_autosummary/mlx.core.Stream", "python/_autosummary/mlx.core.abs", "python/_autosummary/mlx.core.add", "python/_autosummary/mlx.core.all", "python/_autosummary/mlx.core.allclose", "python/_autosummary/mlx.core.any", "python/_autosummary/mlx.core.arange", "python/_autosummary/mlx.core.arccos", "python/_autosummary/mlx.core.arccosh", "python/_autosummary/mlx.core.arcsin", "python/_autosummary/mlx.core.arcsinh", "python/_autosummary/mlx.core.arctan", "python/_autosummary/mlx.core.arctanh", "python/_autosummary/mlx.core.argmax", "python/_autosummary/mlx.core.argmin", "python/_autosummary/mlx.core.argpartition", "python/_autosummary/mlx.core.argsort", "python/_autosummary/mlx.core.array", "python/_autosummary/mlx.core.array.T", "python/_autosummary/mlx.core.array.abs", "python/_autosummary/mlx.core.array.all", "python/_autosummary/mlx.core.array.any", "python/_autosummary/mlx.core.array.argmax", "python/_autosummary/mlx.core.array.argmin", "python/_autosummary/mlx.core.array.astype", "python/_autosummary/mlx.core.array.cos", "python/_autosummary/mlx.core.array.dtype", "python/_autosummary/mlx.core.array.exp", "python/_autosummary/mlx.core.array.item", "python/_autosummary/mlx.core.array.log", "python/_autosummary/mlx.core.array.log1p", "python/_autosummary/mlx.core.array.logsumexp", "python/_autosummary/mlx.core.array.max", "python/_autosummary/mlx.core.array.mean", "python/_autosummary/mlx.core.array.min", "python/_autosummary/mlx.core.array.ndim", "python/_autosummary/mlx.core.array.prod", "python/_autosummary/mlx.core.array.reciprocal", "python/_autosummary/mlx.core.array.reshape", "python/_autosummary/mlx.core.array.round", "python/_autosummary/mlx.core.array.rsqrt", "python/_autosummary/mlx.core.array.shape", "python/_autosummary/mlx.core.array.sin", "python/_autosummary/mlx.core.array.size", "python/_autosummary/mlx.core.array.split", "python/_autosummary/mlx.core.array.sqrt", "python/_autosummary/mlx.core.array.square", "python/_autosummary/mlx.core.array.sum", "python/_autosummary/mlx.core.array.tolist", "python/_autosummary/mlx.core.array.transpose", "python/_autosummary/mlx.core.array.var", "python/_autosummary/mlx.core.array_equal", "python/_autosummary/mlx.core.broadcast_to", "python/_autosummary/mlx.core.ceil", "python/_autosummary/mlx.core.clip", "python/_autosummary/mlx.core.concatenate", "python/_autosummary/mlx.core.conv1d", "python/_autosummary/mlx.core.conv2d", "python/_autosummary/mlx.core.convolve", "python/_autosummary/mlx.core.cos", "python/_autosummary/mlx.core.cosh", "python/_autosummary/mlx.core.default_device", "python/_autosummary/mlx.core.default_stream", "python/_autosummary/mlx.core.dequantize", "python/_autosummary/mlx.core.divide", "python/_autosummary/mlx.core.divmod", "python/_autosummary/mlx.core.equal", "python/_autosummary/mlx.core.erf", "python/_autosummary/mlx.core.erfinv", "python/_autosummary/mlx.core.eval", "python/_autosummary/mlx.core.exp", "python/_autosummary/mlx.core.expand_dims", "python/_autosummary/mlx.core.eye", "python/_autosummary/mlx.core.fft.fft", "python/_autosummary/mlx.core.fft.fft2", "python/_autosummary/mlx.core.fft.fftn", "python/_autosummary/mlx.core.fft.ifft", "python/_autosummary/mlx.core.fft.ifft2", "python/_autosummary/mlx.core.fft.ifftn", "python/_autosummary/mlx.core.fft.irfft", "python/_autosummary/mlx.core.fft.irfft2", "python/_autosummary/mlx.core.fft.irfftn", "python/_autosummary/mlx.core.fft.rfft", "python/_autosummary/mlx.core.fft.rfft2", "python/_autosummary/mlx.core.fft.rfftn", "python/_autosummary/mlx.core.flatten", "python/_autosummary/mlx.core.floor", "python/_autosummary/mlx.core.floor_divide", "python/_autosummary/mlx.core.full", "python/_autosummary/mlx.core.grad", "python/_autosummary/mlx.core.greater", "python/_autosummary/mlx.core.greater_equal", "python/_autosummary/mlx.core.identity", "python/_autosummary/mlx.core.inner", "python/_autosummary/mlx.core.isinf", "python/_autosummary/mlx.core.isnan", "python/_autosummary/mlx.core.isneginf", "python/_autosummary/mlx.core.isposinf", "python/_autosummary/mlx.core.jvp", "python/_autosummary/mlx.core.less", "python/_autosummary/mlx.core.less_equal", "python/_autosummary/mlx.core.linalg.norm", "python/_autosummary/mlx.core.linspace", "python/_autosummary/mlx.core.load", "python/_autosummary/mlx.core.log", "python/_autosummary/mlx.core.log10", "python/_autosummary/mlx.core.log1p", "python/_autosummary/mlx.core.log2", "python/_autosummary/mlx.core.logaddexp", "python/_autosummary/mlx.core.logical_and", "python/_autosummary/mlx.core.logical_not", "python/_autosummary/mlx.core.logical_or", "python/_autosummary/mlx.core.logsumexp", "python/_autosummary/mlx.core.matmul", "python/_autosummary/mlx.core.max", "python/_autosummary/mlx.core.maximum", "python/_autosummary/mlx.core.mean", "python/_autosummary/mlx.core.min", "python/_autosummary/mlx.core.minimum", "python/_autosummary/mlx.core.moveaxis", "python/_autosummary/mlx.core.multiply", "python/_autosummary/mlx.core.negative", "python/_autosummary/mlx.core.new_stream", "python/_autosummary/mlx.core.ones", "python/_autosummary/mlx.core.ones_like", "python/_autosummary/mlx.core.outer", "python/_autosummary/mlx.core.pad", "python/_autosummary/mlx.core.partition", "python/_autosummary/mlx.core.prod", "python/_autosummary/mlx.core.quantize", "python/_autosummary/mlx.core.quantized_matmul", "python/_autosummary/mlx.core.random.bernoulli", "python/_autosummary/mlx.core.random.categorical", "python/_autosummary/mlx.core.random.gumbel", "python/_autosummary/mlx.core.random.key", "python/_autosummary/mlx.core.random.normal", "python/_autosummary/mlx.core.random.randint", "python/_autosummary/mlx.core.random.seed", "python/_autosummary/mlx.core.random.split", "python/_autosummary/mlx.core.random.truncated_normal", "python/_autosummary/mlx.core.random.uniform", "python/_autosummary/mlx.core.reciprocal", "python/_autosummary/mlx.core.repeat", "python/_autosummary/mlx.core.reshape", "python/_autosummary/mlx.core.round", "python/_autosummary/mlx.core.rsqrt", "python/_autosummary/mlx.core.save", "python/_autosummary/mlx.core.save_gguf", "python/_autosummary/mlx.core.save_safetensors", "python/_autosummary/mlx.core.savez", "python/_autosummary/mlx.core.savez_compressed", "python/_autosummary/mlx.core.set_default_device", "python/_autosummary/mlx.core.set_default_stream", "python/_autosummary/mlx.core.sigmoid", "python/_autosummary/mlx.core.sign", "python/_autosummary/mlx.core.simplify", "python/_autosummary/mlx.core.sin", "python/_autosummary/mlx.core.sinh", "python/_autosummary/mlx.core.softmax", "python/_autosummary/mlx.core.sort", "python/_autosummary/mlx.core.split", "python/_autosummary/mlx.core.sqrt", "python/_autosummary/mlx.core.square", "python/_autosummary/mlx.core.squeeze", "python/_autosummary/mlx.core.stack", "python/_autosummary/mlx.core.stop_gradient", "python/_autosummary/mlx.core.subtract", "python/_autosummary/mlx.core.sum", "python/_autosummary/mlx.core.swapaxes", "python/_autosummary/mlx.core.take", "python/_autosummary/mlx.core.take_along_axis", "python/_autosummary/mlx.core.tan", "python/_autosummary/mlx.core.tanh", "python/_autosummary/mlx.core.tensordot", "python/_autosummary/mlx.core.transpose", "python/_autosummary/mlx.core.tri", "python/_autosummary/mlx.core.tril", "python/_autosummary/mlx.core.triu", "python/_autosummary/mlx.core.value_and_grad", "python/_autosummary/mlx.core.var", "python/_autosummary/mlx.core.vjp", "python/_autosummary/mlx.core.vmap", "python/_autosummary/mlx.core.where", "python/_autosummary/mlx.core.zeros", "python/_autosummary/mlx.core.zeros_like", "python/_autosummary/mlx.nn.value_and_grad", "python/_autosummary/mlx.optimizers.AdaDelta", "python/_autosummary/mlx.optimizers.Adagrad", "python/_autosummary/mlx.optimizers.Adam", "python/_autosummary/mlx.optimizers.AdamW", "python/_autosummary/mlx.optimizers.Adamax", "python/_autosummary/mlx.optimizers.Lion", "python/_autosummary/mlx.optimizers.Optimizer", "python/_autosummary/mlx.optimizers.OptimizerState", "python/_autosummary/mlx.optimizers.RMSprop", "python/_autosummary/mlx.optimizers.SGD", "python/_autosummary/mlx.utils.tree_flatten", "python/_autosummary/mlx.utils.tree_map", "python/_autosummary/mlx.utils.tree_unflatten", "python/array", "python/data_types", "python/devices_and_streams", "python/fft", "python/linalg", "python/nn", "python/nn/_autosummary/mlx.nn.ALiBi", "python/nn/_autosummary/mlx.nn.BatchNorm", "python/nn/_autosummary/mlx.nn.Conv1d", "python/nn/_autosummary/mlx.nn.Conv2d", "python/nn/_autosummary/mlx.nn.Dropout", "python/nn/_autosummary/mlx.nn.Dropout2d", "python/nn/_autosummary/mlx.nn.Dropout3d", "python/nn/_autosummary/mlx.nn.Embedding", "python/nn/_autosummary/mlx.nn.GELU", "python/nn/_autosummary/mlx.nn.GroupNorm", "python/nn/_autosummary/mlx.nn.InstanceNorm", "python/nn/_autosummary/mlx.nn.LayerNorm", "python/nn/_autosummary/mlx.nn.Linear", "python/nn/_autosummary/mlx.nn.Mish", "python/nn/_autosummary/mlx.nn.Module.apply", "python/nn/_autosummary/mlx.nn.Module.apply_to_modules", "python/nn/_autosummary/mlx.nn.Module.children", "python/nn/_autosummary/mlx.nn.Module.eval", "python/nn/_autosummary/mlx.nn.Module.filter_and_map", "python/nn/_autosummary/mlx.nn.Module.freeze", "python/nn/_autosummary/mlx.nn.Module.leaf_modules", "python/nn/_autosummary/mlx.nn.Module.load_weights", "python/nn/_autosummary/mlx.nn.Module.modules", "python/nn/_autosummary/mlx.nn.Module.named_modules", "python/nn/_autosummary/mlx.nn.Module.parameters", "python/nn/_autosummary/mlx.nn.Module.save_weights", "python/nn/_autosummary/mlx.nn.Module.train", "python/nn/_autosummary/mlx.nn.Module.trainable_parameters", "python/nn/_autosummary/mlx.nn.Module.training", "python/nn/_autosummary/mlx.nn.Module.unfreeze", "python/nn/_autosummary/mlx.nn.Module.update", "python/nn/_autosummary/mlx.nn.Module.update_modules", "python/nn/_autosummary/mlx.nn.MultiHeadAttention", "python/nn/_autosummary/mlx.nn.PReLU", "python/nn/_autosummary/mlx.nn.QuantizedLinear", "python/nn/_autosummary/mlx.nn.RMSNorm", "python/nn/_autosummary/mlx.nn.ReLU", "python/nn/_autosummary/mlx.nn.RoPE", "python/nn/_autosummary/mlx.nn.SELU", "python/nn/_autosummary/mlx.nn.Sequential", "python/nn/_autosummary/mlx.nn.SiLU", "python/nn/_autosummary/mlx.nn.SinusoidalPositionalEncoding", "python/nn/_autosummary/mlx.nn.Step", "python/nn/_autosummary/mlx.nn.Transformer", "python/nn/_autosummary_functions/mlx.nn.gelu", "python/nn/_autosummary_functions/mlx.nn.gelu_approx", "python/nn/_autosummary_functions/mlx.nn.gelu_fast_approx", "python/nn/_autosummary_functions/mlx.nn.losses.binary_cross_entropy", "python/nn/_autosummary_functions/mlx.nn.losses.cosine_similarity_loss", "python/nn/_autosummary_functions/mlx.nn.losses.cross_entropy", "python/nn/_autosummary_functions/mlx.nn.losses.hinge_loss", "python/nn/_autosummary_functions/mlx.nn.losses.huber_loss", "python/nn/_autosummary_functions/mlx.nn.losses.kl_div_loss", "python/nn/_autosummary_functions/mlx.nn.losses.l1_loss", "python/nn/_autosummary_functions/mlx.nn.losses.log_cosh_loss", "python/nn/_autosummary_functions/mlx.nn.losses.mse_loss", "python/nn/_autosummary_functions/mlx.nn.losses.nll_loss", "python/nn/_autosummary_functions/mlx.nn.losses.smooth_l1_loss", "python/nn/_autosummary_functions/mlx.nn.losses.triplet_loss", "python/nn/_autosummary_functions/mlx.nn.mish", "python/nn/_autosummary_functions/mlx.nn.prelu", "python/nn/_autosummary_functions/mlx.nn.relu", "python/nn/_autosummary_functions/mlx.nn.selu", "python/nn/_autosummary_functions/mlx.nn.silu", "python/nn/_autosummary_functions/mlx.nn.step", "python/nn/functions", "python/nn/layers", "python/nn/losses", "python/nn/module", "python/ops", "python/optimizers", "python/random", "python/transforms", "python/tree_utils", "usage/function_transforms", "usage/indexing", "usage/lazy_evaluation", "usage/numpy", "usage/quick_start", "usage/saving_and_loading", "usage/unified_memory", "usage/using_streams"], "filenames": ["cpp/ops.rst", "dev/extensions.rst", "examples/linear_regression.rst", "examples/llama-inference.rst", "examples/mlp.rst", "index.rst", "install.rst", "python/_autosummary/mlx.core.Device.rst", "python/_autosummary/mlx.core.Dtype.rst", "python/_autosummary/mlx.core.Stream.rst", "python/_autosummary/mlx.core.abs.rst", "python/_autosummary/mlx.core.add.rst", "python/_autosummary/mlx.core.all.rst", "python/_autosummary/mlx.core.allclose.rst", "python/_autosummary/mlx.core.any.rst", "python/_autosummary/mlx.core.arange.rst", "python/_autosummary/mlx.core.arccos.rst", "python/_autosummary/mlx.core.arccosh.rst", "python/_autosummary/mlx.core.arcsin.rst", "python/_autosummary/mlx.core.arcsinh.rst", "python/_autosummary/mlx.core.arctan.rst", "python/_autosummary/mlx.core.arctanh.rst", "python/_autosummary/mlx.core.argmax.rst", "python/_autosummary/mlx.core.argmin.rst", "python/_autosummary/mlx.core.argpartition.rst", "python/_autosummary/mlx.core.argsort.rst", "python/_autosummary/mlx.core.array.rst", "python/_autosummary/mlx.core.array.T.rst", "python/_autosummary/mlx.core.array.abs.rst", "python/_autosummary/mlx.core.array.all.rst", "python/_autosummary/mlx.core.array.any.rst", "python/_autosummary/mlx.core.array.argmax.rst", "python/_autosummary/mlx.core.array.argmin.rst", "python/_autosummary/mlx.core.array.astype.rst", "python/_autosummary/mlx.core.array.cos.rst", "python/_autosummary/mlx.core.array.dtype.rst", "python/_autosummary/mlx.core.array.exp.rst", "python/_autosummary/mlx.core.array.item.rst", "python/_autosummary/mlx.core.array.log.rst", "python/_autosummary/mlx.core.array.log1p.rst", "python/_autosummary/mlx.core.array.logsumexp.rst", "python/_autosummary/mlx.core.array.max.rst", "python/_autosummary/mlx.core.array.mean.rst", "python/_autosummary/mlx.core.array.min.rst", "python/_autosummary/mlx.core.array.ndim.rst", "python/_autosummary/mlx.core.array.prod.rst", "python/_autosummary/mlx.core.array.reciprocal.rst", "python/_autosummary/mlx.core.array.reshape.rst", "python/_autosummary/mlx.core.array.round.rst", "python/_autosummary/mlx.core.array.rsqrt.rst", "python/_autosummary/mlx.core.array.shape.rst", "python/_autosummary/mlx.core.array.sin.rst", "python/_autosummary/mlx.core.array.size.rst", "python/_autosummary/mlx.core.array.split.rst", "python/_autosummary/mlx.core.array.sqrt.rst", "python/_autosummary/mlx.core.array.square.rst", "python/_autosummary/mlx.core.array.sum.rst", "python/_autosummary/mlx.core.array.tolist.rst", "python/_autosummary/mlx.core.array.transpose.rst", "python/_autosummary/mlx.core.array.var.rst", "python/_autosummary/mlx.core.array_equal.rst", "python/_autosummary/mlx.core.broadcast_to.rst", "python/_autosummary/mlx.core.ceil.rst", "python/_autosummary/mlx.core.clip.rst", "python/_autosummary/mlx.core.concatenate.rst", "python/_autosummary/mlx.core.conv1d.rst", "python/_autosummary/mlx.core.conv2d.rst", "python/_autosummary/mlx.core.convolve.rst", "python/_autosummary/mlx.core.cos.rst", "python/_autosummary/mlx.core.cosh.rst", "python/_autosummary/mlx.core.default_device.rst", "python/_autosummary/mlx.core.default_stream.rst", "python/_autosummary/mlx.core.dequantize.rst", "python/_autosummary/mlx.core.divide.rst", "python/_autosummary/mlx.core.divmod.rst", "python/_autosummary/mlx.core.equal.rst", "python/_autosummary/mlx.core.erf.rst", "python/_autosummary/mlx.core.erfinv.rst", "python/_autosummary/mlx.core.eval.rst", "python/_autosummary/mlx.core.exp.rst", "python/_autosummary/mlx.core.expand_dims.rst", "python/_autosummary/mlx.core.eye.rst", "python/_autosummary/mlx.core.fft.fft.rst", "python/_autosummary/mlx.core.fft.fft2.rst", "python/_autosummary/mlx.core.fft.fftn.rst", "python/_autosummary/mlx.core.fft.ifft.rst", "python/_autosummary/mlx.core.fft.ifft2.rst", "python/_autosummary/mlx.core.fft.ifftn.rst", "python/_autosummary/mlx.core.fft.irfft.rst", "python/_autosummary/mlx.core.fft.irfft2.rst", "python/_autosummary/mlx.core.fft.irfftn.rst", "python/_autosummary/mlx.core.fft.rfft.rst", "python/_autosummary/mlx.core.fft.rfft2.rst", "python/_autosummary/mlx.core.fft.rfftn.rst", "python/_autosummary/mlx.core.flatten.rst", "python/_autosummary/mlx.core.floor.rst", "python/_autosummary/mlx.core.floor_divide.rst", "python/_autosummary/mlx.core.full.rst", "python/_autosummary/mlx.core.grad.rst", "python/_autosummary/mlx.core.greater.rst", "python/_autosummary/mlx.core.greater_equal.rst", "python/_autosummary/mlx.core.identity.rst", "python/_autosummary/mlx.core.inner.rst", "python/_autosummary/mlx.core.isinf.rst", "python/_autosummary/mlx.core.isnan.rst", "python/_autosummary/mlx.core.isneginf.rst", "python/_autosummary/mlx.core.isposinf.rst", "python/_autosummary/mlx.core.jvp.rst", "python/_autosummary/mlx.core.less.rst", "python/_autosummary/mlx.core.less_equal.rst", "python/_autosummary/mlx.core.linalg.norm.rst", "python/_autosummary/mlx.core.linspace.rst", "python/_autosummary/mlx.core.load.rst", "python/_autosummary/mlx.core.log.rst", "python/_autosummary/mlx.core.log10.rst", "python/_autosummary/mlx.core.log1p.rst", "python/_autosummary/mlx.core.log2.rst", "python/_autosummary/mlx.core.logaddexp.rst", "python/_autosummary/mlx.core.logical_and.rst", "python/_autosummary/mlx.core.logical_not.rst", "python/_autosummary/mlx.core.logical_or.rst", "python/_autosummary/mlx.core.logsumexp.rst", "python/_autosummary/mlx.core.matmul.rst", "python/_autosummary/mlx.core.max.rst", "python/_autosummary/mlx.core.maximum.rst", "python/_autosummary/mlx.core.mean.rst", "python/_autosummary/mlx.core.min.rst", "python/_autosummary/mlx.core.minimum.rst", "python/_autosummary/mlx.core.moveaxis.rst", "python/_autosummary/mlx.core.multiply.rst", "python/_autosummary/mlx.core.negative.rst", "python/_autosummary/mlx.core.new_stream.rst", "python/_autosummary/mlx.core.ones.rst", "python/_autosummary/mlx.core.ones_like.rst", "python/_autosummary/mlx.core.outer.rst", "python/_autosummary/mlx.core.pad.rst", "python/_autosummary/mlx.core.partition.rst", "python/_autosummary/mlx.core.prod.rst", "python/_autosummary/mlx.core.quantize.rst", "python/_autosummary/mlx.core.quantized_matmul.rst", "python/_autosummary/mlx.core.random.bernoulli.rst", "python/_autosummary/mlx.core.random.categorical.rst", "python/_autosummary/mlx.core.random.gumbel.rst", "python/_autosummary/mlx.core.random.key.rst", "python/_autosummary/mlx.core.random.normal.rst", "python/_autosummary/mlx.core.random.randint.rst", "python/_autosummary/mlx.core.random.seed.rst", "python/_autosummary/mlx.core.random.split.rst", "python/_autosummary/mlx.core.random.truncated_normal.rst", "python/_autosummary/mlx.core.random.uniform.rst", "python/_autosummary/mlx.core.reciprocal.rst", "python/_autosummary/mlx.core.repeat.rst", "python/_autosummary/mlx.core.reshape.rst", "python/_autosummary/mlx.core.round.rst", "python/_autosummary/mlx.core.rsqrt.rst", "python/_autosummary/mlx.core.save.rst", "python/_autosummary/mlx.core.save_gguf.rst", "python/_autosummary/mlx.core.save_safetensors.rst", "python/_autosummary/mlx.core.savez.rst", "python/_autosummary/mlx.core.savez_compressed.rst", "python/_autosummary/mlx.core.set_default_device.rst", "python/_autosummary/mlx.core.set_default_stream.rst", "python/_autosummary/mlx.core.sigmoid.rst", "python/_autosummary/mlx.core.sign.rst", "python/_autosummary/mlx.core.simplify.rst", "python/_autosummary/mlx.core.sin.rst", "python/_autosummary/mlx.core.sinh.rst", "python/_autosummary/mlx.core.softmax.rst", "python/_autosummary/mlx.core.sort.rst", "python/_autosummary/mlx.core.split.rst", "python/_autosummary/mlx.core.sqrt.rst", "python/_autosummary/mlx.core.square.rst", "python/_autosummary/mlx.core.squeeze.rst", "python/_autosummary/mlx.core.stack.rst", "python/_autosummary/mlx.core.stop_gradient.rst", "python/_autosummary/mlx.core.subtract.rst", "python/_autosummary/mlx.core.sum.rst", "python/_autosummary/mlx.core.swapaxes.rst", "python/_autosummary/mlx.core.take.rst", "python/_autosummary/mlx.core.take_along_axis.rst", "python/_autosummary/mlx.core.tan.rst", "python/_autosummary/mlx.core.tanh.rst", "python/_autosummary/mlx.core.tensordot.rst", "python/_autosummary/mlx.core.transpose.rst", "python/_autosummary/mlx.core.tri.rst", "python/_autosummary/mlx.core.tril.rst", "python/_autosummary/mlx.core.triu.rst", "python/_autosummary/mlx.core.value_and_grad.rst", "python/_autosummary/mlx.core.var.rst", "python/_autosummary/mlx.core.vjp.rst", "python/_autosummary/mlx.core.vmap.rst", "python/_autosummary/mlx.core.where.rst", "python/_autosummary/mlx.core.zeros.rst", "python/_autosummary/mlx.core.zeros_like.rst", "python/_autosummary/mlx.nn.value_and_grad.rst", "python/_autosummary/mlx.optimizers.AdaDelta.rst", "python/_autosummary/mlx.optimizers.Adagrad.rst", "python/_autosummary/mlx.optimizers.Adam.rst", "python/_autosummary/mlx.optimizers.AdamW.rst", "python/_autosummary/mlx.optimizers.Adamax.rst", "python/_autosummary/mlx.optimizers.Lion.rst", "python/_autosummary/mlx.optimizers.Optimizer.rst", "python/_autosummary/mlx.optimizers.OptimizerState.rst", "python/_autosummary/mlx.optimizers.RMSprop.rst", "python/_autosummary/mlx.optimizers.SGD.rst", "python/_autosummary/mlx.utils.tree_flatten.rst", "python/_autosummary/mlx.utils.tree_map.rst", "python/_autosummary/mlx.utils.tree_unflatten.rst", "python/array.rst", "python/data_types.rst", "python/devices_and_streams.rst", "python/fft.rst", "python/linalg.rst", "python/nn.rst", "python/nn/_autosummary/mlx.nn.ALiBi.rst", "python/nn/_autosummary/mlx.nn.BatchNorm.rst", "python/nn/_autosummary/mlx.nn.Conv1d.rst", "python/nn/_autosummary/mlx.nn.Conv2d.rst", "python/nn/_autosummary/mlx.nn.Dropout.rst", "python/nn/_autosummary/mlx.nn.Dropout2d.rst", "python/nn/_autosummary/mlx.nn.Dropout3d.rst", "python/nn/_autosummary/mlx.nn.Embedding.rst", "python/nn/_autosummary/mlx.nn.GELU.rst", "python/nn/_autosummary/mlx.nn.GroupNorm.rst", "python/nn/_autosummary/mlx.nn.InstanceNorm.rst", "python/nn/_autosummary/mlx.nn.LayerNorm.rst", "python/nn/_autosummary/mlx.nn.Linear.rst", "python/nn/_autosummary/mlx.nn.Mish.rst", "python/nn/_autosummary/mlx.nn.Module.apply.rst", "python/nn/_autosummary/mlx.nn.Module.apply_to_modules.rst", "python/nn/_autosummary/mlx.nn.Module.children.rst", "python/nn/_autosummary/mlx.nn.Module.eval.rst", "python/nn/_autosummary/mlx.nn.Module.filter_and_map.rst", "python/nn/_autosummary/mlx.nn.Module.freeze.rst", "python/nn/_autosummary/mlx.nn.Module.leaf_modules.rst", "python/nn/_autosummary/mlx.nn.Module.load_weights.rst", "python/nn/_autosummary/mlx.nn.Module.modules.rst", "python/nn/_autosummary/mlx.nn.Module.named_modules.rst", "python/nn/_autosummary/mlx.nn.Module.parameters.rst", "python/nn/_autosummary/mlx.nn.Module.save_weights.rst", "python/nn/_autosummary/mlx.nn.Module.train.rst", "python/nn/_autosummary/mlx.nn.Module.trainable_parameters.rst", "python/nn/_autosummary/mlx.nn.Module.training.rst", "python/nn/_autosummary/mlx.nn.Module.unfreeze.rst", "python/nn/_autosummary/mlx.nn.Module.update.rst", "python/nn/_autosummary/mlx.nn.Module.update_modules.rst", "python/nn/_autosummary/mlx.nn.MultiHeadAttention.rst", "python/nn/_autosummary/mlx.nn.PReLU.rst", "python/nn/_autosummary/mlx.nn.QuantizedLinear.rst", "python/nn/_autosummary/mlx.nn.RMSNorm.rst", "python/nn/_autosummary/mlx.nn.ReLU.rst", "python/nn/_autosummary/mlx.nn.RoPE.rst", "python/nn/_autosummary/mlx.nn.SELU.rst", "python/nn/_autosummary/mlx.nn.Sequential.rst", "python/nn/_autosummary/mlx.nn.SiLU.rst", "python/nn/_autosummary/mlx.nn.SinusoidalPositionalEncoding.rst", "python/nn/_autosummary/mlx.nn.Step.rst", "python/nn/_autosummary/mlx.nn.Transformer.rst", "python/nn/_autosummary_functions/mlx.nn.gelu.rst", "python/nn/_autosummary_functions/mlx.nn.gelu_approx.rst", "python/nn/_autosummary_functions/mlx.nn.gelu_fast_approx.rst", "python/nn/_autosummary_functions/mlx.nn.losses.binary_cross_entropy.rst", "python/nn/_autosummary_functions/mlx.nn.losses.cosine_similarity_loss.rst", "python/nn/_autosummary_functions/mlx.nn.losses.cross_entropy.rst", "python/nn/_autosummary_functions/mlx.nn.losses.hinge_loss.rst", "python/nn/_autosummary_functions/mlx.nn.losses.huber_loss.rst", "python/nn/_autosummary_functions/mlx.nn.losses.kl_div_loss.rst", "python/nn/_autosummary_functions/mlx.nn.losses.l1_loss.rst", "python/nn/_autosummary_functions/mlx.nn.losses.log_cosh_loss.rst", "python/nn/_autosummary_functions/mlx.nn.losses.mse_loss.rst", "python/nn/_autosummary_functions/mlx.nn.losses.nll_loss.rst", "python/nn/_autosummary_functions/mlx.nn.losses.smooth_l1_loss.rst", "python/nn/_autosummary_functions/mlx.nn.losses.triplet_loss.rst", "python/nn/_autosummary_functions/mlx.nn.mish.rst", "python/nn/_autosummary_functions/mlx.nn.prelu.rst", "python/nn/_autosummary_functions/mlx.nn.relu.rst", "python/nn/_autosummary_functions/mlx.nn.selu.rst", "python/nn/_autosummary_functions/mlx.nn.silu.rst", "python/nn/_autosummary_functions/mlx.nn.step.rst", "python/nn/functions.rst", "python/nn/layers.rst", "python/nn/losses.rst", "python/nn/module.rst", "python/ops.rst", "python/optimizers.rst", "python/random.rst", "python/transforms.rst", "python/tree_utils.rst", "usage/function_transforms.rst", "usage/indexing.rst", "usage/lazy_evaluation.rst", "usage/numpy.rst", "usage/quick_start.rst", "usage/saving_and_loading.rst", "usage/unified_memory.rst", "usage/using_streams.rst"], "titles": ["Operations", "Developer Documentation", "Linear Regression", "LLM inference", "Multi-Layer Perceptron", "MLX", "Build and Install", "mlx.core.Device", "mlx.core.Dtype", "mlx.core.Stream", "mlx.core.abs", "mlx.core.add", "mlx.core.all", "mlx.core.allclose", "mlx.core.any", "mlx.core.arange", "mlx.core.arccos", "mlx.core.arccosh", "mlx.core.arcsin", "mlx.core.arcsinh", "mlx.core.arctan", "mlx.core.arctanh", "mlx.core.argmax", "mlx.core.argmin", "mlx.core.argpartition", "mlx.core.argsort", "mlx.core.array", "mlx.core.array.T", "mlx.core.array.abs", "mlx.core.array.all", "mlx.core.array.any", "mlx.core.array.argmax", "mlx.core.array.argmin", "mlx.core.array.astype", "mlx.core.array.cos", "mlx.core.array.dtype", "mlx.core.array.exp", "mlx.core.array.item", "mlx.core.array.log", "mlx.core.array.log1p", "mlx.core.array.logsumexp", "mlx.core.array.max", "mlx.core.array.mean", "mlx.core.array.min", "mlx.core.array.ndim", "mlx.core.array.prod", "mlx.core.array.reciprocal", "mlx.core.array.reshape", "mlx.core.array.round", "mlx.core.array.rsqrt", "mlx.core.array.shape", "mlx.core.array.sin", "mlx.core.array.size", "mlx.core.array.split", "mlx.core.array.sqrt", "mlx.core.array.square", "mlx.core.array.sum", "mlx.core.array.tolist", "mlx.core.array.transpose", "mlx.core.array.var", "mlx.core.array_equal", "mlx.core.broadcast_to", "mlx.core.ceil", "mlx.core.clip", "mlx.core.concatenate", "mlx.core.conv1d", "mlx.core.conv2d", "mlx.core.convolve", "mlx.core.cos", "mlx.core.cosh", "mlx.core.default_device", "mlx.core.default_stream", "mlx.core.dequantize", "mlx.core.divide", "mlx.core.divmod", "mlx.core.equal", "mlx.core.erf", "mlx.core.erfinv", "mlx.core.eval", "mlx.core.exp", "mlx.core.expand_dims", "mlx.core.eye", "mlx.core.fft.fft", "mlx.core.fft.fft2", "mlx.core.fft.fftn", "mlx.core.fft.ifft", "mlx.core.fft.ifft2", "mlx.core.fft.ifftn", "mlx.core.fft.irfft", "mlx.core.fft.irfft2", "mlx.core.fft.irfftn", "mlx.core.fft.rfft", "mlx.core.fft.rfft2", "mlx.core.fft.rfftn", "mlx.core.flatten", "mlx.core.floor", "mlx.core.floor_divide", "mlx.core.full", "mlx.core.grad", "mlx.core.greater", "mlx.core.greater_equal", "mlx.core.identity", "mlx.core.inner", "mlx.core.isinf", "mlx.core.isnan", "mlx.core.isneginf", "mlx.core.isposinf", "mlx.core.jvp", "mlx.core.less", "mlx.core.less_equal", "mlx.core.linalg.norm", "mlx.core.linspace", "mlx.core.load", "mlx.core.log", "mlx.core.log10", "mlx.core.log1p", "mlx.core.log2", "mlx.core.logaddexp", "mlx.core.logical_and", "mlx.core.logical_not", "mlx.core.logical_or", "mlx.core.logsumexp", "mlx.core.matmul", "mlx.core.max", "mlx.core.maximum", "mlx.core.mean", "mlx.core.min", "mlx.core.minimum", "mlx.core.moveaxis", "mlx.core.multiply", "mlx.core.negative", "mlx.core.new_stream", "mlx.core.ones", "mlx.core.ones_like", "mlx.core.outer", "mlx.core.pad", "mlx.core.partition", "mlx.core.prod", "mlx.core.quantize", "mlx.core.quantized_matmul", "mlx.core.random.bernoulli", "mlx.core.random.categorical", "mlx.core.random.gumbel", "mlx.core.random.key", "mlx.core.random.normal", "mlx.core.random.randint", "mlx.core.random.seed", "mlx.core.random.split", "mlx.core.random.truncated_normal", "mlx.core.random.uniform", "mlx.core.reciprocal", "mlx.core.repeat", "mlx.core.reshape", "mlx.core.round", "mlx.core.rsqrt", "mlx.core.save", "mlx.core.save_gguf", "mlx.core.save_safetensors", "mlx.core.savez", "mlx.core.savez_compressed", "mlx.core.set_default_device", "mlx.core.set_default_stream", "mlx.core.sigmoid", "mlx.core.sign", "mlx.core.simplify", "mlx.core.sin", "mlx.core.sinh", "mlx.core.softmax", "mlx.core.sort", "mlx.core.split", "mlx.core.sqrt", "mlx.core.square", "mlx.core.squeeze", "mlx.core.stack", "mlx.core.stop_gradient", "mlx.core.subtract", "mlx.core.sum", "mlx.core.swapaxes", "mlx.core.take", "mlx.core.take_along_axis", "mlx.core.tan", "mlx.core.tanh", "mlx.core.tensordot", "mlx.core.transpose", "mlx.core.tri", "mlx.core.tril", "mlx.core.triu", "mlx.core.value_and_grad", "mlx.core.var", "mlx.core.vjp", "mlx.core.vmap", "mlx.core.where", "mlx.core.zeros", "mlx.core.zeros_like", "mlx.nn.value_and_grad", "mlx.optimizers.AdaDelta", "mlx.optimizers.Adagrad", "mlx.optimizers.Adam", "mlx.optimizers.AdamW", "mlx.optimizers.Adamax", "mlx.optimizers.Lion", "mlx.optimizers.Optimizer", "mlx.optimizers.OptimizerState", "mlx.optimizers.RMSprop", "mlx.optimizers.SGD", "mlx.utils.tree_flatten", "mlx.utils.tree_map", "mlx.utils.tree_unflatten", "Array", "Data Types", "Devices and Streams", "FFT", "Linear Algebra", "Neural Networks", "mlx.nn.ALiBi", "mlx.nn.BatchNorm", "mlx.nn.Conv1d", "mlx.nn.Conv2d", "mlx.nn.Dropout", "mlx.nn.Dropout2d", "mlx.nn.Dropout3d", "mlx.nn.Embedding", "mlx.nn.GELU", "mlx.nn.GroupNorm", "mlx.nn.InstanceNorm", "mlx.nn.LayerNorm", "mlx.nn.Linear", "mlx.nn.Mish", "mlx.nn.Module.apply", "mlx.nn.Module.apply_to_modules", "mlx.nn.Module.children", "mlx.nn.Module.eval", "mlx.nn.Module.filter_and_map", "mlx.nn.Module.freeze", "mlx.nn.Module.leaf_modules", "mlx.nn.Module.load_weights", "mlx.nn.Module.modules", "mlx.nn.Module.named_modules", "mlx.nn.Module.parameters", "mlx.nn.Module.save_weights", "mlx.nn.Module.train", "mlx.nn.Module.trainable_parameters", "mlx.nn.Module.training", "mlx.nn.Module.unfreeze", "mlx.nn.Module.update", "mlx.nn.Module.update_modules", "mlx.nn.MultiHeadAttention", "mlx.nn.PReLU", "mlx.nn.QuantizedLinear", "mlx.nn.RMSNorm", "mlx.nn.ReLU", "mlx.nn.RoPE", "mlx.nn.SELU", "mlx.nn.Sequential", "mlx.nn.SiLU", "mlx.nn.SinusoidalPositionalEncoding", "mlx.nn.Step", "mlx.nn.Transformer", "mlx.nn.gelu", "mlx.nn.gelu_approx", "mlx.nn.gelu_fast_approx", "mlx.nn.losses.binary_cross_entropy", "mlx.nn.losses.cosine_similarity_loss", "mlx.nn.losses.cross_entropy", "mlx.nn.losses.hinge_loss", "mlx.nn.losses.huber_loss", "mlx.nn.losses.kl_div_loss", "mlx.nn.losses.l1_loss", "mlx.nn.losses.log_cosh_loss", "mlx.nn.losses.mse_loss", "mlx.nn.losses.nll_loss", "mlx.nn.losses.smooth_l1_loss", "mlx.nn.losses.triplet_loss", "mlx.nn.mish", "mlx.nn.prelu", "mlx.nn.relu", "mlx.nn.selu", "mlx.nn.silu", "mlx.nn.step", "Functions", "Layers", "Loss Functions", "Module", "Operations", "Optimizers", "Random", "Transforms", "Tree Utils", "Function Transforms", "Indexing Arrays", "Lazy Evaluation", "Conversion to NumPy and Other Frameworks", "Quick Start Guide", "Saving and Loading Arrays", "Unified Memory", "Using Streams"], "terms": {"mlx": [1, 2, 3, 4, 6, 213, 282, 284, 285, 287, 288, 289, 290, 291, 292, 293, 294], "provid": [1, 3, 72, 98, 182, 187, 206, 213, 228, 233, 235, 243, 244, 245, 248, 257, 279, 282, 293, 295], "open": [1, 6, 15, 145, 149], "flexibl": [1, 5, 245], "which": [1, 3, 4, 5, 6, 15, 33, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 98, 103, 104, 105, 106, 107, 110, 112, 138, 141, 142, 151, 152, 155, 156, 157, 158, 159, 172, 173, 178, 187, 189, 190, 219, 220, 222, 228, 232, 251, 263, 265, 271, 285, 288, 289, 290, 294, 295], "user": [1, 3, 213], "mai": [1, 110, 219, 288, 289], "add": [1, 3, 80, 117, 135, 138, 216, 217, 288, 294], "special": 1, "without": [1, 3, 5, 174, 246, 279, 287, 290, 291, 294], "much": [1, 3, 290], "hassl": 1, "while": [1, 3, 6, 152, 251, 290, 291], "librari": [1, 6, 213], "suppli": 1, "effici": [1, 3, 5, 219, 251, 290, 292], "can": [1, 3, 5, 6, 11, 15, 47, 58, 73, 74, 75, 78, 99, 100, 108, 109, 110, 117, 124, 127, 129, 140, 141, 145, 148, 149, 175, 187, 213, 221, 232, 243, 253, 282, 284, 285, 287, 288, 289, 290, 291, 292, 293, 294, 295], "compos": [1, 5, 213, 288, 292], "ani": [1, 3, 5, 15, 164, 205, 206, 207, 213, 222, 228, 229, 232, 248, 257, 287, 288, 290, 292, 293, 294], "number": [1, 15, 52, 66, 72, 81, 98, 101, 107, 111, 135, 138, 139, 141, 144, 147, 149, 151, 153, 164, 182, 184, 187, 189, 190, 213, 215, 216, 217, 219, 220, 223, 224, 246, 247, 257, 285, 288, 295], "applic": [1, 6], "aris": [1, 291], "case": [1, 3, 84, 87, 88, 90, 91, 92, 93, 94, 122, 152, 172, 219, 252, 256, 271, 276, 278, 288, 292, 293, 294, 295], "where": [1, 4, 81, 138, 187, 190, 215, 216, 217, 218, 219, 220, 222, 223, 224, 225, 226, 232, 247, 249, 252, 254, 256, 258, 259, 260, 274, 276, 277, 278, 288, 289], "new": [1, 4, 61, 128, 152, 173, 183, 206, 246, 282, 284, 289, 290, 291], "function": [1, 2, 3, 4, 5, 13, 74, 76, 77, 98, 107, 110, 122, 162, 164, 187, 189, 190, 194, 206, 213, 222, 227, 229, 233, 243, 247, 253, 256, 257, 258, 259, 260, 273, 278, 284, 285, 287, 289, 290, 291, 293], "highli": [1, 6], "optim": [1, 2, 4, 5, 244, 288, 290], "ar": [1, 2, 3, 4, 5, 6, 13, 15, 60, 61, 63, 67, 81, 83, 84, 86, 87, 89, 90, 92, 93, 98, 103, 104, 105, 106, 107, 110, 112, 122, 134, 135, 136, 138, 139, 140, 141, 142, 145, 148, 149, 158, 159, 172, 173, 178, 187, 189, 190, 200, 205, 206, 215, 216, 217, 218, 219, 220, 223, 224, 225, 226, 235, 246, 248, 279, 282, 287, 288, 289, 290, 291, 292, 293, 294], "need": [1, 3, 4, 5, 60, 138, 213, 244, 245, 255, 257, 285, 288, 290, 291, 292, 294], "For": [1, 3, 6, 110, 138, 207, 213, 215, 219, 228, 233, 240, 243, 248, 251, 255, 285, 289, 290, 291, 292, 293, 294], "you": [1, 3, 4, 5, 6, 213, 255, 257, 285, 288, 289, 291, 293, 294], "design": [1, 2, 5, 285, 294], "your": [1, 3, 6, 282, 288, 290], "own": [1, 6, 291], "link": [1, 6], "top": [1, 226], "core": [1, 2, 3, 4, 213, 215, 224, 235, 238, 241, 261, 282, 284, 291, 292], "we": [1, 2, 3, 4, 72, 138, 139, 198, 200, 213, 221, 253, 285, 287, 288, 290, 294], "inner": 1, "work": [1, 3, 6, 288, 289, 290], "go": [1, 3, 288], "over": [1, 3, 4, 12, 14, 22, 23, 24, 25, 65, 66, 84, 87, 90, 93, 102, 110, 111, 121, 123, 125, 126, 136, 137, 154, 167, 168, 176, 182, 188, 215, 216, 217, 223, 225, 249, 263, 288], "simpl": [1, 3, 4, 213, 221, 279, 288, 290], "learn": [1, 2, 4, 5, 195, 196, 197, 198, 199, 200, 203, 204, 215, 223, 224, 225, 247, 249], "step": [1, 3, 4, 15, 213], "involv": [1, 284], "ad": [1, 2, 6, 195, 196, 197, 198, 199, 203, 224, 282, 290, 293], "let": [1, 2, 3, 288, 290, 291], "s": [1, 2, 3, 4, 35, 44, 72, 83, 84, 86, 87, 89, 90, 92, 93, 98, 110, 112, 125, 134, 138, 141, 153, 156, 157, 187, 188, 190, 194, 201, 213, 232, 233, 235, 239, 243, 284, 285, 288, 290, 291, 292, 293, 294], "sai": [1, 3, 290], "would": [1, 3, 289, 290, 291, 294], "like": [1, 3, 5, 133, 193, 220, 268, 288, 290, 291, 292, 294], "an": [1, 3, 4, 6, 8, 12, 14, 26, 61, 65, 66, 78, 81, 94, 97, 101, 110, 123, 126, 128, 132, 133, 135, 137, 138, 139, 151, 152, 153, 169, 172, 177, 178, 179, 182, 184, 190, 192, 193, 195, 201, 202, 205, 206, 213, 218, 223, 225, 226, 228, 246, 247, 248, 257, 259, 274, 285, 287, 288, 289, 290, 291, 292, 293, 294, 295], "take": [1, 3, 4, 98, 107, 124, 127, 133, 139, 179, 187, 189, 190, 193, 246, 285, 288, 289, 293, 294, 295], "two": [1, 11, 13, 60, 73, 75, 83, 86, 92, 99, 100, 108, 109, 117, 122, 124, 127, 129, 134, 177, 248, 262, 288, 289, 294], "arrai": [1, 3, 4, 5, 8, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127, 128, 129, 130, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141, 142, 143, 144, 145, 147, 148, 149, 150, 151, 152, 153, 154, 155, 156, 157, 158, 159, 162, 163, 164, 165, 166, 167, 168, 169, 170, 171, 172, 173, 174, 175, 176, 177, 178, 179, 180, 181, 182, 183, 184, 185, 186, 187, 188, 189, 190, 191, 192, 193, 213, 215, 228, 235, 238, 241, 247, 261, 262, 263, 264, 265, 266, 267, 268, 269, 270, 271, 272, 273, 274, 278, 282, 288, 290, 291, 292, 294], "x": [1, 2, 3, 4, 76, 101, 110, 139, 142, 153, 158, 162, 164, 185, 186, 191, 200, 206, 213, 215, 222, 223, 224, 225, 226, 227, 228, 247, 249, 250, 252, 254, 255, 256, 258, 259, 260, 271, 273, 274, 275, 276, 277, 278, 282, 284, 288, 289, 290, 291, 292, 294], "y": [1, 2, 3, 4, 164, 191, 196, 213, 215, 219, 223, 224, 225, 226, 249, 264, 271, 284, 288, 290, 291], "scale": [1, 3, 72, 138, 139, 219, 220, 246, 251, 252, 255, 276], "them": [1, 3, 213, 233, 243, 294], "both": [1, 11, 73, 74, 75, 99, 100, 108, 109, 110, 117, 124, 127, 129, 141, 175, 224, 284, 288, 292, 294], "some": [1, 2, 3, 4, 233, 243, 288, 290], "coeffici": [1, 195, 197, 198, 199, 200], "alpha": [1, 138, 198, 203, 252, 272, 274, 276], "beta": [1, 72, 138, 197, 198, 199, 200, 215, 223, 224, 225, 271], "respect": [1, 2, 4, 98, 138, 187, 206, 213, 215, 222, 223, 224, 225, 282, 288, 292], "togeth": [1, 4, 138, 206], "get": [1, 2, 4, 6, 66, 143, 202, 213, 288, 290, 294], "z": [1, 164, 290], "well": [1, 3, 213, 233, 243, 246, 290], "veri": [1, 3, 246, 290, 294], "easili": 1, "do": [1, 3, 6, 198, 213, 234, 243, 282, 288, 290], "just": [1, 4, 289], "write": [1, 3, 213, 291], "out": [1, 6, 219, 220, 240, 288, 289], "follow": [1, 3, 4, 5, 6, 15, 67, 72, 110, 138, 195, 196, 197, 198, 199, 200, 204, 213, 259, 260, 266, 285, 288, 294], "import": [1, 2, 3, 4, 6, 110, 158, 164, 187, 205, 206, 207, 213, 215, 224, 235, 261, 282, 288, 289, 290, 291, 292], "mx": [1, 2, 3, 4, 110, 112, 158, 164, 187, 213, 215, 224, 228, 235, 250, 261, 262, 266, 275, 282, 284, 285, 288, 289, 290, 291, 292, 293, 294, 295], "def": [1, 2, 3, 4, 164, 187, 213, 282, 288, 289, 290, 291, 294], "simple_axpbi": 1, "float": [1, 13, 15, 57, 96, 97, 110, 139, 140, 145, 148, 149, 195, 196, 197, 198, 199, 200, 203, 204, 209, 215, 218, 219, 220, 223, 224, 225, 228, 249, 251, 255, 256, 257, 262, 263, 265, 271, 272, 278], "return": [1, 2, 3, 4, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 33, 37, 50, 57, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 72, 73, 74, 75, 76, 77, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127, 128, 129, 130, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141, 142, 143, 144, 145, 147, 148, 149, 150, 151, 152, 153, 154, 162, 163, 164, 165, 166, 167, 168, 169, 170, 171, 172, 173, 174, 175, 176, 177, 178, 179, 180, 181, 182, 183, 184, 185, 186, 187, 188, 189, 190, 191, 192, 193, 194, 205, 206, 207, 213, 230, 232, 234, 236, 237, 238, 241, 248, 261, 262, 263, 264, 265, 266, 267, 268, 269, 270, 271, 272, 282, 287, 288, 289, 290, 291, 293, 294], "thi": [1, 3, 4, 6, 12, 13, 14, 15, 22, 23, 24, 25, 107, 110, 117, 121, 122, 123, 125, 126, 136, 137, 141, 164, 167, 168, 169, 176, 178, 188, 213, 218, 219, 220, 229, 230, 232, 233, 236, 237, 238, 241, 243, 244, 245, 246, 248, 256, 259, 260, 268, 278, 282, 287, 288, 290, 291, 293], "perform": [1, 3, 5, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 122, 139, 153, 167, 178, 213, 223, 257, 289, 290, 294], "leav": [1, 206], "differenti": [1, 5], "howev": [1, 213, 222, 223, 285, 290, 291], "vector": [1, 2, 5, 102, 107, 110, 178, 189, 190, 221, 292], "math": [1, 3, 272], "often": [1, 220], "realiz": 1, "axpbi": 1, "routin": 1, "defin": [1, 2, 3, 4, 6, 110, 139, 202, 205, 291], "same": [1, 3, 6, 60, 61, 66, 67, 88, 91, 92, 93, 98, 107, 135, 141, 153, 189, 191, 213, 215, 218, 223, 224, 248, 272, 282, 285, 289, 294], "realli": 1, "part": [1, 288, 289], "doe": [1, 3, 6, 213, 289, 290, 291], "fast": [1, 164, 222, 260, 294], "so": [1, 3, 6, 98, 164, 187, 218, 284, 290, 294], "decid": [1, 206, 232], "want": [1, 3, 288, 294], "reli": 1, "acceler": [1, 215], "framework": [1, 5], "continu": [1, 288], "impos": 1, "our": [1, 3, 4, 195, 196, 197, 199, 200, 253], "assumpt": 1, "also": [1, 3, 4, 5, 6, 11, 73, 74, 75, 84, 87, 90, 93, 99, 100, 108, 109, 117, 124, 127, 129, 138, 175, 194, 202, 213, 232, 244, 246, 248, 252, 254, 258, 276, 277, 279, 284, 288, 289, 290, 291, 292, 295], "assum": [1, 3, 206, 213, 223], "how": [1, 3, 4, 213, 216, 217, 221, 289, 294], "gradient": [1, 2, 4, 98, 174, 187, 194, 195, 197, 198, 199, 200, 204, 213, 233, 244, 248, 268, 282, 284, 288, 289, 290, 291, 292], "ins": 1, "what": [1, 3, 206], "coincid": 1, "right": [1, 6, 138, 222, 259, 260, 265, 272], "place": [1, 3, 153, 290, 291], "cours": [1, 288], "The": [1, 3, 4, 5, 6, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 33, 35, 44, 50, 57, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 72, 73, 74, 75, 76, 77, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127, 128, 129, 130, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141, 142, 143, 144, 145, 147, 148, 149, 150, 151, 152, 156, 157, 162, 163, 165, 166, 167, 168, 170, 171, 172, 173, 174, 175, 176, 177, 178, 179, 180, 181, 182, 183, 184, 185, 186, 187, 188, 189, 190, 191, 192, 193, 194, 195, 196, 197, 198, 199, 200, 201, 202, 203, 204, 205, 206, 207, 209, 215, 216, 217, 218, 219, 220, 221, 223, 224, 225, 226, 229, 235, 244, 245, 246, 248, 249, 251, 253, 255, 256, 257, 261, 262, 263, 264, 265, 266, 267, 268, 269, 270, 271, 272, 278, 282, 284, 288, 289, 290, 291, 292, 293, 294, 295], "structur": [1, 288], "from": [1, 3, 4, 5, 72, 89, 90, 92, 93, 97, 110, 112, 122, 133, 138, 140, 141, 142, 143, 145, 148, 158, 172, 174, 175, 178, 179, 191, 193, 205, 206, 207, 213, 226, 233, 235, 246, 271, 287, 288, 290, 291, 292, 293, 294], "frontend": 1, "api": [1, 288], "redirect": 1, "when": [1, 3, 5, 6, 110, 112, 216, 217, 266, 271, 282, 285, 294], "appropri": 1, "fallback": 1, "metal": 1, "vjp": [1, 292], "jvp": [1, 292], "In": [1, 3, 4, 122, 138, 195, 196, 197, 199, 200, 206, 213, 219, 223, 282, 287, 288, 290, 293, 294], "one": [1, 3, 6, 57, 63, 66, 80, 81, 110, 115, 122, 139, 141, 172, 175, 243, 294], "sentenc": 1, "comput": [1, 2, 3, 4, 5, 6, 72, 98, 107, 110, 117, 125, 134, 138, 164, 167, 174, 182, 187, 188, 189, 194, 195, 197, 198, 199, 200, 213, 215, 223, 224, 225, 233, 244, 248, 249, 251, 259, 260, 261, 262, 263, 264, 265, 266, 267, 268, 269, 270, 271, 272, 284, 288, 292, 294], "graph": [1, 3, 4, 5, 164, 288], "rule": 1, "evalu": [1, 3, 4, 5, 78, 107, 189, 213, 231, 240, 282, 284, 292], "said": [1, 3], "start": [1, 2, 3, 5, 6, 15, 111, 169, 289, 294], "discuss": 1, "more": [1, 4, 8, 57, 122, 156, 157, 213, 215, 219, 251, 255, 285, 288, 289, 292, 294], "detail": [1, 8, 195, 196, 197, 199, 200, 213, 219, 251, 255, 289, 292], "thei": [1, 2, 3, 67, 200, 253, 264, 282, 287, 290, 292, 293, 294], "c": [1, 3, 110, 209, 215, 216, 217, 219, 220, 224, 291, 292, 294], "scalar": [1, 11, 13, 26, 37, 57, 60, 61, 63, 73, 74, 75, 96, 97, 98, 99, 100, 108, 109, 110, 111, 117, 118, 119, 120, 122, 124, 127, 129, 135, 145, 148, 149, 175, 187, 191, 194, 272, 288, 290, 292], "sum": [1, 2, 11, 102, 110, 121, 167, 182, 213, 261, 262, 263, 264, 265, 266, 267, 268, 269, 270, 271, 272, 289, 291], "element": [1, 10, 11, 16, 17, 18, 19, 20, 21, 24, 52, 62, 68, 69, 72, 73, 74, 75, 76, 77, 79, 81, 95, 96, 99, 100, 103, 104, 105, 106, 108, 109, 113, 114, 115, 116, 117, 118, 119, 120, 124, 127, 129, 130, 136, 138, 139, 150, 151, 154, 162, 163, 165, 166, 170, 171, 175, 178, 180, 181, 187, 191, 218, 219, 220, 227, 247, 251, 254, 273, 274, 277, 288], "wise": [1, 10, 11, 16, 17, 18, 19, 20, 21, 62, 68, 69, 73, 74, 75, 76, 77, 79, 95, 96, 99, 100, 108, 109, 113, 114, 115, 116, 117, 118, 119, 120, 124, 127, 129, 130, 150, 154, 162, 163, 165, 166, 170, 171, 175, 180, 181, 219, 220, 227, 247, 254, 273, 274, 277], "numpi": [1, 3, 4, 5, 11, 13, 15, 61, 73, 74, 75, 99, 100, 108, 109, 117, 122, 124, 127, 129, 175, 290, 292, 293], "style": [1, 11, 13, 73, 74, 75, 99, 100, 108, 109, 117, 122, 124, 127, 129, 175], "broadcast": [1, 11, 13, 61, 63, 73, 74, 75, 97, 99, 100, 108, 109, 117, 122, 124, 127, 129, 140, 141, 148, 149, 175, 179, 191, 246], "between": [1, 5, 63, 257, 262, 264, 265, 268, 290, 294], "input": [1, 2, 3, 10, 11, 12, 13, 14, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 73, 74, 75, 76, 77, 79, 80, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 98, 99, 100, 102, 103, 104, 105, 106, 107, 108, 109, 110, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127, 128, 129, 130, 133, 134, 135, 136, 137, 138, 139, 147, 150, 151, 152, 153, 154, 162, 163, 165, 166, 167, 168, 169, 170, 171, 172, 173, 174, 175, 176, 177, 178, 179, 180, 181, 182, 183, 185, 186, 187, 188, 190, 191, 193, 215, 216, 217, 219, 220, 221, 223, 224, 225, 226, 246, 248, 249, 251, 256, 257, 261, 262, 264, 265, 266, 268, 270, 272, 278, 288, 289, 292, 293], "upcast": 1, "const": 1, "factor": [1, 263], "streamordevic": 1, "stream": [1, 5, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 28, 29, 30, 31, 32, 33, 34, 36, 38, 39, 40, 41, 42, 43, 45, 46, 47, 48, 49, 51, 53, 54, 55, 56, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 71, 72, 73, 74, 75, 76, 77, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 99, 100, 101, 102, 103, 104, 105, 106, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127, 128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141, 142, 144, 145, 147, 148, 149, 150, 151, 152, 153, 154, 161, 162, 163, 165, 166, 167, 168, 169, 170, 171, 172, 173, 174, 175, 176, 177, 178, 179, 180, 181, 182, 183, 184, 185, 186, 188, 191, 192, 193, 294], "schedul": [1, 294], "itself": 1, "call": [1, 3, 4, 27, 96, 213, 221, 233, 243, 253, 282, 284, 288, 290], "other": [1, 3, 5, 110, 200, 213, 234, 282, 289, 290, 292], "within": [1, 24], "simplest": [1, 213], "wai": [1, 3, 6, 213, 288, 289], "about": [1, 3, 4, 290, 294], "term": [1, 195, 196, 197, 198, 199, 203], "exist": [1, 3, 233, 243], "auto": [1, 6], "ax": [1, 12, 14, 22, 23, 58, 80, 83, 84, 86, 87, 89, 90, 92, 93, 102, 110, 121, 123, 125, 126, 135, 137, 167, 172, 176, 177, 182, 183, 188, 288], "multipli": [1, 138, 139, 218, 255], "earlier": 1, "goal": 1, "themselv": 1, "contain": [1, 3, 50, 88, 89, 90, 110, 118, 119, 120, 138, 169, 191, 213, 232, 234, 235, 257, 282, 288], "act": [1, 268], "data": [1, 4, 5, 8, 15, 81, 91, 92, 97, 101, 111, 132, 148, 184, 192, 220, 289, 291], "nor": [1, 98, 187], "rather": [1, 288, 294], "easi": [1, 213], "interfac": 1, "block": [1, 3, 257], "A": [1, 3, 5, 6, 50, 60, 98, 107, 110, 112, 121, 122, 138, 140, 141, 142, 144, 145, 148, 149, 169, 173, 187, 189, 190, 194, 197, 199, 205, 206, 207, 213, 215, 219, 223, 224, 225, 227, 232, 236, 237, 244, 245, 249, 253, 255, 257, 260, 272, 273, 282, 284, 288, 290, 291], "It": [1, 3, 6, 98, 187, 199, 201, 213, 245, 248, 291, 293], "creat": [1, 3, 6, 81, 101, 213, 282, 284, 289, 291], "output": [1, 3, 6, 12, 13, 14, 15, 22, 23, 24, 61, 81, 88, 91, 92, 93, 97, 98, 101, 110, 111, 121, 123, 125, 126, 132, 133, 136, 137, 140, 141, 142, 144, 145, 148, 149, 158, 159, 167, 172, 176, 179, 184, 187, 188, 189, 190, 191, 192, 193, 215, 216, 217, 224, 226, 246, 248, 256, 257, 261, 262, 263, 264, 265, 266, 267, 268, 269, 270, 271, 272, 278, 288, 289, 290, 291, 292, 293, 294], "given": [1, 12, 14, 24, 61, 63, 64, 72, 78, 80, 82, 83, 84, 85, 86, 87, 91, 92, 93, 97, 110, 121, 123, 125, 126, 137, 145, 153, 167, 169, 176, 184, 185, 186, 188, 218, 232, 246, 262], "set": [1, 3, 4, 6, 202, 222, 226, 231, 233, 240, 243, 244, 248, 251, 256, 262, 272, 278, 282, 285, 288, 290], "further": [1, 6, 288], "class": [1, 3, 4, 7, 8, 9, 26, 195, 196, 197, 198, 199, 200, 201, 202, 203, 204, 214, 215, 216, 217, 218, 219, 220, 221, 222, 223, 224, 225, 226, 227, 246, 247, 248, 249, 250, 251, 252, 253, 254, 255, 256, 257, 258, 259, 260, 261, 262, 263, 264, 265, 266, 267, 268, 269, 270, 271, 272, 273, 274, 275, 276, 277, 278, 282], "under": [1, 110], "These": [1, 179, 294], "word": 1, "bit": [1, 72, 138, 139, 209, 228, 248], "abstract": 1, "back": [1, 3, 291], "give": [1, 3, 4, 24], "ourselv": 1, "concret": [1, 226, 290, 294], "imag": [1, 217, 219, 220], "public": [1, 213], "explicit": [1, 285, 291], "alpha_": 1, "beta_": 1, "must": [1, 6, 63, 78, 97, 110, 140, 141, 145, 148, 149, 191, 291], "know": [1, 3], "popul": 1, "To": [1, 2, 3, 4, 6, 213, 288, 292], "avoid": 1, "unnecessari": [1, 3], "alloc": [1, 282], "respons": 1, "space": [1, 111, 270], "void": 1, "eval_cpu": 1, "std": 1, "overrid": 1, "eval_gpu": 1, "jacobian": [1, 107, 189, 292], "product": [1, 102, 107, 122, 134, 137, 182, 189, 246, 292], "primal": [1, 107, 189], "tangent": [1, 20, 21, 107, 180, 181], "int": [1, 3, 4, 7, 9, 12, 14, 15, 22, 23, 24, 25, 29, 30, 31, 32, 40, 41, 42, 43, 45, 48, 50, 53, 56, 57, 59, 61, 64, 65, 66, 72, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 97, 98, 101, 110, 111, 121, 123, 125, 126, 128, 132, 135, 136, 137, 138, 139, 140, 141, 142, 143, 144, 145, 146, 147, 148, 149, 151, 152, 153, 167, 168, 169, 172, 173, 176, 177, 178, 179, 182, 183, 184, 185, 186, 187, 188, 190, 192, 213, 215, 216, 217, 221, 223, 224, 225, 226, 246, 248, 249, 251, 255, 257, 262, 263, 266, 270, 272, 282], "argnum": [1, 98, 187, 288], "cotan": 1, "across": [1, 223], "pair": [1, 135, 235, 251], "repres": [1, 3, 272, 291], "axi": [1, 3, 4, 12, 14, 22, 23, 24, 25, 29, 30, 31, 32, 40, 41, 42, 43, 45, 53, 56, 59, 64, 80, 82, 85, 88, 89, 90, 91, 92, 93, 110, 121, 123, 125, 126, 128, 135, 136, 137, 141, 151, 167, 168, 169, 172, 173, 176, 177, 178, 179, 183, 188, 190, 262, 263, 266, 270, 272, 289], "correspond": [1, 12, 14, 57, 63, 72, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 121, 123, 126, 137, 176, 182, 190, 206, 288], "dimens": [1, 3, 12, 14, 22, 23, 44, 50, 57, 66, 80, 89, 90, 92, 93, 94, 102, 110, 121, 122, 123, 125, 126, 137, 138, 141, 147, 176, 179, 182, 183, 188, 215, 216, 217, 219, 220, 223, 224, 225, 246, 249, 251, 257, 288], "vmap": [1, 288, 290, 292], "print": [1, 2, 3, 4, 6, 205, 206, 207, 213, 285, 288, 289, 290, 291, 292], "ostream": 1, "os": [1, 6], "equival": [1, 27, 47, 58, 74, 96, 178, 222, 245, 247, 248], "check": [1, 6, 60, 235, 288, 289], "bool": [1, 12, 14, 22, 23, 29, 30, 31, 32, 40, 41, 42, 43, 45, 56, 57, 59, 60, 110, 121, 123, 125, 126, 137, 139, 140, 145, 148, 149, 176, 188, 204, 215, 216, 217, 223, 224, 225, 226, 228, 232, 233, 235, 240, 243, 246, 248, 251, 255, 257], "is_equival": 1, "privat": 1, "fall": 1, "eval": [1, 2, 3, 4, 164, 213, 282, 284, 288, 290, 292], "deriv": [1, 288, 290], "base": [1, 110, 114, 116, 199, 201, 251, 257, 282, 284, 285, 289], "abov": [1, 3, 6, 138, 185, 198, 213, 288, 289, 290, 294], "demonstr": [1, 291], "treat": [1, 60, 89, 90, 92, 93, 178], "paramet": [1, 2, 3, 4, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 33, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127, 128, 129, 130, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141, 142, 143, 144, 145, 146, 147, 148, 149, 150, 151, 152, 153, 154, 155, 156, 157, 158, 159, 162, 163, 164, 165, 166, 167, 168, 169, 170, 171, 172, 173, 174, 175, 176, 177, 178, 179, 180, 181, 182, 183, 184, 185, 186, 187, 188, 189, 190, 191, 192, 193, 194, 195, 196, 197, 198, 199, 200, 201, 203, 204, 205, 206, 207, 215, 216, 217, 218, 219, 220, 221, 222, 223, 224, 225, 226, 228, 229, 232, 233, 235, 240, 243, 244, 245, 246, 247, 248, 249, 251, 253, 255, 256, 257, 261, 262, 263, 264, 265, 266, 267, 268, 269, 270, 271, 272, 278, 279, 282, 284, 288, 290], "produc": [1, 246], "through": [1, 174, 200, 257, 288, 291], "construct": [1, 4, 97, 132, 192], "its": [1, 6, 122, 136, 147, 164, 184, 194, 197, 198, 199, 207, 213, 248, 291, 294], "type": [1, 5, 6, 7, 8, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 33, 50, 57, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 72, 73, 74, 75, 76, 77, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127, 128, 129, 130, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141, 142, 143, 144, 145, 147, 148, 149, 150, 151, 152, 153, 154, 162, 163, 165, 166, 167, 168, 169, 170, 171, 172, 173, 174, 175, 176, 177, 178, 179, 180, 181, 182, 183, 184, 185, 186, 187, 188, 189, 190, 191, 192, 193, 201, 205, 213, 251, 257, 261, 262, 263, 264, 265, 266, 267, 268, 269, 270, 271, 272, 289], "shape": [1, 3, 4, 47, 60, 61, 65, 66, 82, 85, 88, 91, 92, 93, 97, 107, 122, 132, 133, 140, 141, 142, 144, 145, 148, 149, 152, 179, 189, 191, 192, 193, 213, 215, 216, 217, 219, 220, 224, 226, 235, 272, 284, 288, 289, 292, 294], "pass": [1, 3, 4, 47, 58, 134, 135, 187, 194, 205, 206, 213, 233, 243, 244, 245, 248, 253, 290], "re": [1, 4, 6], "now": [1, 3, 6, 248, 291], "promot": 1, "dtype": [1, 3, 15, 26, 33, 57, 81, 97, 101, 110, 111, 132, 142, 144, 145, 148, 149, 184, 192, 209, 261, 288, 289, 291, 292, 293], "promoted_dtyp": 1, "promote_typ": 1, "float32": [1, 15, 81, 101, 110, 111, 132, 142, 144, 148, 149, 184, 192, 209, 261, 288, 289, 290, 291, 292, 293], "non": [1, 6, 227, 241, 273, 282], "point": [1, 2, 3, 6, 96, 139, 209], "out_dtyp": 1, "is_floating_point": 1, "cast": [1, 33, 91, 92, 93, 112, 228, 291], "up": [1, 3, 248], "determin": [1, 293], "x_cast": 1, "astyp": [1, 3, 228, 291], "y_cast": 1, "broadcasted_input": 1, "broadcast_arrai": 1, "out_shap": 1, "0": [1, 2, 3, 4, 6, 7, 15, 48, 53, 59, 64, 65, 66, 81, 94, 98, 110, 135, 140, 149, 151, 153, 169, 173, 184, 185, 186, 187, 188, 190, 195, 197, 198, 199, 200, 203, 204, 205, 213, 215, 216, 217, 218, 219, 220, 222, 223, 224, 225, 247, 250, 251, 252, 255, 256, 257, 259, 260, 261, 263, 264, 265, 271, 272, 274, 275, 276, 278, 282, 285, 288, 289, 290, 291, 292, 293], "unique_ptr": 1, "make_uniqu": 1, "to_stream": 1, "handl": [1, 213], "resolv": 1, "No": [1, 3], "happen": [1, 3, 257, 284, 290], "alon": [1, 291], "effect": [1, 219, 290], "onli": [1, 3, 5, 6, 60, 65, 66, 110, 138, 209, 213, 232, 233, 235, 240, 243, 244, 245, 282, 288, 293, 294], "execut": [1, 6, 291, 294], "depend": [1, 2, 57, 110, 289, 293, 294], "devic": [1, 5, 6, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 28, 29, 30, 31, 32, 33, 34, 36, 38, 39, 40, 41, 42, 43, 45, 46, 47, 48, 49, 51, 53, 54, 55, 56, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 99, 100, 101, 102, 103, 104, 105, 106, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127, 128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141, 142, 144, 145, 147, 148, 149, 150, 151, 152, 153, 154, 160, 162, 163, 165, 166, 167, 168, 169, 170, 171, 172, 173, 174, 175, 176, 177, 178, 179, 180, 181, 182, 183, 184, 185, 186, 188, 191, 192, 193, 294, 295], "specifi": [1, 15, 33, 66, 89, 90, 97, 98, 110, 111, 128, 132, 141, 151, 177, 178, 179, 182, 183, 187, 190, 192, 215, 256, 261, 262, 263, 264, 265, 266, 267, 268, 269, 270, 271, 272, 278, 288, 294], "memori": [1, 5, 164, 282, 290, 291], "ha": [1, 3, 4, 5, 57, 88, 89, 91, 92, 93, 98, 141, 215, 226, 282, 284, 289, 290, 292, 294], "been": [1, 3, 290], "try": [1, 6], "naiv": [1, 288], "gener": [1, 2, 15, 81, 89, 90, 111, 140, 144, 145, 148, 149, 257, 285, 289, 290, 295], "version": [1, 6, 72, 117, 121, 138, 167, 190, 285, 288, 289], "declar": 1, "member": [1, 213, 238, 241], "method": [1, 3, 7, 8, 9, 26, 195, 196, 197, 198, 199, 200, 201, 202, 203, 204, 213, 282], "each": [1, 50, 72, 78, 122, 135, 138, 139, 141, 151, 158, 159, 169, 183, 190, 191, 219, 220, 221, 223, 251, 257, 263, 285, 290], "find": [1, 2, 6], "pointwis": 1, "captur": [1, 213], "templat": 1, "axpby_impl": 1, "typenam": 1, "t": [1, 3, 76, 139, 187, 195, 196, 197, 198, 199, 200, 203, 204, 213, 288, 294], "readi": 1, "fill": [1, 97, 133, 184, 193], "malloc_or_wait": 1, "synchron": 1, "avail": [1, 2, 3, 4, 6, 8, 209, 294], "There": [1, 213], "wait": [1, 3], "here": [1, 3, 288, 290, 293, 294], "request": 1, "pressur": 1, "condit": [1, 191, 294], "set_data": 1, "nbyte": 1, "collect": [1, 202, 206, 287], "pointer": 1, "x_ptr": 1, "y_ptr": 1, "out_ptr": 1, "relev": 1, "static_cast": 1, "size_t": 1, "out_idx": 1, "size": [1, 3, 4, 50, 66, 72, 80, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 97, 101, 110, 138, 139, 141, 152, 169, 172, 213, 216, 217, 221, 224, 248, 290, 291], "map": [1, 4, 112, 206, 221, 228], "linear": [1, 3, 4, 5, 206, 213, 222, 235, 248, 250, 252, 254, 258, 259, 260, 275, 276, 277, 282], "indic": [1, 13, 22, 23, 24, 25, 98, 103, 104, 105, 106, 169, 178, 179, 187, 240, 242, 263, 289], "offset": [1, 3], "x_offset": 1, "elem_to_loc": 1, "stride": [1, 65, 66, 216, 217, 251, 289], "y_offset": 1, "contigu": 1, "regularli": 1, "default": [1, 6, 12, 14, 15, 22, 23, 24, 25, 60, 64, 65, 66, 72, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 98, 101, 110, 111, 112, 121, 123, 125, 126, 132, 136, 137, 138, 139, 140, 141, 142, 144, 145, 147, 148, 149, 151, 152, 153, 168, 169, 172, 173, 176, 182, 183, 184, 185, 186, 187, 188, 190, 192, 195, 196, 197, 198, 199, 200, 202, 203, 204, 209, 215, 216, 217, 224, 226, 228, 233, 235, 240, 243, 246, 247, 248, 251, 255, 257, 261, 262, 263, 264, 265, 266, 267, 268, 269, 270, 271, 272, 282, 285, 287, 288, 291, 293, 295], "row": [1, 81, 101, 138, 184], "major": 1, "henc": [1, 138], "doesn": [1, 213], "addit": [1, 3, 11, 215, 223, 225, 246, 249, 282, 288], "abl": [1, 138], "all": [1, 4, 6, 13, 24, 66, 78, 81, 84, 87, 90, 93, 122, 135, 136, 172, 201, 213, 228, 229, 233, 236, 237, 238, 241, 243, 246, 248, 255, 257, 282, 285, 289, 290, 292, 295], "incom": 1, "accordingli": 1, "dispatch": 1, "float16": [1, 112, 209, 228, 290, 291], "bfloat16": [1, 291], "complex64": 1, "throw": 1, "error": [1, 6, 76, 77, 169, 222, 248, 258, 259, 260, 268, 269, 288, 291], "encount": [1, 288], "unexpect": [1, 15], "regist": [1, 4], "op": [1, 134, 233, 290], "assert": 1, "2": [1, 2, 3, 4, 66, 76, 83, 86, 88, 89, 90, 91, 92, 93, 110, 116, 122, 138, 147, 182, 184, 185, 186, 195, 196, 197, 198, 203, 209, 213, 217, 222, 249, 255, 259, 265, 271, 272, 282, 288, 289, 290, 291, 292, 293, 294], "1": [1, 3, 4, 15, 24, 25, 65, 66, 82, 83, 85, 86, 88, 89, 90, 91, 92, 93, 94, 102, 110, 122, 134, 136, 138, 141, 149, 162, 168, 178, 187, 195, 196, 197, 198, 199, 200, 203, 204, 209, 213, 215, 216, 217, 218, 219, 220, 222, 223, 224, 225, 226, 247, 249, 251, 252, 255, 256, 259, 260, 261, 262, 263, 264, 265, 266, 268, 270, 271, 272, 276, 278, 282, 284, 288, 289, 291, 292, 293, 294], "correct": [1, 6, 197, 198, 199, 289, 290], "els": [1, 3, 213, 233, 290], "float16_t": 1, "bfloat16_t": 1, "complex64_t": 1, "runtime_error": 1, "support": [1, 3, 5, 6, 13, 65, 66, 112, 122, 138, 288, 289, 291, 293], "have": [1, 3, 6, 60, 89, 90, 92, 93, 122, 141, 200, 205, 246, 253, 287, 289, 290, 294], "rememb": 1, "3": [1, 3, 6, 110, 200, 285, 289, 291, 292], "complic": 1, "keep": [1, 12, 14, 22, 23, 121, 123, 125, 126, 137, 176, 188, 213, 232, 288, 290], "mind": [1, 3], "half": [1, 15, 145, 149, 251, 290], "precis": [1, 3, 213, 222], "direct": [1, 3, 200, 230, 294], "fix": [1, 3, 6, 290], "possibl": [1, 3, 122, 169, 221, 289, 294], "due": 1, "transpos": [1, 3, 27, 139], "aren": 1, "guarante": 1, "fit": [1, 138, 294], "requir": [1, 3, 213, 290, 291], "column": [1, 81, 101, 138], "inplac": 1, "expect": [1, 3, 216, 217, 218, 219, 220, 255, 257, 289], "answer": 1, "copi": [1, 3, 5, 136, 168, 291], "simpli": [1, 3, 6, 250, 275, 282, 288], "catlas_saxpbi": 1, "axpby_impl_acceler": 1, "first": [1, 2, 3, 4, 6, 94, 98, 118, 120, 122, 136, 147, 177, 182, 187, 197, 198, 199, 205, 213, 223, 262, 288, 291, 294], "mode": [1, 67, 231, 240, 242], "i": [1, 3, 107, 110, 198, 213, 216, 217, 219, 220, 233, 268, 288], "e": [1, 4, 6, 76, 107, 162, 196, 215, 216, 217, 219, 220, 223, 224, 225, 233, 249, 279, 284, 290, 295], "match": [1, 6, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 235, 289, 291], "transposit": 1, "data_s": 1, "items": 1, "flag": [1, 291], "copy_inplac": 1, "copytyp": 1, "n": [1, 3, 26, 65, 66, 81, 82, 84, 85, 87, 88, 91, 93, 101, 184, 188, 215, 216, 217, 219, 220, 268, 272], "incx": 1, "inci": 1, "great": 1, "But": [1, 294], "criteria": 1, "luckili": [1, 290], "alwai": [1, 205, 288], "With": 1, "final": [1, 2, 3, 4], "singl": [1, 4, 78, 107, 112, 135, 189, 289, 293], "row_contigu": 1, "col_contigu": 1, "common": [1, 290], "hit": 1, "mileston": 1, "enough": [1, 290], "run": [1, 3, 4, 5, 6, 134, 164, 195, 197, 198, 199, 215, 228, 290, 294, 295], "If": [1, 3, 6, 12, 14, 15, 22, 23, 24, 25, 57, 60, 63, 64, 67, 78, 91, 92, 93, 96, 97, 98, 110, 112, 121, 122, 123, 125, 126, 132, 135, 136, 137, 141, 151, 167, 168, 169, 176, 178, 179, 182, 187, 188, 190, 192, 206, 215, 216, 217, 223, 225, 226, 233, 235, 243, 248, 251, 253, 255, 272, 288, 290, 293, 294, 295], "plan": 1, "stop": [1, 3, 15, 111, 174, 288, 289], "enjoi": 1, "speed": 1, "appl": [1, 3, 5, 6, 294], "silicon": [1, 3, 5, 6, 294], "address": 1, "shade": 1, "languag": [1, 209], "kernel": [1, 65, 66, 289], "written": 1, "help": [1, 3, 294], "resourc": 1, "walkthrough": 1, "pipelin": 1, "specif": [1, 6, 288], "cpp": 1, "algorithm": [1, 200], "launch": [1, 289], "exactli": [1, 3, 235, 288], "mani": [1, 169, 216, 217, 221, 290], "thread": 1, "pick": 1, "updat": [1, 2, 3, 4, 198, 200, 204, 206, 215, 228, 235, 245, 284, 290], "assign": [1, 282], "axpby_gener": 1, "buffer": [1, 291], "constant": [1, 3, 6, 135, 203, 215, 223, 225, 249, 272, 291], "4": [1, 3, 72, 110, 138, 139, 158, 209, 215, 224, 248, 257, 289, 292, 294], "5": [1, 2, 3, 6, 110, 140, 203, 215, 218, 219, 220, 224, 271, 288, 289], "x_stride": 1, "6": [1, 3, 110, 158, 203, 257, 259, 260, 272, 289, 292], "y_stride": 1, "7": [1, 3, 110, 138, 289], "ndim": [1, 110], "8": [1, 3, 6, 110, 138, 195, 196, 197, 198, 199, 203, 209, 224, 257, 262, 289, 292, 294], "uint": 1, "index": [1, 5, 7, 9, 24, 80, 81, 98, 136, 178, 179, 187], "thread_position_in_grid": 1, "convert": [1, 57, 248, 290, 291, 292], "instanti": [1, 4, 290], "uniqu": [1, 285], "host": 1, "name": [1, 112, 138, 139, 156, 157, 158, 159, 202, 213, 223, 232, 235, 237, 289, 293], "identifi": [1, 205, 287], "instantiate_axpbi": 1, "type_nam": 1, "host_nam": 1, "axpby_general_": 1, "compil": [1, 6, 290], "mlx_ext": 1, "metallib": [1, 6], "see": [1, 3, 4, 6, 8, 28, 29, 30, 31, 32, 34, 36, 38, 39, 40, 41, 42, 43, 45, 46, 47, 48, 49, 51, 53, 54, 55, 56, 58, 59, 110, 156, 157, 213, 215, 219, 222, 231, 247, 248, 251, 252, 255, 258, 259, 260, 276, 288, 289, 292, 294], "later": [1, 6], "co": [1, 255, 288], "locat": [1, 244, 245, 294], "share": [1, 5, 72, 138, 139], "register_librari": 1, "potenti": 1, "path": [1, 6, 158, 159, 235], "tri": 1, "load": [1, 4, 5, 235], "hasn": 1, "alreadi": [1, 3], "static": [1, 6], "object": [1, 8, 26, 37, 57, 140, 145, 148, 149, 190, 205, 206, 219, 287], "why": [1, 3], "packag": [1, 2, 4], "process": [1, 3, 67, 206, 220, 221, 257, 287], "logic": [1, 118, 119, 120], "grid": 1, "shown": 1, "below": [1, 6, 110, 184, 186, 209, 290], "prepar": [1, 3], "carri": 1, "should": [1, 2, 3, 4, 6, 107, 138, 164, 179, 187, 189, 205, 213, 216, 217, 219, 220, 240, 246, 253, 264, 282, 287, 288, 290, 291, 295], "d": [1, 3, 102, 110, 122, 134, 178, 184, 185, 186, 195, 197, 199, 207, 220, 294], "ostringstream": 1, "kname": 1, "axpby_": 1, "general_": 1, "type_to_nam": 1, "make": [1, 3, 4, 6, 122, 213, 290, 292, 294], "sure": [1, 3, 6, 213], "look": [1, 3], "folder": 1, "get_colocated_mtllib_path": 1, "get_kernel": 1, "str": [1, 67, 98, 110, 112, 155, 156, 157, 158, 159, 187, 205, 207, 228, 229, 232, 233, 235, 237, 239, 243, 261, 262, 263, 264, 265, 266, 267, 268, 269, 270, 271, 272], "encod": [1, 251, 255, 257], "compute_encod": 1, "get_command_encod": 1, "setcomputepipelinest": 1, "those": [1, 3, 213], "nelem": 1, "set_array_buff": 1, "setbyt": 1, "sizeof": 1, "threadgroup": 1, "higher": [1, 102, 288], "than": [1, 3, 57, 67, 74, 99, 100, 108, 109, 122, 200, 206, 251, 256, 271, 278, 288, 294], "max": [1, 110, 124, 199, 247, 262, 264, 272, 274, 288, 294], "allow": [1, 201, 213, 245, 282, 289, 292], "tgp_size": 1, "min": [1, 110, 127, 247, 274], "maxtotalthreadsperthreadgroup": 1, "3d": [1, 215, 220], "mtl": 1, "group_dim": 1, "grid_dim": 1, "divid": [1, 96, 138], "among": 1, "dispatchthread": 1, "few": [1, 3, 4, 5, 164, 290, 292], "thing": [1, 3], "note": [1, 3, 6, 13, 65, 66, 89, 90, 110, 138, 141, 213, 291, 293], "befor": [1, 3, 6, 24, 136, 232, 257, 289, 290], "move": [1, 128, 294], "track": [1, 213, 215], "activ": [1, 6, 219, 227, 256, 257, 273, 278, 279], "command": [1, 6], "instead": [1, 6, 213, 245, 255, 288, 290], "end_encod": 1, "end": [1, 138, 252, 256, 265, 271, 276, 278], "until": [1, 290, 292], "limit": [1, 63, 289], "flush": 1, "enqueu": 1, "commit": 1, "associ": [1, 158, 159, 290], "suggest": 1, "deeper": 1, "dive": 1, "studi": 1, "come": [1, 3, 288], "far": [1, 284], "built": [1, 6, 290], "includ": [1, 229, 248, 288, 289, 292, 293, 295], "forward": [1, 187, 290], "diff": 1, "push": 1, "along": [1, 22, 23, 64, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 110, 151, 167, 169, 173, 178, 179, 182, 213], "similarli": [1, 6, 122, 288, 290], "scale_arr": 1, "contribut": 1, "tangent_x": 1, "tangent_i": 1, "revers": [1, 183, 255], "arg": [1, 3, 8, 47, 58, 78, 158, 159, 164], "push_back": 1, "fulli": [1, 5, 291, 294], "overal": 1, "directori": [1, 3, 6], "extens": [1, 112, 209, 293], "h": [1, 65, 66, 110, 215, 217, 219, 220, 288, 290], "mlx_sample_extens": 1, "__init__": [1, 3, 4, 7, 8, 9, 26, 213, 282], "py": [1, 3, 6], "cmakelist": 1, "txt": 1, "setup": [1, 2, 4, 6], "hold": [1, 3, 8, 110, 201], "instal": 1, "pybind11": [1, 6], "sinc": [1, 3, 4, 200, 282, 291, 294], "compon": [1, 3], "etc": [1, 138, 213], "becom": 1, "pybind11_modul": 1, "m": [1, 6, 81, 110, 184, 195], "doc": [1, 4], "sampl": [1, 2, 3, 111, 140, 141, 142, 145, 148, 149, 272, 285], "_a": 1, "pos_onli": 1, "kw_onli": 1, "none": [1, 3, 7, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 28, 29, 30, 31, 32, 33, 34, 36, 38, 39, 40, 41, 42, 43, 45, 46, 47, 48, 49, 51, 53, 54, 55, 56, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127, 128, 129, 130, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141, 142, 144, 145, 146, 147, 148, 149, 150, 151, 152, 153, 154, 160, 161, 162, 163, 164, 165, 166, 167, 168, 169, 170, 171, 172, 173, 174, 175, 176, 177, 178, 179, 180, 181, 182, 183, 184, 185, 186, 187, 188, 190, 191, 192, 193, 205, 206, 222, 228, 232, 233, 243, 246, 255, 257, 261, 262, 263, 264, 265, 266, 267, 268, 269, 270, 271, 272, 289], "r": [1, 3, 187, 219], "pbdoc": 1, "most": [1, 141, 213, 288, 289, 290], "complex": [1, 89, 90, 91, 92, 93, 140, 145, 148, 149, 205, 213, 245, 288], "bell": 1, "whistl": 1, "liter": 1, "string": [1, 291, 293], "modul": [1, 3, 4, 194, 248, 253, 257, 287, 290], "ensur": [1, 6, 268], "caster": 1, "find_packag": 1, "config": 1, "add_librari": 1, "sourc": [1, 128, 183], "target_sourc": 1, "cmake_current_list_dir": 1, "header": 1, "target_include_directori": 1, "target_link_librari": 1, "attach": 1, "conveni": [1, 4], "mlx_build_metallib": 1, "target": [1, 187, 261, 263, 264, 265, 266, 267, 268, 269, 270, 271], "destin": [1, 128], "automat": [1, 5, 112, 292, 293, 294], "practic": 1, "mlx_build_met": [1, 6], "mlx_ext_metallib": 1, "titl": 1, "include_dir": 1, "project_source_dir": 1, "mlx_include_dir": 1, "output_directori": 1, "cmake_library_output_directori": 1, "add_depend": 1, "endif": 1, "pybind11_add_modul": 1, "build_shared_lib": 1, "target_link_opt": 1, "wl": 1, "rpath": 1, "loader_path": 1, "onc": [1, 164], "describ": [1, 290], "util": [1, 3, 5, 6, 158, 213], "__name__": [1, 3], "__main__": [1, 3], "descript": [1, 3, 209], "ext_modul": 1, "cmakeextens": 1, "cmdclass": 1, "build_ext": 1, "cmakebuild": 1, "package_dir": 1, "package_data": 1, "dylib": 1, "zip_saf": 1, "fals": [1, 3, 12, 14, 22, 23, 29, 30, 31, 32, 40, 41, 42, 43, 45, 56, 59, 60, 110, 121, 123, 125, 126, 137, 176, 188, 191, 204, 205, 206, 209, 223, 224, 226, 233, 235, 243, 246, 248, 251, 255, 257, 291], "python_requir": 1, "even": [1, 3, 290, 291], "though": [1, 3, 290, 291], "j8": 1, "libmlx_ext": 1, "cpython": 1, "3x": 1, "darwin": 1, "pip": [1, 6], "after": [1, 3, 4, 24, 96, 136, 138, 215, 223, 225, 246, 257, 271, 294], "plai": [1, 3], "ones": [1, 3, 133, 158, 164, 184, 244, 245, 248, 289], "b": [1, 3, 11, 13, 60, 73, 74, 75, 96, 99, 100, 102, 108, 109, 110, 117, 118, 120, 122, 124, 127, 129, 134, 138, 175, 182, 187, 226, 288, 289, 290, 291, 292, 293, 294], "f": [1, 2, 4, 110, 198, 213, 291], "item": [1, 2, 3, 4, 206, 290, 291, 292], "true": [1, 2, 3, 60, 110, 139, 167, 191, 205, 206, 209, 213, 215, 216, 217, 223, 224, 225, 226, 232, 233, 235, 240, 243, 248, 251, 255, 257, 268], "quick": [1, 5], "benchmark": 1, "compar": [1, 60], "time": [1, 3, 6, 164, 213, 288, 290, 294], "set_default_devic": 1, "256": [1, 4], "512": [1, 3, 257, 294], "random": [1, 2, 3, 4, 5, 215, 224, 235, 240, 288, 294, 295], "normal": [1, 2, 3, 148, 202, 215, 223, 224, 225, 249, 257, 291, 294], "bench": 1, "warm": 1, "rang": [1, 2, 3, 4, 6, 15, 111, 259, 260, 284, 285, 288, 290, 294], "100": [1, 2, 3, 288, 290, 294], "5000": 1, "simple_tim": 1, "custom_tim": 1, "3f": [1, 4], "custom": [1, 257], "114": 1, "109": 1, "modest": 1, "improv": [1, 3, 195, 196, 197, 198, 199, 203], "awai": [1, 3], "good": [1, 6, 294], "nn": [1, 3, 4, 158, 206, 213, 282, 284, 290], "grad": [1, 2, 4, 187, 284, 288, 289, 290, 292], "simplifi": [1, 290], "full": [1, 4, 47, 58, 67, 167, 244, 245, 290], "implement": [2, 4, 110, 195, 196, 197, 198, 199, 200, 201, 202, 203, 221, 232, 246, 251, 253, 255, 256, 257, 278, 288, 291], "basic": [2, 153, 288], "model": [2, 4, 5, 158, 194, 206, 213, 228, 231, 233, 235, 239, 240, 242, 243, 244, 246, 257, 282, 284, 290], "problem": [2, 4, 213], "metadata": 2, "num_featur": [2, 215], "num_exampl": 2, "1_000": 2, "num_it": 2, "10_000": 2, "iter": [2, 4, 206, 285, 290], "sgd": [2, 4, 200, 284], "lr": [2, 200], "01": [2, 198], "rate": [2, 195, 196, 197, 198, 199, 200, 203, 204], "ll": [2, 4, 265, 288], "synthet": 2, "dataset": [2, 290], "matrix": [2, 72, 81, 101, 110, 122, 138, 139, 248], "ground": [2, 3, 271], "truth": [2, 271], "w_star": 2, "valu": [2, 3, 10, 15, 22, 23, 37, 57, 60, 63, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 97, 110, 111, 135, 140, 141, 142, 144, 145, 148, 149, 178, 179, 187, 190, 194, 198, 202, 205, 206, 209, 218, 219, 220, 224, 226, 232, 246, 247, 251, 256, 257, 261, 262, 263, 264, 265, 267, 268, 269, 270, 271, 278, 282, 288], "gaussian": [2, 222, 258, 259, 260], "nois": 2, "exampl": [2, 3, 4, 15, 110, 178, 213, 215, 224, 233, 235, 240, 243, 261, 284, 285, 288, 289, 290, 291, 292, 293], "noisi": 2, "label": [2, 263], "ep": [2, 195, 196, 197, 198, 199, 203, 215, 223, 224, 225, 249, 262, 272], "1e": [2, 4, 13, 195, 196, 197, 198, 199, 203, 215, 223, 224, 225, 249, 262, 272], "us": [2, 3, 4, 5, 6, 15, 72, 74, 94, 110, 122, 138, 139, 151, 152, 195, 197, 198, 199, 200, 201, 205, 213, 219, 221, 222, 226, 228, 232, 244, 245, 246, 248, 251, 255, 257, 259, 260, 262, 282, 284, 285, 287, 288, 289, 292, 294], "weight": [2, 65, 66, 198, 200, 204, 206, 213, 235, 239, 248, 263, 282, 288, 290], "squar": [2, 3, 101, 154, 170, 187, 195, 197, 198, 199, 206, 213, 249, 269, 271, 288, 291], "loss": [2, 4, 187, 213, 284, 288, 290], "loss_fn": [2, 4, 284, 288], "w": [2, 66, 72, 138, 139, 187, 204, 215, 217, 219, 220, 226, 288], "mean": [2, 3, 4, 187, 213, 215, 223, 233, 249, 261, 262, 263, 264, 265, 266, 267, 268, 269, 270, 271, 272, 288, 291], "grad_fn": [2, 288], "initi": [2, 3, 213, 215, 223, 224, 225, 226, 247, 249, 282, 290], "randomli": [2, 3, 218, 219, 220], "Then": [2, 6], "repeatedli": 2, "_": [2, 3, 213, 285, 290, 294], "verifi": [2, 6], "close": [2, 5, 6, 13], "error_norm": 2, "5f": 2, "someth": [2, 3, 289], "00005": 2, "00364": 2, "complet": [2, 3, 6, 244, 245, 288, 294], "logist": [2, 162, 254, 259, 260, 277], "github": [2, 4, 6], "repo": [2, 4, 6], "enabl": [3, 6, 204], "larg": [3, 213, 246, 268, 290], "ish": 3, "transform": [3, 5, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 194, 213, 215, 223, 225, 226, 232, 233, 243, 248, 251, 289], "compromis": 3, "eas": 3, "llama": 3, "famili": 3, "less": [3, 24, 109, 136, 251, 271], "200": 3, "line": [3, 290, 291], "python": [3, 37, 50, 57, 78, 205, 206, 207, 282, 287, 288, 291], "neural": [3, 5, 203, 221, 227, 273, 282], "network": [3, 5, 203, 215, 219, 221, 282], "build": [3, 5, 282], "concis": 3, "architectur": [3, 6, 213, 245, 294], "notabl": [3, 5], "rope": [3, 213], "posit": [3, 24, 98, 106, 128, 136, 187, 206, 213, 216, 217, 246, 251, 255, 272], "option": [3, 12, 14, 15, 22, 23, 24, 25, 26, 31, 32, 64, 65, 66, 67, 72, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 97, 98, 101, 105, 106, 110, 111, 112, 121, 123, 125, 126, 132, 135, 136, 137, 138, 139, 140, 141, 142, 144, 145, 147, 148, 149, 151, 152, 167, 168, 169, 172, 173, 176, 178, 179, 182, 183, 184, 185, 186, 187, 188, 190, 192, 195, 196, 197, 198, 199, 200, 203, 204, 205, 206, 215, 216, 217, 226, 228, 232, 233, 235, 243, 246, 248, 251, 255, 257, 261, 262, 263, 264, 265, 266, 267, 268, 269, 270, 271, 272, 285, 293, 295], "kei": [3, 140, 141, 142, 144, 145, 147, 148, 149, 202, 205, 206, 232, 233, 243, 246, 251, 285, 287, 288], "cach": [3, 251], "concaten": 3, "project": [3, 246], "llamaattent": 3, "self": [3, 4, 7, 9, 26, 27, 28, 29, 30, 31, 32, 33, 34, 36, 37, 38, 39, 40, 41, 42, 43, 45, 46, 47, 48, 49, 51, 53, 54, 55, 56, 57, 58, 59, 213, 227, 273, 282], "dim": [3, 182, 221, 223, 224, 225, 246, 249, 251, 255, 257], "num_head": [3, 246, 257], "super": [3, 4, 213, 282], "tradit": [3, 219, 220, 251], "query_proj": 3, "bia": [3, 72, 138, 139, 197, 198, 199, 206, 213, 216, 217, 226, 233, 235, 243, 246, 248, 288], "key_proj": 3, "value_proj": 3, "out_proj": [3, 282], "__call__": [3, 4, 213, 282], "queri": [3, 246], "mask": [3, 240, 246, 289], "extract": [3, 213, 232, 282], "l": [3, 4, 213, 215, 216, 271], "reshap": [3, 110, 289], "combin": 3, "key_cach": 3, "value_cach": 3, "sqrt": [3, 76, 195, 196, 197, 198, 203, 215, 223, 224, 225, 226, 249, 255], "score": 3, "softmax": [3, 263], "values_hat": 3, "rm": [3, 6], "swiglu": 3, "rmsnorm": [3, 213], "llamaencoderlay": 3, "mlp_dim": [3, 257], "norm1": 3, "norm2": 3, "linear1": 3, "linear2": 3, "linear3": 3, "sigmoid": [3, 254, 259, 260, 261, 277], "instanc": [3, 138, 207, 213, 224, 228, 229, 230, 233, 236, 237, 243, 245, 253, 282, 291], "embed": [3, 213, 251, 255, 262], "emb": [3, 221, 255], "token": [3, 221], "num_lay": [3, 4, 284], "vocab_s": 3, "norm": [3, 199, 200, 223, 272], "multiheadattent": [3, 213], "create_additive_causal_mask": 3, "list": [3, 8, 12, 14, 26, 29, 30, 40, 41, 42, 43, 45, 50, 53, 56, 57, 59, 61, 64, 78, 80, 83, 84, 86, 87, 89, 90, 92, 93, 97, 98, 107, 110, 121, 123, 125, 126, 132, 135, 137, 140, 141, 142, 144, 145, 148, 149, 152, 167, 169, 172, 173, 176, 182, 183, 187, 188, 189, 192, 197, 198, 199, 200, 205, 207, 213, 233, 235, 236, 237, 238, 241, 243, 244, 245, 282, 287, 288, 290], "still": [3, 6, 110, 290], "consid": [3, 13, 60, 205, 206, 223, 287], "train": [3, 4, 213, 215, 218, 219, 220, 231, 233, 243], "ignor": [3, 63], "whatsoev": 3, "rest": [3, 206, 251], "subsect": 3, "prompt": 3, "autoregress": 3, "yield": [3, 4, 285], "temp": 3, "causal": 3, "save": [3, 5, 112, 138, 156, 157, 158, 159, 239, 290], "append": [3, 122, 290], "store": 3, "per": [3, 4, 72, 138, 139, 201, 215, 223, 224, 225, 249, 290], "care": [3, 290], "last": [3, 25, 57, 84, 87, 89, 90, 92, 93, 94, 102, 122, 141, 168, 182, 216, 217, 219, 220, 223, 291], "logit": [3, 141, 261, 263], "next": [3, 4], "categor": 3, "lazili": [3, 213], "noth": [3, 213, 290], "yet": [3, 110, 213, 282, 288, 289, 290, 292], "forc": [3, 4, 213, 292], "choos": [3, 251], "pars": 3, "feed": 3, "loop": [3, 4, 288, 290], "unsqueez": 3, "sequenc": [3, 215, 216, 257, 285, 294], "length": [3, 172, 215, 216], "len": [3, 84, 87, 90, 93], "overwrit": 3, "discard": [3, 205], "old": 3, "moment": [3, 197, 198, 199], "anymor": 3, "everyth": 3, "small": [3, 164, 215, 223, 225, 249, 268, 272, 294], "10": [3, 4, 114, 153, 158, 164, 206, 213, 235, 289], "12": 3, "8192": 3, "1024": 3, "actual": [3, 15, 235, 282, 290], "materi": [3, 5], "could": [3, 213], "20_000": 3, "machin": [3, 5, 6, 203], "8gb": 3, "ram": 3, "32": [3, 4, 138, 139, 209], "44": 3, "doubl": 3, "bracket": 3, "becaus": [3, 213, 290], "batch": [3, 122, 215, 216, 217, 219, 220, 246, 290], "zip": [3, 4], "haven": 3, "anyth": [3, 187, 290], "result": [3, 15, 57, 72, 102, 110, 112, 122, 134, 139, 151, 153, 173, 182, 191, 206, 255, 288, 291], "similar": [3, 206, 244, 245, 246, 262, 291, 293], "runtim": 3, "section": [3, 6, 169, 272, 288], "access": [3, 37, 213, 282, 290, 294], "origin": [3, 195, 196, 197, 199, 200, 215, 291], "sentencepiec": 3, "pytorch": [3, 5, 223, 288], "compat": [3, 141, 293], "npz": [3, 112, 158, 159, 235, 239, 293], "file": [3, 6, 112, 155, 156, 157, 158, 159, 235, 239, 288, 293], "directli": 3, "argpars": 3, "itertool": [3, 206], "starmap": [3, 206], "np": [3, 4, 291, 292], "torch": [3, 291], "map_torch_to_mlx": 3, "tok_embed": 3, "elif": 3, "replac": [3, 244, 245, 257, 271], "attention_norm": 3, "ffn_norm": 3, "wq": 3, "wk": 3, "wv": 3, "wo": 3, "w1": 3, "w2": 3, "w3": 3, "ffn": 3, "separ": [3, 47, 58, 223], "submodul": [3, 4, 213, 233, 234, 243, 245], "feed_forward": 3, "parser": 3, "argumentpars": 3, "add_argu": 3, "torch_weight": 3, "output_fil": 3, "parse_arg": 3, "state": [3, 4, 201, 202, 213, 284, 285], "savez": [3, 293], "k": [3, 81, 184, 185, 186, 226, 233], "v": [3, 67, 213, 233, 291], "left": [3, 110, 138, 222, 251, 259, 260, 265, 272], "disk": 3, "text": [3, 200, 227, 252, 256, 264, 265, 268, 271, 272, 273, 274, 276, 278], "format": [3, 112, 155, 156, 157, 158, 159, 291], "oper": [3, 5, 33, 164, 167, 174, 179, 200, 213, 257, 288, 289, 290, 291, 292, 294, 295], "dictionari": [3, 156, 157, 201, 202, 205, 213, 232, 244, 245, 287, 293], "represent": [3, 138, 205, 207], "tree_unflatten": 3, "helper": 3, "weight_fil": 3, "incur": 3, "sever": [3, 65, 66, 158, 159, 293], "futur": [3, 248, 289, 290], "pth": 3, "current": [3, 5, 6, 65, 66, 138, 213, 290], "around": 3, "m1": [3, 288, 294], "ultra": 3, "7b": 3, "me": 3, "ishmael": 3, "year": 3, "ago": 3, "never": [3, 290], "long": 3, "info": [3, 6], "247": 3, "press": [3, 110], "enter": 3, "littl": 3, "monei": 3, "my": [3, 6], "purs": 3, "greater": [3, 24, 100, 136, 256, 278], "consequ": 3, "walk": 3, "down": 3, "gower": 3, "street": 3, "afternoon": 3, "heavi": 3, "rain": 3, "saw": [3, 288], "off": [3, 6, 290], "man": 3, "rag": 3, "who": 3, "sat": 3, "upon": [3, 206], "hi": 3, "bundl": 3, "hard": 3, "wet": 3, "he": 3, "were": [3, 294], "cry": 3, "watch": 3, "him": 3, "observ": 3, "numer": [3, 110, 117, 121, 167, 195, 196, 197, 198, 199, 203, 215, 223, 224, 225, 249, 262, 272, 290], "crowd": 3, "wa": [3, 202, 290], "hurri": 3, "437": 3, "330": 3, "second": [3, 118, 120, 122, 177, 187, 197, 198, 199, 262, 288, 294], "spent": 3, "amount": 3, "39": 3, "ms": 3, "By": [3, 288, 291], "bigger": 3, "remain": [3, 187, 218, 219, 220], "almost": 3, "nobodi": 3, "took": 3, "least": [3, 63, 138], "notic": [3, 288, 293], "distanc": [3, 272], "had": 3, "doubt": 3, "minut": 3, "straight": 3, "slowli": 3, "rais": [3, 110, 169, 235], "ey": 3, "speak": [3, 110], "resum": 3, "postur": 3, "stood": 3, "feel": 3, "pain": 3, "heart": 3, "smile": 3, "face": 3, "am": 3, "someon": 3, "three": 3, "quarter": 3, "hour": 3, "made": 3, "immedi": [3, 228], "repli": 3, "again": [3, 6, 213], "hand": [3, 288, 290], "did": 3, "accustom": 3, "thu": [3, 213], "question": [3, 290], "reason": [3, 289], "tell": [3, 291], "understand": 3, "579": 3, "690": 3, "num": [3, 111, 147], "500": [3, 294], "628": 3, "went": 3, "nervou": 3, "trembl": 3, "told": 3, "And": 3, "perhap": 3, "surpris": 3, "matter": [3, 213], "shall": 3, "anyhow": 3, "friend": 3, "ye": 3, "slight": [3, 290], "kind": 3, "longer": [3, 67, 288], "soon": 3, "unless": [3, 110, 282], "unlik": [3, 13, 219, 220], "strang": 3, "amus": 3, "That": 3, "secret": 3, "disappoint": 3, "mine": 3, "cannot": [3, 63, 289, 291], "happi": 3, "ask": 3, "Is": [3, 255, 257], "shop": 3, "bui": 3, "food": 3, "633": 3, "21": 3, "475": 3, "su": 3, "j": [3, 6, 110, 196, 197, 199, 219], "lu": 3, "pan": 3, "murtadha": 3, "wen": 3, "liu": 3, "2021": 3, "roform": [3, 251], "enhanc": [3, 251, 290], "rotari": [3, 251], "arxiv": [3, 195, 200, 223, 224, 225, 227, 249, 273], "preprint": [3, 195, 200], "2104": 3, "09864": 3, "zhang": 3, "sennrich": 3, "2019": [3, 198], "root": [3, 154, 170, 249], "advanc": 3, "inform": [3, 4, 6, 156, 157, 213, 215, 222, 246, 288, 294], "system": [3, 6], "shazeer": 3, "2020": 3, "glu": 3, "variant": [3, 199, 271], "2002": 3, "05202": 3, "classifi": 4, "mnist": 4, "As": [4, 178, 213], "mlp": [4, 213, 257, 284], "inherit": [4, 287], "standard": [4, 37, 57, 122, 142, 257, 292], "idiom": 4, "input_dim": [4, 213, 226, 248], "hidden_dim": [4, 282, 284], "output_dim": [4, 213, 226, 248], "layer_s": 4, "idim": 4, "odim": 4, "maximum": [4, 22, 63, 213, 250, 255, 259, 260, 275, 282, 290], "cross": [4, 261, 263], "entropi": [4, 261, 263], "sub": [4, 147], "commonli": [4, 244], "cross_entropi": [4, 213], "accuraci": 4, "valid": [4, 67, 190, 205, 233, 243, 287], "eval_fn": 4, "argmax": 4, "loader": 4, "num_class": [4, 284], "batch_siz": [4, 284], "num_epoch": [4, 284], "learning_r": [4, 195, 196, 197, 198, 199, 200, 203, 204, 284], "train_imag": [4, 284], "train_label": [4, 284], "test_imag": 4, "test_label": 4, "shuffl": 4, "minibatch": 4, "batch_iter": [4, 284], "perm": 4, "permut": 4, "id": [4, 6], "put": 4, "trainabl": [4, 194, 213, 282], "loss_and_grad_fn": [4, 284, 288], "value_and_grad": [4, 213, 244, 282, 284, 288, 291, 292], "epoch": 4, "test": [4, 6], "confus": 4, "decent": 4, "95": 4, "brought": 5, "research": 5, "except": [5, 81, 88, 89, 91, 92, 93, 223, 235, 289, 291], "featur": [5, 65, 66, 215, 223, 224, 225, 226, 248, 249, 251, 257, 290], "main": [5, 81, 206, 213], "differ": [5, 175, 271, 288], "lazi": [5, 282, 292], "multi": [5, 216, 217, 289, 291], "cpu": [5, 294], "gpu": [5, 289, 294], "inspir": 5, "jax": [5, 285], "arrayfir": 5, "unifi": 5, "live": [5, 294], "guid": 5, "convers": 5, "regress": [5, 268], "layer": [5, 213, 219, 220, 223, 225, 226, 240, 245, 248, 253, 257, 279, 282], "perceptron": 5, "llm": 5, "infer": [5, 97, 112], "fft": 5, "algebra": 5, "tree": [5, 78, 98, 164, 187, 190, 201, 205, 206, 207, 288], "develop": [5, 6], "document": [5, 47, 58, 156, 157, 288, 289], "meet": 6, "seri": 6, "chip": 6, "nativ": 6, "maco": 6, "13": 6, "recommend": [6, 200], "14": 6, "sonoma": 6, "distribut": [6, 140, 141, 142, 144, 148, 149, 226, 266, 270, 272], "probabl": [6, 145, 218, 219, 220, 248, 266, 294], "platform": 6, "processor": 6, "arm": [6, 209], "i386": 6, "switch": 6, "conda": 6, "17": 6, "g": [6, 110, 138, 203, 204, 279, 290, 295], "clang": 6, "cmake": 6, "24": 6, "xcode": 6, "15": [6, 110], "environ": 6, "via": [6, 290, 291], "rosetta": 6, "unam": 6, "p": [6, 140, 197, 199, 213, 218, 219, 220, 272], "clone": 6, "git": 6, "com": 6, "ml": 6, "explor": 6, "cd": 6, "brew": 6, "global": [6, 146, 285], "env": 6, "cmake_build_parallel_level": 6, "edit": [6, 245], "unittest": 6, "discov": 6, "stub": 6, "dev": 6, "generate_stub": 6, "mkdir": 6, "either": [6, 11, 47, 57, 58, 63, 73, 74, 75, 96, 99, 100, 108, 109, 110, 117, 122, 124, 127, 129, 175, 187, 253], "libmlx": 6, "preprocessor": 6, "metal_path": 6, "mlx_build_test": 6, "ON": 6, "mlx_build_exampl": 6, "mlx_build_benchmark": 6, "mlx_build_python_bind": 6, "multipl": [6, 122, 129, 138, 139, 246, 255, 290, 293], "wish": 6, "variabl": [6, 98, 107, 187, 189, 190], "export": 6, "developer_dir": 6, "app": 6, "content": [6, 232], "sdk": 6, "xcrun": 6, "macosx": 6, "show": [6, 209], "unabl": 6, "tool": 6, "select": [6, 191, 228, 232], "sudo": 6, "ouptut": 6, "finder": 6, "iterm": 6, "termin": 6, "click": 6, "uncheck": 6, "window": 6, "restart": 6, "devicetyp": 7, "attribut": [7, 8, 9, 26, 282], "kwarg": [8, 158, 159, 295], "union": [10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 28, 29, 30, 31, 32, 33, 34, 36, 38, 39, 40, 41, 42, 43, 45, 46, 47, 48, 49, 51, 53, 54, 55, 56, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 72, 73, 74, 75, 76, 77, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127, 128, 129, 130, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141, 142, 144, 145, 147, 148, 149, 150, 151, 152, 153, 154, 162, 163, 165, 166, 167, 168, 169, 170, 171, 172, 173, 174, 175, 176, 177, 178, 179, 180, 181, 182, 183, 184, 185, 186, 187, 188, 191, 192, 193, 217, 233, 235, 243], "absolut": [10, 13, 259, 260, 271], "semant": [11, 61, 73, 74, 75, 99, 100, 108, 109, 117, 122, 124, 127, 129, 175, 294], "keepdim": [12, 14, 22, 23, 29, 30, 31, 32, 40, 41, 42, 43, 45, 56, 59, 110, 121, 123, 125, 126, 137, 167, 176, 188], "reduct": [12, 14, 121, 123, 126, 137, 261, 262, 263, 264, 265, 266, 267, 268, 269, 270, 271, 272], "reduc": [12, 14, 22, 23, 121, 123, 125, 126, 137, 164, 176, 188, 215, 268], "unspecifi": [12, 14, 15, 22, 23, 24, 25, 64, 97, 121, 123, 125, 126, 132, 136, 137, 151, 167, 168, 176, 178, 188, 192, 295], "entir": [12, 14, 22, 23, 121, 123, 125, 126, 137, 176, 188, 219, 220], "singleton": [12, 14, 22, 23, 121, 122, 123, 125, 126, 137, 176, 188], "rtol": 13, "05": [13, 215, 223, 224, 225, 249], "atol": 13, "08": [13, 196, 197, 198, 199, 203, 262], "approxim": [13, 164, 222, 258, 259, 260], "comparison": [13, 75, 99, 100, 108, 109], "equal": [13, 24, 60, 81, 100, 109, 136, 145, 169, 224, 226], "ab": [13, 110, 187, 223, 224, 225, 227, 249, 273], "array_equ": 13, "rel": 13, "toler": 13, "boolean": [13, 60, 103, 104, 105, 106, 118, 119, 120, 209, 242, 289], "interv": [15, 111, 145, 149], "increment": 15, "otherwis": [15, 205, 206, 233, 235, 243, 256, 257, 265, 271, 278, 290, 291], "int32": [15, 110, 145, 209, 289, 292], "convent": [15, 67, 198], "lead": 15, "fraction": 15, "integr": [15, 178, 290], "invers": [16, 17, 18, 19, 20, 21, 77, 85, 86, 87, 88, 89, 90], "cosin": [16, 17, 68, 69, 251, 262, 288], "hyperbol": [17, 19, 21, 69, 166, 181], "sine": [18, 19, 165, 166, 251, 288], "minimum": [23, 63, 255, 262], "kth": [24, 136], "partit": 24, "order": [24, 110, 136, 138, 213, 223, 244, 253, 288], "undefin": [24, 136, 289], "sort": [24, 25, 136], "flatten": [24, 25, 110, 134, 136, 151, 168, 178, 179, 205], "dimension": [26, 82, 83, 84, 85, 86, 87, 91, 92, 93, 215, 216, 217, 221, 226, 248, 255, 289, 291], "val": [26, 97], "tupl": [26, 47, 58, 64, 66, 74, 78, 80, 107, 110, 135, 138, 152, 172, 187, 189, 197, 198, 199, 200, 205, 206, 207, 217, 235, 237, 251, 253, 287, 288], "ndarrai": [26, 289, 290, 292], "properti": [27, 35, 44, 50, 52, 242, 288], "argument": [27, 47, 58, 78, 98, 187, 206, 213, 285, 288, 293, 294, 295], "decim": [48, 153], "indices_or_sect": [53, 169], "nest": [57, 213, 282, 287, 288], "ddof": [59, 188], "equal_nan": 60, "nan": [60, 104], "a_min": 63, "a_max": 63, "edg": [63, 135], "At": 63, "anoth": [63, 122, 175, 191, 213, 228, 288, 289, 294], "pad": [65, 66, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 216, 217], "dilat": [65, 66], "group": [65, 66, 72, 138, 139, 223, 248], "1d": [65, 67, 179], "convolut": [65, 66, 67, 216, 217, 219, 220], "channel": [65, 66, 215, 216, 217, 219, 220], "c_in": [65, 66], "c_out": [65, 66], "convolv": [65, 66], "2d": [66, 138, 215, 219], "spatial": [66, 223], "symmetr": 66, "discret": [67, 82, 83, 84, 85, 86, 87, 91, 92, 93, 221], "swap": [67, 177, 245, 248], "conv": 67, "filter": [67, 216, 217, 228, 232], "flip": 67, "signal": 67, "bias": [72, 138, 139, 233, 243, 246], "group_siz": [72, 138, 139, 248], "64": [72, 138, 139, 209, 248], "configur": 72, "formal": [72, 138], "notat": [72, 205, 237], "quantiz": [72, 112, 139, 248], "w_i": [72, 138], "hat": [72, 138], "occupi": [72, 138, 139], "divis": [73, 96, 138], "quotient": [73, 74, 96], "remaind": 74, "fuction": 74, "faster": [74, 258, 288], "mathrm": [76, 162, 224], "frac": [76, 138, 162, 195, 196, 197, 198, 199, 203, 215, 218, 219, 220, 223, 224, 225, 226, 249, 262, 265, 268], "pi": [76, 255, 288], "int_0": 76, "dx": 76, "erf": 77, "node": [78, 164, 190], "dict": [78, 112, 156, 157, 158, 238, 241, 244, 245, 282, 287, 288, 293], "leaf": [78, 205, 206, 232], "exponenti": [79, 252, 276], "insert": [80, 294], "ident": [81, 174, 240], "diagon": [81, 184, 185, 186], "zero": [81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 184, 185, 186, 193, 213, 218, 219, 220, 235, 289], "th": 81, "whose": [81, 194], "One": [82, 85, 91, 154, 288], "fourier": [82, 83, 84, 85, 86, 87, 91, 92, 93], "truncat": [82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 148], "dft": [82, 83, 84, 85, 86, 87, 91, 92, 93], "rfft": 88, "real": [88, 89, 90, 91, 92, 93], "rfft2": 89, "rfftn": 90, "silent": [91, 92, 93], "start_axi": 94, "end_axi": 94, "integ": [96, 110, 135, 138, 139, 140, 145, 169, 182, 190, 209, 221, 289], "floor": 96, "fun": [98, 107, 187, 189, 190, 289, 290, 294], "argnam": [98, 187], "cpp_function": [], "neither": [98, 187], "keyword": [98, 158, 159, 187, 206, 213, 285, 293, 295], "strict": [99, 108, 233, 235, 243], "ordinari": 102, "ord": 110, "tabl": [110, 209, 221], "frobeniu": 110, "matric": 110, "strictli": 110, "mathemat": 110, "variou": 110, "purpos": 110, "calcul": 110, "fro": 110, "inf": [110, 246], "largest": 110, "sing": 110, "smallest": 110, "singular": 110, "nuclear": 110, "_f": 110, "sum_": [110, 268], "a_": 110, "valueerror": [110, 235, 288], "refer": [110, 224, 227, 273, 289], "golub": 110, "van": 110, "loan": 110, "baltimor": 110, "md": 110, "john": 110, "hopkin": 110, "univers": 110, "1985": 110, "pg": 110, "la": 110, "arang": [110, 289, 291], "9": [110, 195, 197, 198, 199, 200, 291], "74597": 110, "20": 110, "84804": 110, "41421": 110, "23607": 110, "74166": 110, "24264": 110, "11": 110, "225": 110, "50": 111, "evenli": 111, "binari": [112, 155, 156, 157, 158, 159, 256, 261, 278], "npy": [112, 155, 293], "safetensor": [112, 157, 290, 293], "gguf": [112, 156, 293], "unsupport": 112, "tensor": [112, 182, 272, 291], "natur": [113, 115, 290], "logarithm": [113, 114, 115, 116], "log": [115, 117, 121, 266, 268, 270], "plu": 115, "exp": [117, 121, 142, 167, 252, 266, 276, 294], "stabl": [117, 121, 167, 268], "prepend": 122, "remov": [122, 141, 172], "negat": 130, "beforehand": 134, "pad_with": 135, "constant_valu": 135, "pad_width": 135, "before_1": 135, "after_1": 135, "before_2": 135, "after_2": 135, "before_n": 135, "after_n": 135, "before_i": 135, "after_i": 135, "extend": 135, "side": 135, "smaller": [136, 200], "everi": [138, 164, 206, 288], "particular": [138, 223], "consecut": [138, 251], "w_1": 138, "w_g": 138, "begin": [138, 252, 256, 265, 271, 276, 278], "align": 138, "max_i": 138, "min_i": 138, "textrm": [138, 222, 258], "round": 138, "pack": [138, 139], "unsign": [138, 139, 209], "lower": [138, 145, 148, 149, 184], "upper": [138, 145, 148, 149], "1st": 138, "signific": 138, "2nd": 138, "dequant": 138, "w_q": 138, "whether": [139, 232, 246], "prng": [140, 141, 142, 143, 144, 145, 146, 147, 148, 149, 285], "num_sampl": 141, "unnorm": [141, 261, 263], "draw": 141, "uint32": [141, 209], "cdf": [142, 222, 258], "accord": [142, 191, 246], "seed": 143, "low": [145, 149], "high": [145, 149, 213, 221], "bound": [145, 148, 149, 222, 289, 294], "roadcast": 145, "domain": 148, "uniformli": 149, "repetit": 151, "preserv": [152, 288], "reciproc": 154, "arr": [155, 289], "uncompress": 158, "my_path": 158, "tree_flatten": [158, 206, 207, 213], "transformerencod": 158, "128": [158, 213], "flat_param": 158, "compress": 159, "simplif": 164, "reus": 164, "consumpt": 164, "meant": 164, "overhead": [164, 290, 294], "1m": 164, "thousand": [164, 290], "foo": 164, "matmul": [164, 294], "twice": [164, 294], "subarrai": 169, "being": [174, 213], "prevent": [174, 272, 291], "flow": [174, 290], "unchang": [174, 251], "axis1": 177, "axis2": 177, "taken": 178, "prior": [178, 179], "exclud": 179, "dot": [182, 205, 237, 246], "elsewher": [184, 289], "col": 184, "triangl": 184, "mse": 187, "param": [187, 213, 288], "lvalu": 187, "dlvalu": 187, "dparam": 187, "lasso": 187, "l1": [187, 265, 267, 268, 271], "varianc": [188, 215, 223], "divisor": 188, "cotang": 189, "in_ax": [190, 288], "out_ax": [190, 288], "prefix": [190, 205], "fn": [194, 206, 292], "callabl": [194, 205, 206, 228, 229, 232, 253, 257], "wrt": 194, "rho": 195, "06": [195, 272], "paper": [195, 196, 197, 199, 200, 215, 255], "zeiler": 195, "2012": [195, 203], "adapt": [195, 196], "1212": 195, "5701": 195, "v_": [195, 196, 197, 198, 199, 203, 204], "v_t": [195, 196, 197, 198, 199, 203, 204], "g_t": [195, 196, 197, 198, 199, 200, 203, 204], "delta": [195, 265], "w_": [195, 196, 197, 198, 199, 200, 203, 204], "u_t": 195, "epsilon": [195, 196, 197, 198, 199, 203, 215, 223, 224, 225, 249, 262], "u_": 195, "w_t": [195, 196, 197, 198, 199, 200, 203, 204], "lambda": [195, 196, 197, 198, 199, 200, 203, 204, 206, 213, 228, 233, 252, 276, 288], "averag": [195, 197, 198, 199], "denomin": [195, 196, 197, 198, 199, 203, 224, 262], "stabil": [195, 196, 197, 198, 199, 203, 215, 223, 224, 225, 249, 262], "duchi": 196, "hazan": 196, "singer": 196, "2011": 196, "subgradi": 196, "onlin": 196, "stochast": [196, 197, 199, 204, 290], "jmlr": 196, "999": [197, 198, 199], "omit": [197, 199], "estim": [197, 199], "kingma": [197, 199], "ba": [197, 199], "2015": [197, 199, 219], "iclr": [197, 198, 199], "m_": [197, 198, 199, 200], "beta_1": [197, 198, 199, 200], "m_t": [197, 198, 199, 200], "beta_2": [197, 198, 199, 200], "weight_decai": [198, 200, 204], "contrast": [198, 202], "loshchilov": 198, "hutter": 198, "decoupl": 198, "decai": [198, 200, 204], "regular": [198, 219, 227, 273, 289], "adam": [199, 200], "infin": [103, 105, 106, 199], "99": [200, 203], "sign": [200, 209], "tend": 200, "larger": [200, 251], "10x": 200, "adamw": 200, "maintain": [200, 219, 220], "strength": [200, 204], "wd": 200, "chen": 200, "symbol": 200, "discoveri": 200, "2302": 200, "06675": 200, "c_": 200, "eta": 200, "c_t": 200, "momentum": [200, 204, 215], "basi": 201, "appli": [201, 206, 213, 215, 216, 217, 219, 220, 222, 223, 224, 225, 226, 227, 229, 240, 247, 248, 249, 250, 252, 254, 256, 258, 261, 262, 263, 264, 265, 266, 267, 268, 269, 270, 271, 272, 273, 274, 275, 276, 277, 278], "optimizerst": 201, "recurs": [202, 213, 232, 233, 238, 241, 243, 282], "defaultdict": 202, "miss": [202, 235, 293], "present": 202, "tieleman": 203, "hinton": 203, "lectur": 203, "coursera": 203, "smooth": [203, 263, 271], "dampen": 204, "nesterov": 204, "descent": [204, 290], "mu": 204, "tau": 204, "l2": [204, 265, 268], "penalti": 204, "is_leaf": [205, 206], "arbitrari": [205, 282], "depth": [205, 220, 288], "hello": [205, 207], "charact": 205, "flat": [205, 207], "superset": 206, "extra": 206, "closer": 206, "constitut": 206, "dict_kei": 206, "recreat": 207, "world": 207, "42": 207, "byte": 209, "bool_": 209, "uint8": 209, "uint16": 209, "16": [209, 224, 228, 282], "uint64": 209, "int8": 209, "int16": 209, "int64": 209, "arbitrarili": [213, 287, 288, 292], "done": [213, 218, 290, 291], "manual": 213, "explicitli": [213, 285], "solv": 213, "intuit": 213, "freez": [213, 243, 282], "finetun": 213, "in_dim": [213, 282], "out_dim": [213, 282], "enumer": 213, "caus": [213, 290], "local": [213, 219], "scope": 213, "l2_loss": 213, "y_hat": 213, "trainable_paramet": [213, 232], "loss_and_grad": 213, "workhors": 213, "Its": 213, "frozen": [213, 233, 241, 243, 248, 282], "individu": [213, 219, 220], "subset": [213, 232], "action": 213, "displai": 213, "tree_map": 213, "count": 213, "num_param": 213, "preclud": 213, "pure": [213, 284], "pattern": [213, 290], "achiev": 213, "other_input": 213, "necessari": 213, "wrap": 213, "apply_to_modul": [213, 233], "children": 213, "filter_and_map": 213, "leaf_modul": 213, "load_weight": [213, 290], "named_modul": 213, "save_weight": 213, "unfreez": [213, 233], "update_modul": 213, "sequenti": 213, "relu": [213, 247, 257, 274], "prelu": 213, "gelu": [213, 259, 260], "silu": 213, "selu": 213, "mish": 213, "quantizedlinear": 213, "conv1d": 213, "conv2d": 213, "batchnorm": 213, "layernorm": 213, "groupnorm": 213, "instancenorm": 213, "dropout": [213, 219, 220, 240, 257], "dropout2d": 213, "dropout3d": 213, "alibi": 213, "sinusoidalpositionalencod": 213, "gelu_approx": [213, 222, 258], "gelu_fast_approx": [213, 222, 258], "binary_cross_entropi": 213, "kl_div_loss": 213, "l1_loss": 213, "mse_loss": 213, "nll_loss": 213, "smooth_l1_loss": 213, "triplet_loss": 213, "hinge_loss": 213, "huber_loss": 213, "log_cosh_loss": 213, "cosine_similarity_loss": 213, "affin": [215, 223, 224, 225, 226, 248], "track_running_stat": 215, "var": [215, 223, 224, 225], "gamma": [215, 223, 224, 225, 249], "nc": 215, "nlc": [215, 216], "four": 215, "nhwc": [215, 217], "height": [215, 217, 219, 220], "width": [215, 217, 219, 220, 248], "deep": 215, "intern": 215, "covari": 215, "shift": 215, "bn": 215, "in_channel": [216, 217], "out_channel": [216, 217], "kernel_s": [216, 217], "learnabl": [216, 217, 253], "portion": 218, "dure": [218, 219, 220, 291], "independ": [219, 220], "nwhc": 219, "whc": 219, "entri": [219, 220], "benefici": [219, 220, 290], "earli": 219, "adjac": 219, "pixel": 219, "correl": 219, "thompson": 219, "goroshin": 219, "jain": 219, "lecun": 219, "bregler": 219, "cvpr": 219, "ndhwc": 220, "dhwc": 220, "medic": 220, "video": 220, "num_embed": 221, "lookup": 221, "typic": [221, 284, 290], "usual": [221, 287, 290], "vocabulari": 221, "approx": 222, "unit": [222, 250, 252, 254, 258, 259, 260, 275, 276, 277], "phi": [222, 258], "geluapprox": 222, "sigma": [222, 254, 259, 260, 277], "60033": [222, 259], "0433603": [222, 259], "gelufast": 222, "773": [222, 260], "regard": 222, "num_group": 223, "pytorch_compat": 223, "split": 223, "preced": 223, "http": [223, 224, 225, 227, 249, 273], "org": [223, 224, 225, 227, 249, 273], "1803": 223, "08494": 223, "inorm": 224, "1607": [224, 225], "08022": 224, "06450": 225, "uniform": [226, 235, 285, 288, 294], "mathcal": 226, "u": 226, "d_i": 226, "monoton": [227, 273], "1908": [227, 273], "08681": [227, 273], "tanh": [227, 273], "softplu": [227, 273], "map_fn": [228, 232], "filter_fn": [228, 232], "valid_parameter_filt": 228, "apply_fn": 229, "descend": 230, "is_leaf_fn": 232, "found": 232, "drop": 232, "idempot": [233, 243], "attent": [233, 246, 255, 257], "endswith": 233, "file_or_weight": 235, "ok": [235, 288], "certain": 240, "ie": 243, "noop": 243, "unfrozen": 243, "chang": [244, 248, 265, 271, 291], "tracer": 244, "partial": [244, 245, 290], "child": 245, "programmat": 245, "query_input_dim": 246, "key_input_dim": 246, "value_input_dim": 246, "value_dim": 246, "value_output_dim": 246, "head": [246, 257], "aggreg": 246, "linearli": 246, "neg": [105, 246, 270, 272, 289], "attend": 246, "num_paramet": 247, "init": 247, "25": 247, "parametr": [247, 274], "classmethod": 248, "from_linear": 248, "quantize_modul": 248, "1910": 249, "07467": 249, "rectifi": [250, 275], "10000": 251, "rotat": 251, "slightli": [251, 294], "angular": 251, "frequenc": [251, 255], "_cos_sin_theta_kei": 251, "precomput": 251, "_cos_sin_theta_valu": 251, "leq": [252, 265, 276], "0507": [252, 276], "67326": [252, 276], "elu": [252, 276], "plain": 253, "known": [254, 277], "swish": [254, 277], "cdot": [254, 259, 260, 262, 264, 277], "min_freq": 255, "0001": 255, "max_freq": 255, "cos_first": 255, "full_turn": 255, "sinusoid": 255, "sin": [255, 288, 292], "threshold": [256, 265, 271, 278], "geq": [256, 278], "num_encoder_lay": 257, "num_decoder_lay": 257, "custom_encod": 257, "custom_decod": 257, "norm_first": 257, "decod": 257, "interact": 257, "mechan": 257, "hidden": 257, "exact": [259, 260], "0003": 259, "015": 260, "pre": 261, "predict": [261, 263, 264, 265, 266, 267, 268, 269, 270, 271], "105361": 261, "223144": 261, "20397": 261, "916291": 261, "612192": 261, "x1": 262, "x2": 262, "x_1": 262, "x_2": 262, "label_smooth": 263, "hing": 264, "y_": [264, 268], "pred": [264, 268], "huber": 265, "l_": [265, 272], "kullback": 266, "leibler": 266, "diverg": 266, "cosh": 268, "logcosh": 268, "sensit": 268, "outlier": 268, "dual": 268, "behavior": [268, 289, 290], "offer": 268, "balanc": 268, "robust": 268, "approach": [268, 288], "task": 268, "likelihood": 270, "nll": 270, "formula": 271, "anchor": 272, "margin": 272, "triplet": 272, "_p": 272, "degre": 272, "pairwis": 272, "instabl": 272, "subclass": 282, "concept": 282, "mymlp": 282, "in_proj": 282, "subsequ": 284, "implicit": [285, 288], "fine": [285, 290], "grain": 285, "control": [285, 290], "manag": [285, 294], "pseudo": 285, "altern": 285, "splittabl": 285, "threefri": 285, "counter": 285, "cycl": 287, "slice": 289, "ellipsi": 289, "syntax": 289, "idx": 289, "mix": 289, "take_along_axi": 289, "lack": 289, "propag": [288, 289], "extrem": [289, 290], "ineffici": [289, 290], "nonzero": 289, "reflect": [289, 291], "dfdx": [288, 289], "record": 290, "nice": [288, 290], "rerun": 290, "dynam": 290, "easier": 290, "worri": 290, "fun1": 290, "expensive_fun": 290, "cost": 290, "code": 290, "consum": 290, "eager": 290, "thank": 290, "weights_fp16": 290, "trade": 290, "too": 290, "bad": 290, "idea": [288, 290], "On": [288, 290], "grow": 290, "computation": 290, "costli": 290, "wide": 290, "pretti": 290, "ten": [288, 290], "okai": 290, "outer": 290, "value_and_grad_fn": 290, "awar": 290, "implicitli": 290, "anytim": 290, "memoryview": [290, 291], "perfectli": 290, "first_lay": 290, "second_layer_a": 290, "second_layer_b": 290, "frequent": 290, "protocol": 291, "receiv": 291, "pep": 291, "3118": 291, "view": 291, "a_view": 291, "owndata": 291, "quit": [288, 291], "power": [288, 291], "extern": 291, "x_view": 291, "modifi": 291, "df": 291, "x\u00b2": 291, "2x": 291, "indirectli": 291, "modif": 291, "seen": 291, "occur": 291, "incorpor": 291, "issu": [288, 291], "incorrect": 291, "experiment": 291, "break": 291, "advis": 291, "intermedi": 291, "jnp": 291, "tf": 291, "inspect": 292, "page": 292, "composit": 292, "archiv": 293, "savez_compress": 293, "save_safetensor": 293, "save_gguf": 293, "arr_0": 293, "pool": 294, "advantag": 294, "don": 294, "parallel": 294, "race": 294, "interest": 294, "albeit": 294, "contriv": [288, 294], "suppos": [288, 294], "d1": 294, "d2": 294, "4096": [288, 294], "dens": 294, "better": [288, 294], "millisecond": 294, "measur": 294, "default_stream": 295, "default_devic": 295, "my_devic": 295, "pypi": 6, "forg": 6, "grep": 6, "cmake_host_system_processor": 6, "arm64": 6, "x86_64": 6, "wipe": 6, "cahc": 6, "rf": 6, "inifn": 103, "behind": 288, "d2fdx2": 288, "differentiaion": 288, "backward": 288, "zero_grad": 288, "detach": 288, "requires_grad": 288, "dloss_dw": 288, "dloss_dx": 288, "lot": 288, "redund": 288, "stop_gradi": 288, "autom": 288, "sake": 288, "clariti": 288, "difficult": 288, "primit": 288, "priorit": 288, "xs": 288, "ys": 288, "naive_add": 288, "vmap_add": 288, "timeit": 288, "total": 288, "390": 288, "wherea": 288, "025": 288, "Of": 288, "handi": 288}, "objects": {"mlx.core": [[7, 0, 1, "", "Device"], [8, 0, 1, "", "Dtype"], [9, 0, 1, "", "Stream"], [10, 2, 1, "", "abs"], [11, 2, 1, "", "add"], [12, 2, 1, "", "all"], [13, 2, 1, "", "allclose"], [14, 2, 1, "", "any"], [15, 2, 1, "", "arange"], [16, 2, 1, "", "arccos"], [17, 2, 1, "", "arccosh"], [18, 2, 1, "", "arcsin"], [19, 2, 1, "", "arcsinh"], [20, 2, 1, "", "arctan"], [21, 2, 1, "", "arctanh"], [22, 2, 1, "", "argmax"], [23, 2, 1, "", "argmin"], [24, 2, 1, "", "argpartition"], [25, 2, 1, "", "argsort"], [26, 0, 1, "", "array"], [60, 2, 1, "", "array_equal"], [61, 2, 1, "", "broadcast_to"], [62, 2, 1, "", "ceil"], [63, 2, 1, "", "clip"], [64, 2, 1, "", "concatenate"], [65, 2, 1, "", "conv1d"], [66, 2, 1, "", "conv2d"], [67, 2, 1, "", "convolve"], [68, 2, 1, "", "cos"], [69, 2, 1, "", "cosh"], [70, 2, 1, "", "default_device"], [71, 2, 1, "", "default_stream"], [72, 2, 1, "", "dequantize"], [73, 2, 1, "", "divide"], [74, 2, 1, "", "divmod"], [75, 2, 1, "", "equal"], [76, 2, 1, "", "erf"], [77, 2, 1, "", "erfinv"], [78, 2, 1, "", "eval"], [79, 2, 1, "", "exp"], [80, 2, 1, "", "expand_dims"], [81, 2, 1, "", "eye"], [94, 2, 1, "", "flatten"], [95, 2, 1, "", "floor"], [96, 2, 1, "", "floor_divide"], [97, 2, 1, "", "full"], [98, 2, 1, "", "grad"], [99, 2, 1, "", "greater"], [100, 2, 1, "", "greater_equal"], [101, 2, 1, "", "identity"], [102, 2, 1, "", "inner"], [103, 2, 1, "", "isinf"], [104, 2, 1, "", "isnan"], [105, 2, 1, "", "isneginf"], [106, 2, 1, "", "isposinf"], [107, 2, 1, "", "jvp"], [108, 2, 1, "", "less"], [109, 2, 1, "", "less_equal"], [111, 2, 1, "", "linspace"], [112, 2, 1, "", "load"], [113, 2, 1, "", "log"], [114, 2, 1, "", "log10"], [115, 2, 1, "", "log1p"], [116, 2, 1, "", "log2"], [117, 2, 1, "", "logaddexp"], [118, 2, 1, "", "logical_and"], [119, 2, 1, "", "logical_not"], [120, 2, 1, "", "logical_or"], [121, 2, 1, "", "logsumexp"], [122, 2, 1, "", "matmul"], [123, 2, 1, "", "max"], [124, 2, 1, "", "maximum"], [125, 2, 1, "", "mean"], [126, 2, 1, "", "min"], [127, 2, 1, "", "minimum"], [128, 2, 1, "", "moveaxis"], [129, 2, 1, "", "multiply"], [130, 2, 1, "", "negative"], [131, 2, 1, "", "new_stream"], [132, 2, 1, "", "ones"], [133, 2, 1, "", "ones_like"], [134, 2, 1, "", "outer"], [135, 2, 1, "", "pad"], [136, 2, 1, "", "partition"], [137, 2, 1, "", "prod"], [138, 2, 1, "", "quantize"], [139, 2, 1, "", "quantized_matmul"], [150, 2, 1, "", "reciprocal"], [151, 2, 1, "", "repeat"], [152, 2, 1, "", "reshape"], [153, 2, 1, "", "round"], [154, 2, 1, "", "rsqrt"], [155, 2, 1, "", "save"], [156, 2, 1, "", "save_gguf"], [157, 2, 1, "", "save_safetensors"], [158, 2, 1, "", "savez"], [159, 2, 1, "", "savez_compressed"], [160, 2, 1, "", "set_default_device"], [161, 2, 1, "", "set_default_stream"], [162, 2, 1, "", "sigmoid"], [163, 2, 1, "", "sign"], [164, 2, 1, "", "simplify"], [165, 2, 1, "", "sin"], [166, 2, 1, "", "sinh"], [167, 2, 1, "", "softmax"], [168, 2, 1, "", "sort"], [169, 2, 1, "", "split"], [170, 2, 1, "", "sqrt"], [171, 2, 1, "", "square"], [172, 2, 1, "", "squeeze"], [173, 2, 1, "", "stack"], [174, 2, 1, "", "stop_gradient"], [175, 2, 1, "", "subtract"], [176, 2, 1, "", "sum"], [177, 2, 1, "", "swapaxes"], [178, 2, 1, "", "take"], [179, 2, 1, "", "take_along_axis"], [180, 2, 1, "", "tan"], [181, 2, 1, "", "tanh"], [182, 2, 1, "", "tensordot"], [183, 2, 1, "", "transpose"], [184, 2, 1, "", "tri"], [185, 2, 1, "", "tril"], [186, 2, 1, "", "triu"], [187, 2, 1, "", "value_and_grad"], [188, 2, 1, "", "var"], [189, 2, 1, "", "vjp"], [190, 2, 1, "", "vmap"], [191, 2, 1, "", "where"], [192, 2, 1, "", "zeros"], [193, 2, 1, "", "zeros_like"]], "mlx.core.Device": [[7, 1, 1, "", "__init__"]], "mlx.core.Dtype": [[8, 1, 1, "", "__init__"]], "mlx.core.Stream": [[9, 1, 1, "", "__init__"]], "mlx.core.array": [[27, 3, 1, "", "T"], [26, 1, 1, "", "__init__"], [28, 1, 1, "", "abs"], [29, 1, 1, "", "all"], [30, 1, 1, "", "any"], [31, 1, 1, "", "argmax"], [32, 1, 1, "", "argmin"], [33, 1, 1, "", "astype"], [34, 1, 1, "", "cos"], [35, 3, 1, "", "dtype"], [36, 1, 1, "", "exp"], [37, 1, 1, "", "item"], [38, 1, 1, "", "log"], [39, 1, 1, "", "log1p"], [40, 1, 1, "", "logsumexp"], [41, 1, 1, "", "max"], [42, 1, 1, "", "mean"], [43, 1, 1, "", "min"], [44, 3, 1, "", "ndim"], [45, 1, 1, "", "prod"], [46, 1, 1, "", "reciprocal"], [47, 1, 1, "", "reshape"], [48, 1, 1, "", "round"], [49, 1, 1, "", "rsqrt"], [50, 3, 1, "", "shape"], [51, 1, 1, "", "sin"], [52, 3, 1, "", "size"], [53, 1, 1, "", "split"], [54, 1, 1, "", "sqrt"], [55, 1, 1, "", "square"], [56, 1, 1, "", "sum"], [57, 1, 1, "", "tolist"], [58, 1, 1, "", "transpose"], [59, 1, 1, "", "var"]], "mlx.core.fft": [[82, 2, 1, "", "fft"], [83, 2, 1, "", "fft2"], [84, 2, 1, "", "fftn"], [85, 2, 1, "", "ifft"], [86, 2, 1, "", "ifft2"], [87, 2, 1, "", "ifftn"], [88, 2, 1, "", "irfft"], [89, 2, 1, "", "irfft2"], [90, 2, 1, "", "irfftn"], [91, 2, 1, "", "rfft"], [92, 2, 1, "", "rfft2"], [93, 2, 1, "", "rfftn"]], "mlx.core.linalg": [[110, 2, 1, "", "norm"]], "mlx.core.random": [[140, 2, 1, "", "bernoulli"], [141, 2, 1, "", "categorical"], [142, 2, 1, "", "gumbel"], [143, 2, 1, "", "key"], [144, 2, 1, "", "normal"], [145, 2, 1, "", "randint"], [146, 2, 1, "", "seed"], [147, 2, 1, "", "split"], [148, 2, 1, "", "truncated_normal"], [149, 2, 1, "", "uniform"]], "mlx.nn": [[214, 0, 1, "", "ALiBi"], [215, 0, 1, "", "BatchNorm"], [216, 0, 1, "", "Conv1d"], [217, 0, 1, "", "Conv2d"], [218, 0, 1, "", "Dropout"], [219, 0, 1, "", "Dropout2d"], [220, 0, 1, "", "Dropout3d"], [221, 0, 1, "", "Embedding"], [222, 0, 1, "", "GELU"], [223, 0, 1, "", "GroupNorm"], [224, 0, 1, "", "InstanceNorm"], [225, 0, 1, "", "LayerNorm"], [226, 0, 1, "", "Linear"], [227, 0, 1, "", "Mish"], [282, 0, 1, "", "Module"], [246, 0, 1, "", "MultiHeadAttention"], [247, 0, 1, "", "PReLU"], [248, 0, 1, "", "QuantizedLinear"], [249, 0, 1, "", "RMSNorm"], [250, 0, 1, "", "ReLU"], [251, 0, 1, "", "RoPE"], [252, 0, 1, "", "SELU"], [253, 0, 1, "", "Sequential"], [254, 0, 1, "", "SiLU"], [255, 0, 1, "", "SinusoidalPositionalEncoding"], [256, 0, 1, "", "Step"], [257, 0, 1, "", "Transformer"], [258, 0, 1, "", "gelu"], [259, 0, 1, "", "gelu_approx"], [260, 0, 1, "", "gelu_fast_approx"], [273, 0, 1, "", "mish"], [274, 0, 1, "", "prelu"], [275, 0, 1, "", "relu"], [276, 0, 1, "", "selu"], [277, 0, 1, "", "silu"], [278, 0, 1, "", "step"], [194, 2, 1, "", "value_and_grad"]], "mlx.nn.Module": [[228, 1, 1, "", "apply"], [229, 1, 1, "", "apply_to_modules"], [230, 1, 1, "", "children"], [231, 1, 1, "", "eval"], [232, 1, 1, "", "filter_and_map"], [233, 1, 1, "", "freeze"], [234, 1, 1, "", "leaf_modules"], [235, 1, 1, "", "load_weights"], [236, 1, 1, "", "modules"], [237, 1, 1, "", "named_modules"], [238, 1, 1, "", "parameters"], [239, 1, 1, "", "save_weights"], [240, 1, 1, "", "train"], [241, 1, 1, "", "trainable_parameters"], [242, 3, 1, "", "training"], [243, 1, 1, "", "unfreeze"], [244, 1, 1, "", "update"], [245, 1, 1, "", "update_modules"]], "mlx.nn.RoPE": [[251, 4, 1, "", "_cos_sin_theta_key"], [251, 4, 1, "", "_cos_sin_theta_value"]], "mlx.nn.losses": [[261, 0, 1, "", "binary_cross_entropy"], [262, 0, 1, "", "cosine_similarity_loss"], [263, 0, 1, "", "cross_entropy"], [264, 0, 1, "", "hinge_loss"], [265, 0, 1, "", "huber_loss"], [266, 0, 1, "", "kl_div_loss"], [267, 0, 1, "", "l1_loss"], [268, 0, 1, "", "log_cosh_loss"], [269, 0, 1, "", "mse_loss"], [270, 0, 1, "", "nll_loss"], [271, 0, 1, "", "smooth_l1_loss"], [272, 0, 1, "", "triplet_loss"]], "mlx.optimizers": [[195, 0, 1, "", "AdaDelta"], [196, 0, 1, "", "Adagrad"], [197, 0, 1, "", "Adam"], [198, 0, 1, "", "AdamW"], [199, 0, 1, "", "Adamax"], [200, 0, 1, "", "Lion"], [201, 0, 1, "", "Optimizer"], [202, 0, 1, "", "OptimizerState"], [203, 0, 1, "", "RMSprop"], [204, 0, 1, "", "SGD"]], "mlx.optimizers.Optimizer": [[201, 4, 1, "", "state"]], "mlx.utils": [[205, 2, 1, "", "tree_flatten"], [206, 2, 1, "", "tree_map"], [207, 2, 1, "", "tree_unflatten"]]}, "objtypes": {"0": "py:class", "1": "py:method", "2": "py:function", "3": "py:property", "4": "py:attribute"}, "objnames": {"0": ["py", "class", "Python class"], "1": ["py", "method", "Python method"], "2": ["py", "function", "Python function"], "3": ["py", "property", "Python property"], "4": ["py", "attribute", "Python attribute"]}, "titleterms": {"oper": [0, 1, 283], "develop": 1, "document": 1, "introduc": 1, "exampl": [1, 5, 294], "primit": 1, "us": [1, 290, 295], "implement": [1, 3], "cpu": 1, "backend": 1, "gpu": 1, "transform": [1, 257, 286, 288, 290, 292], "build": [1, 6], "bind": 1, "python": [1, 5, 6], "cmake": 1, "setuptool": 1, "usag": [1, 5], "result": 1, "script": [1, 3], "download": [1, 3], "code": [1, 3], "linear": [2, 212, 226], "regress": 2, "llm": 3, "infer": 3, "model": 3, "attent": 3, "layer": [3, 4, 280], "encod": 3, "full": [3, 97], "gener": 3, "put": 3, "all": [3, 12, 29], "togeth": 3, "convert": 3, "weight": 3, "load": [3, 112, 293], "benchmark": 3, "multi": 4, "perceptron": 4, "mlx": [5, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127, 128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141, 142, 143, 144, 145, 146, 147, 148, 149, 150, 151, 152, 153, 154, 155, 156, 157, 158, 159, 160, 161, 162, 163, 164, 165, 166, 167, 168, 169, 170, 171, 172, 173, 174, 175, 176, 177, 178, 179, 180, 181, 182, 183, 184, 185, 186, 187, 188, 189, 190, 191, 192, 193, 194, 195, 196, 197, 198, 199, 200, 201, 202, 203, 204, 205, 206, 207, 214, 215, 216, 217, 218, 219, 220, 221, 222, 223, 224, 225, 226, 227, 228, 229, 230, 231, 232, 233, 234, 235, 236, 237, 238, 239, 240, 241, 242, 243, 244, 245, 246, 247, 248, 249, 250, 251, 252, 253, 254, 255, 256, 257, 258, 259, 260, 261, 262, 263, 264, 265, 266, 267, 268, 269, 270, 271, 272, 273, 274, 275, 276, 277, 278], "instal": [5, 6], "api": [5, 6], "refer": 5, "c": [5, 6], "further": 5, "read": 5, "from": [6, 289], "pypi": [], "troubleshoot": 6, "sourc": 6, "requir": 6, "option": 6, "metal": 6, "found": 6, "x86": 6, "shell": 6, "core": [7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127, 128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141, 142, 143, 144, 145, 146, 147, 148, 149, 150, 151, 152, 153, 154, 155, 156, 157, 158, 159, 160, 161, 162, 163, 164, 165, 166, 167, 168, 169, 170, 171, 172, 173, 174, 175, 176, 177, 178, 179, 180, 181, 182, 183, 184, 185, 186, 187, 188, 189, 190, 191, 192, 193], "devic": [7, 210], "dtype": [8, 35], "stream": [9, 210, 295], "ab": [10, 28], "add": 11, "allclos": 13, "ani": [14, 30], "arang": 15, "arcco": 16, "arccosh": 17, "arcsin": 18, "arcsinh": 19, "arctan": 20, "arctanh": 21, "argmax": [22, 31], "argmin": [23, 32], "argpartit": 24, "argsort": 25, "arrai": [26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 208, 289, 293], "t": 27, "astyp": 33, "co": [34, 68], "exp": [36, 79], "item": 37, "log": [38, 113], "log1p": [39, 115], "logsumexp": [40, 121], "max": [41, 123], "mean": [42, 125], "min": [43, 126], "ndim": 44, "prod": [45, 137], "reciproc": [46, 150], "reshap": [47, 152], "round": [48, 153], "rsqrt": [49, 154], "shape": 50, "sin": [51, 165], "size": 52, "split": [53, 147, 169], "sqrt": [54, 170], "squar": [55, 171], "sum": [56, 176], "tolist": 57, "transpos": [58, 183], "var": [59, 188], "array_equ": 60, "broadcast_to": 61, "ceil": 62, "clip": 63, "concaten": 64, "conv1d": [65, 216], "conv2d": [66, 217], "convolv": 67, "cosh": 69, "default_devic": 70, "default_stream": 71, "dequant": 72, "divid": 73, "divmod": 74, "equal": 75, "erf": 76, "erfinv": 77, "eval": [78, 231], "expand_dim": 80, "ey": 81, "fft": [82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 211], "fft2": 83, "fftn": 84, "ifft": 85, "ifft2": 86, "ifftn": 87, "irfft": 88, "irfft2": 89, "irfftn": 90, "rfft": 91, "rfft2": 92, "rfftn": 93, "flatten": 94, "floor": 95, "floor_divid": 96, "grad": [98, 213], "greater": 99, "greater_equ": 100, "ident": 101, "inner": 102, "jvp": 107, "less": 108, "less_equ": 109, "linalg": 110, "norm": 110, "linspac": 111, "log10": 114, "log2": 116, "logaddexp": 117, "logical_and": 118, "logical_not": 119, "logical_or": 120, "matmul": 122, "maximum": 124, "minimum": 127, "moveaxi": 128, "multipli": 129, "neg": 130, "new_stream": 131, "ones": 132, "ones_lik": 133, "outer": 134, "pad": 135, "partit": 136, "quantiz": 138, "quantized_matmul": 139, "random": [140, 141, 142, 143, 144, 145, 146, 147, 148, 149, 285], "bernoulli": 140, "categor": 141, "gumbel": 142, "kei": 143, "normal": 144, "randint": 145, "seed": 146, "truncated_norm": 148, "uniform": 149, "repeat": 151, "save": [155, 293], "save_gguf": 156, "save_safetensor": 157, "savez": 158, "savez_compress": 159, "set_default_devic": 160, "set_default_stream": 161, "sigmoid": 162, "sign": 163, "simplifi": 164, "sinh": 166, "softmax": 167, "sort": 168, "squeez": 172, "stack": 173, "stop_gradi": 174, "subtract": 175, "swapax": 177, "take": 178, "take_along_axi": 179, "tan": 180, "tanh": 181, "tensordot": 182, "tri": 184, "tril": 185, "triu": 186, "value_and_grad": [187, 194], "vjp": 189, "vmap": 190, "where": 191, "zero": 192, "zeros_lik": 193, "nn": [194, 214, 215, 216, 217, 218, 219, 220, 221, 222, 223, 224, 225, 226, 227, 228, 229, 230, 231, 232, 233, 234, 235, 236, 237, 238, 239, 240, 241, 242, 243, 244, 245, 246, 247, 248, 249, 250, 251, 252, 253, 254, 255, 256, 257, 258, 259, 260, 261, 262, 263, 264, 265, 266, 267, 268, 269, 270, 271, 272, 273, 274, 275, 276, 277, 278], "optim": [195, 196, 197, 198, 199, 200, 201, 202, 203, 204, 284], "adadelta": 195, "adagrad": 196, "adam": 197, "adamw": 198, "adamax": 199, "lion": 200, "optimizerst": 202, "rmsprop": 203, "sgd": 204, "util": [205, 206, 207, 287], "tree_flatten": 205, "tree_map": 206, "tree_unflatten": 207, "data": 209, "type": 209, "support": 209, "algebra": 212, "neural": 213, "network": 213, "quick": [213, 292], "start": [213, 292], "The": 213, "modul": [213, 228, 229, 230, 231, 232, 233, 234, 235, 236, 237, 238, 239, 240, 241, 242, 243, 244, 245, 282], "class": 213, "paramet": [213, 238], "updat": [213, 244, 289], "inspect": 213, "valu": 213, "alibi": 214, "batchnorm": 215, "dropout": 218, "dropout2d": 219, "dropout3d": 220, "embed": 221, "gelu": [222, 258], "groupnorm": 223, "instancenorm": 224, "layernorm": 225, "mish": [227, 273], "appli": 228, "apply_to_modul": 229, "children": 230, "filter_and_map": 232, "freez": 233, "leaf_modul": 234, "load_weight": 235, "named_modul": 237, "save_weight": 239, "train": [240, 242], "trainable_paramet": 241, "unfreez": 243, "update_modul": 245, "multiheadattent": 246, "prelu": [247, 274], "quantizedlinear": 248, "rmsnorm": 249, "relu": [250, 275], "rope": 251, "selu": [252, 276], "sequenti": 253, "silu": [254, 277], "sinusoidalpositionalencod": 255, "step": [256, 278], "gelu_approx": 259, "gelu_fast_approx": 260, "loss": [261, 262, 263, 264, 265, 266, 267, 268, 269, 270, 271, 272, 281], "binary_cross_entropi": 261, "cosine_similarity_loss": 262, "cross_entropi": 263, "hinge_loss": 264, "huber_loss": 265, "kl_div_loss": 266, "l1_loss": 267, "log_cosh_loss": 268, "mse_loss": 269, "nll_loss": 270, "smooth_l1_loss": 271, "triplet_loss": 272, "function": [279, 281, 288, 292], "tree": 287, "index": 289, "differ": 289, "numpi": [289, 291], "In": 289, "place": 289, "lazi": 290, "evalu": 290, "why": 290, "comput": 290, "graph": [290, 292], "onli": 290, "what": 290, "you": 290, "when": 290, "convers": 291, "other": 291, "framework": 291, "pytorch": 291, "jax": 291, "tensorflow": 291, "guid": 292, "basic": 292, "serial": 293, "format": 293, "unifi": 294, "memori": 294, "A": 294, "simpl": 294, "specifi": 295, "isinf": 103, "isnan": 104, "isneginf": 105, "isposinf": 106, "automat": 288, "differenti": 288, "vector": 288}, "envversion": {"sphinx.domains.c": 2, "sphinx.domains.changeset": 1, "sphinx.domains.citation": 1, "sphinx.domains.cpp": 6, "sphinx.domains.index": 1, "sphinx.domains.javascript": 2, "sphinx.domains.math": 2, "sphinx.domains.python": 3, "sphinx.domains.rst": 2, "sphinx.domains.std": 2, "sphinx.ext.intersphinx": 1, "sphinx": 56}})
\ No newline at end of file
+Search.setIndex({"docnames": ["cpp/ops", "dev/extensions", "examples/linear_regression", "examples/llama-inference", "examples/mlp", "index", "install", "python/_autosummary/mlx.core.Device", "python/_autosummary/mlx.core.Dtype", "python/_autosummary/mlx.core.Stream", "python/_autosummary/mlx.core.abs", "python/_autosummary/mlx.core.add", "python/_autosummary/mlx.core.all", "python/_autosummary/mlx.core.allclose", "python/_autosummary/mlx.core.any", "python/_autosummary/mlx.core.arange", "python/_autosummary/mlx.core.arccos", "python/_autosummary/mlx.core.arccosh", "python/_autosummary/mlx.core.arcsin", "python/_autosummary/mlx.core.arcsinh", "python/_autosummary/mlx.core.arctan", "python/_autosummary/mlx.core.arctanh", "python/_autosummary/mlx.core.argmax", "python/_autosummary/mlx.core.argmin", "python/_autosummary/mlx.core.argpartition", "python/_autosummary/mlx.core.argsort", "python/_autosummary/mlx.core.array", "python/_autosummary/mlx.core.array.T", "python/_autosummary/mlx.core.array.abs", "python/_autosummary/mlx.core.array.all", "python/_autosummary/mlx.core.array.any", "python/_autosummary/mlx.core.array.argmax", "python/_autosummary/mlx.core.array.argmin", "python/_autosummary/mlx.core.array.astype", "python/_autosummary/mlx.core.array.cos", "python/_autosummary/mlx.core.array.dtype", "python/_autosummary/mlx.core.array.exp", "python/_autosummary/mlx.core.array.item", "python/_autosummary/mlx.core.array.log", "python/_autosummary/mlx.core.array.log1p", "python/_autosummary/mlx.core.array.logsumexp", "python/_autosummary/mlx.core.array.max", "python/_autosummary/mlx.core.array.mean", "python/_autosummary/mlx.core.array.min", "python/_autosummary/mlx.core.array.ndim", "python/_autosummary/mlx.core.array.prod", "python/_autosummary/mlx.core.array.reciprocal", "python/_autosummary/mlx.core.array.reshape", "python/_autosummary/mlx.core.array.round", "python/_autosummary/mlx.core.array.rsqrt", "python/_autosummary/mlx.core.array.shape", "python/_autosummary/mlx.core.array.sin", "python/_autosummary/mlx.core.array.size", "python/_autosummary/mlx.core.array.split", "python/_autosummary/mlx.core.array.sqrt", "python/_autosummary/mlx.core.array.square", "python/_autosummary/mlx.core.array.sum", "python/_autosummary/mlx.core.array.tolist", "python/_autosummary/mlx.core.array.transpose", "python/_autosummary/mlx.core.array.var", "python/_autosummary/mlx.core.array_equal", "python/_autosummary/mlx.core.broadcast_to", "python/_autosummary/mlx.core.ceil", "python/_autosummary/mlx.core.clip", "python/_autosummary/mlx.core.concatenate", "python/_autosummary/mlx.core.conv1d", "python/_autosummary/mlx.core.conv2d", "python/_autosummary/mlx.core.convolve", "python/_autosummary/mlx.core.cos", "python/_autosummary/mlx.core.cosh", "python/_autosummary/mlx.core.default_device", "python/_autosummary/mlx.core.default_stream", "python/_autosummary/mlx.core.dequantize", "python/_autosummary/mlx.core.diag", "python/_autosummary/mlx.core.diagonal", "python/_autosummary/mlx.core.divide", "python/_autosummary/mlx.core.divmod", "python/_autosummary/mlx.core.equal", "python/_autosummary/mlx.core.erf", "python/_autosummary/mlx.core.erfinv", "python/_autosummary/mlx.core.eval", "python/_autosummary/mlx.core.exp", "python/_autosummary/mlx.core.expand_dims", "python/_autosummary/mlx.core.eye", "python/_autosummary/mlx.core.fft.fft", "python/_autosummary/mlx.core.fft.fft2", "python/_autosummary/mlx.core.fft.fftn", "python/_autosummary/mlx.core.fft.ifft", "python/_autosummary/mlx.core.fft.ifft2", "python/_autosummary/mlx.core.fft.ifftn", "python/_autosummary/mlx.core.fft.irfft", "python/_autosummary/mlx.core.fft.irfft2", "python/_autosummary/mlx.core.fft.irfftn", "python/_autosummary/mlx.core.fft.rfft", "python/_autosummary/mlx.core.fft.rfft2", "python/_autosummary/mlx.core.fft.rfftn", "python/_autosummary/mlx.core.flatten", "python/_autosummary/mlx.core.floor", "python/_autosummary/mlx.core.floor_divide", "python/_autosummary/mlx.core.full", "python/_autosummary/mlx.core.grad", "python/_autosummary/mlx.core.greater", "python/_autosummary/mlx.core.greater_equal", "python/_autosummary/mlx.core.identity", "python/_autosummary/mlx.core.inner", "python/_autosummary/mlx.core.isinf", "python/_autosummary/mlx.core.isnan", "python/_autosummary/mlx.core.isneginf", "python/_autosummary/mlx.core.isposinf", "python/_autosummary/mlx.core.jvp", "python/_autosummary/mlx.core.less", "python/_autosummary/mlx.core.less_equal", "python/_autosummary/mlx.core.linalg.norm", "python/_autosummary/mlx.core.linalg.qr", "python/_autosummary/mlx.core.linspace", "python/_autosummary/mlx.core.load", "python/_autosummary/mlx.core.log", "python/_autosummary/mlx.core.log10", "python/_autosummary/mlx.core.log1p", "python/_autosummary/mlx.core.log2", "python/_autosummary/mlx.core.logaddexp", "python/_autosummary/mlx.core.logical_and", "python/_autosummary/mlx.core.logical_not", "python/_autosummary/mlx.core.logical_or", "python/_autosummary/mlx.core.logsumexp", "python/_autosummary/mlx.core.matmul", "python/_autosummary/mlx.core.max", "python/_autosummary/mlx.core.maximum", "python/_autosummary/mlx.core.mean", "python/_autosummary/mlx.core.min", "python/_autosummary/mlx.core.minimum", "python/_autosummary/mlx.core.moveaxis", "python/_autosummary/mlx.core.multiply", "python/_autosummary/mlx.core.negative", "python/_autosummary/mlx.core.new_stream", "python/_autosummary/mlx.core.ones", "python/_autosummary/mlx.core.ones_like", "python/_autosummary/mlx.core.outer", "python/_autosummary/mlx.core.pad", "python/_autosummary/mlx.core.partition", "python/_autosummary/mlx.core.prod", "python/_autosummary/mlx.core.quantize", "python/_autosummary/mlx.core.quantized_matmul", "python/_autosummary/mlx.core.random.bernoulli", "python/_autosummary/mlx.core.random.categorical", "python/_autosummary/mlx.core.random.gumbel", "python/_autosummary/mlx.core.random.key", "python/_autosummary/mlx.core.random.normal", "python/_autosummary/mlx.core.random.randint", "python/_autosummary/mlx.core.random.seed", "python/_autosummary/mlx.core.random.split", "python/_autosummary/mlx.core.random.truncated_normal", "python/_autosummary/mlx.core.random.uniform", "python/_autosummary/mlx.core.reciprocal", "python/_autosummary/mlx.core.repeat", "python/_autosummary/mlx.core.reshape", "python/_autosummary/mlx.core.round", "python/_autosummary/mlx.core.rsqrt", "python/_autosummary/mlx.core.save", "python/_autosummary/mlx.core.save_gguf", "python/_autosummary/mlx.core.save_safetensors", "python/_autosummary/mlx.core.savez", "python/_autosummary/mlx.core.savez_compressed", "python/_autosummary/mlx.core.set_default_device", "python/_autosummary/mlx.core.set_default_stream", "python/_autosummary/mlx.core.sigmoid", "python/_autosummary/mlx.core.sign", "python/_autosummary/mlx.core.sin", "python/_autosummary/mlx.core.sinh", "python/_autosummary/mlx.core.softmax", "python/_autosummary/mlx.core.sort", "python/_autosummary/mlx.core.split", "python/_autosummary/mlx.core.sqrt", "python/_autosummary/mlx.core.square", "python/_autosummary/mlx.core.squeeze", "python/_autosummary/mlx.core.stack", "python/_autosummary/mlx.core.stop_gradient", "python/_autosummary/mlx.core.subtract", "python/_autosummary/mlx.core.sum", "python/_autosummary/mlx.core.swapaxes", "python/_autosummary/mlx.core.take", "python/_autosummary/mlx.core.take_along_axis", "python/_autosummary/mlx.core.tan", "python/_autosummary/mlx.core.tanh", "python/_autosummary/mlx.core.tensordot", "python/_autosummary/mlx.core.transpose", "python/_autosummary/mlx.core.tri", "python/_autosummary/mlx.core.tril", "python/_autosummary/mlx.core.triu", "python/_autosummary/mlx.core.value_and_grad", "python/_autosummary/mlx.core.var", "python/_autosummary/mlx.core.vjp", "python/_autosummary/mlx.core.vmap", "python/_autosummary/mlx.core.where", "python/_autosummary/mlx.core.zeros", "python/_autosummary/mlx.core.zeros_like", "python/_autosummary/mlx.nn.value_and_grad", "python/_autosummary/mlx.optimizers.AdaDelta", "python/_autosummary/mlx.optimizers.Adafactor", "python/_autosummary/mlx.optimizers.Adagrad", "python/_autosummary/mlx.optimizers.Adam", "python/_autosummary/mlx.optimizers.AdamW", "python/_autosummary/mlx.optimizers.Adamax", "python/_autosummary/mlx.optimizers.Lion", "python/_autosummary/mlx.optimizers.Optimizer", "python/_autosummary/mlx.optimizers.OptimizerState", "python/_autosummary/mlx.optimizers.RMSprop", "python/_autosummary/mlx.optimizers.SGD", "python/_autosummary/mlx.utils.tree_flatten", "python/_autosummary/mlx.utils.tree_map", "python/_autosummary/mlx.utils.tree_unflatten", "python/array", "python/data_types", "python/devices_and_streams", "python/fft", "python/linalg", "python/nn", "python/nn/_autosummary/mlx.nn.ALiBi", "python/nn/_autosummary/mlx.nn.BatchNorm", "python/nn/_autosummary/mlx.nn.Conv1d", "python/nn/_autosummary/mlx.nn.Conv2d", "python/nn/_autosummary/mlx.nn.Dropout", "python/nn/_autosummary/mlx.nn.Dropout2d", "python/nn/_autosummary/mlx.nn.Dropout3d", "python/nn/_autosummary/mlx.nn.Embedding", "python/nn/_autosummary/mlx.nn.GELU", "python/nn/_autosummary/mlx.nn.GroupNorm", "python/nn/_autosummary/mlx.nn.InstanceNorm", "python/nn/_autosummary/mlx.nn.LayerNorm", "python/nn/_autosummary/mlx.nn.Linear", "python/nn/_autosummary/mlx.nn.Mish", "python/nn/_autosummary/mlx.nn.Module.apply", "python/nn/_autosummary/mlx.nn.Module.apply_to_modules", "python/nn/_autosummary/mlx.nn.Module.children", "python/nn/_autosummary/mlx.nn.Module.eval", "python/nn/_autosummary/mlx.nn.Module.filter_and_map", "python/nn/_autosummary/mlx.nn.Module.freeze", "python/nn/_autosummary/mlx.nn.Module.leaf_modules", "python/nn/_autosummary/mlx.nn.Module.load_weights", "python/nn/_autosummary/mlx.nn.Module.modules", "python/nn/_autosummary/mlx.nn.Module.named_modules", "python/nn/_autosummary/mlx.nn.Module.parameters", "python/nn/_autosummary/mlx.nn.Module.save_weights", "python/nn/_autosummary/mlx.nn.Module.train", "python/nn/_autosummary/mlx.nn.Module.trainable_parameters", "python/nn/_autosummary/mlx.nn.Module.training", "python/nn/_autosummary/mlx.nn.Module.unfreeze", "python/nn/_autosummary/mlx.nn.Module.update", "python/nn/_autosummary/mlx.nn.Module.update_modules", "python/nn/_autosummary/mlx.nn.MultiHeadAttention", "python/nn/_autosummary/mlx.nn.PReLU", "python/nn/_autosummary/mlx.nn.QuantizedLinear", "python/nn/_autosummary/mlx.nn.RMSNorm", "python/nn/_autosummary/mlx.nn.ReLU", "python/nn/_autosummary/mlx.nn.RoPE", "python/nn/_autosummary/mlx.nn.SELU", "python/nn/_autosummary/mlx.nn.Sequential", "python/nn/_autosummary/mlx.nn.SiLU", "python/nn/_autosummary/mlx.nn.SinusoidalPositionalEncoding", "python/nn/_autosummary/mlx.nn.Softshrink", "python/nn/_autosummary/mlx.nn.Step", "python/nn/_autosummary/mlx.nn.Transformer", "python/nn/_autosummary/mlx.nn.init.constant", "python/nn/_autosummary/mlx.nn.init.glorot_normal", "python/nn/_autosummary/mlx.nn.init.glorot_uniform", "python/nn/_autosummary/mlx.nn.init.he_normal", "python/nn/_autosummary/mlx.nn.init.he_uniform", "python/nn/_autosummary/mlx.nn.init.identity", "python/nn/_autosummary/mlx.nn.init.normal", "python/nn/_autosummary/mlx.nn.init.uniform", "python/nn/_autosummary_functions/mlx.nn.gelu", "python/nn/_autosummary_functions/mlx.nn.gelu_approx", "python/nn/_autosummary_functions/mlx.nn.gelu_fast_approx", "python/nn/_autosummary_functions/mlx.nn.losses.binary_cross_entropy", "python/nn/_autosummary_functions/mlx.nn.losses.cosine_similarity_loss", "python/nn/_autosummary_functions/mlx.nn.losses.cross_entropy", "python/nn/_autosummary_functions/mlx.nn.losses.gaussian_nll_loss", "python/nn/_autosummary_functions/mlx.nn.losses.hinge_loss", "python/nn/_autosummary_functions/mlx.nn.losses.huber_loss", "python/nn/_autosummary_functions/mlx.nn.losses.kl_div_loss", "python/nn/_autosummary_functions/mlx.nn.losses.l1_loss", "python/nn/_autosummary_functions/mlx.nn.losses.log_cosh_loss", "python/nn/_autosummary_functions/mlx.nn.losses.mse_loss", "python/nn/_autosummary_functions/mlx.nn.losses.nll_loss", "python/nn/_autosummary_functions/mlx.nn.losses.smooth_l1_loss", "python/nn/_autosummary_functions/mlx.nn.losses.triplet_loss", "python/nn/_autosummary_functions/mlx.nn.mish", "python/nn/_autosummary_functions/mlx.nn.prelu", "python/nn/_autosummary_functions/mlx.nn.relu", "python/nn/_autosummary_functions/mlx.nn.selu", "python/nn/_autosummary_functions/mlx.nn.silu", "python/nn/_autosummary_functions/mlx.nn.softshrink", "python/nn/_autosummary_functions/mlx.nn.step", "python/nn/functions", "python/nn/init", "python/nn/layers", "python/nn/losses", "python/nn/module", "python/ops", "python/optimizers", "python/random", "python/transforms", "python/tree_utils", "usage/function_transforms", "usage/indexing", "usage/lazy_evaluation", "usage/numpy", "usage/quick_start", "usage/saving_and_loading", "usage/unified_memory", "usage/using_streams"], "filenames": ["cpp/ops.rst", "dev/extensions.rst", "examples/linear_regression.rst", "examples/llama-inference.rst", "examples/mlp.rst", "index.rst", "install.rst", "python/_autosummary/mlx.core.Device.rst", "python/_autosummary/mlx.core.Dtype.rst", "python/_autosummary/mlx.core.Stream.rst", "python/_autosummary/mlx.core.abs.rst", "python/_autosummary/mlx.core.add.rst", "python/_autosummary/mlx.core.all.rst", "python/_autosummary/mlx.core.allclose.rst", "python/_autosummary/mlx.core.any.rst", "python/_autosummary/mlx.core.arange.rst", "python/_autosummary/mlx.core.arccos.rst", "python/_autosummary/mlx.core.arccosh.rst", "python/_autosummary/mlx.core.arcsin.rst", "python/_autosummary/mlx.core.arcsinh.rst", "python/_autosummary/mlx.core.arctan.rst", "python/_autosummary/mlx.core.arctanh.rst", "python/_autosummary/mlx.core.argmax.rst", "python/_autosummary/mlx.core.argmin.rst", "python/_autosummary/mlx.core.argpartition.rst", "python/_autosummary/mlx.core.argsort.rst", "python/_autosummary/mlx.core.array.rst", "python/_autosummary/mlx.core.array.T.rst", "python/_autosummary/mlx.core.array.abs.rst", "python/_autosummary/mlx.core.array.all.rst", "python/_autosummary/mlx.core.array.any.rst", "python/_autosummary/mlx.core.array.argmax.rst", "python/_autosummary/mlx.core.array.argmin.rst", "python/_autosummary/mlx.core.array.astype.rst", "python/_autosummary/mlx.core.array.cos.rst", "python/_autosummary/mlx.core.array.dtype.rst", "python/_autosummary/mlx.core.array.exp.rst", "python/_autosummary/mlx.core.array.item.rst", "python/_autosummary/mlx.core.array.log.rst", "python/_autosummary/mlx.core.array.log1p.rst", "python/_autosummary/mlx.core.array.logsumexp.rst", "python/_autosummary/mlx.core.array.max.rst", "python/_autosummary/mlx.core.array.mean.rst", "python/_autosummary/mlx.core.array.min.rst", "python/_autosummary/mlx.core.array.ndim.rst", "python/_autosummary/mlx.core.array.prod.rst", "python/_autosummary/mlx.core.array.reciprocal.rst", "python/_autosummary/mlx.core.array.reshape.rst", "python/_autosummary/mlx.core.array.round.rst", "python/_autosummary/mlx.core.array.rsqrt.rst", "python/_autosummary/mlx.core.array.shape.rst", "python/_autosummary/mlx.core.array.sin.rst", "python/_autosummary/mlx.core.array.size.rst", "python/_autosummary/mlx.core.array.split.rst", "python/_autosummary/mlx.core.array.sqrt.rst", "python/_autosummary/mlx.core.array.square.rst", "python/_autosummary/mlx.core.array.sum.rst", "python/_autosummary/mlx.core.array.tolist.rst", "python/_autosummary/mlx.core.array.transpose.rst", "python/_autosummary/mlx.core.array.var.rst", "python/_autosummary/mlx.core.array_equal.rst", "python/_autosummary/mlx.core.broadcast_to.rst", "python/_autosummary/mlx.core.ceil.rst", "python/_autosummary/mlx.core.clip.rst", "python/_autosummary/mlx.core.concatenate.rst", "python/_autosummary/mlx.core.conv1d.rst", "python/_autosummary/mlx.core.conv2d.rst", "python/_autosummary/mlx.core.convolve.rst", "python/_autosummary/mlx.core.cos.rst", "python/_autosummary/mlx.core.cosh.rst", "python/_autosummary/mlx.core.default_device.rst", "python/_autosummary/mlx.core.default_stream.rst", "python/_autosummary/mlx.core.dequantize.rst", "python/_autosummary/mlx.core.diag.rst", "python/_autosummary/mlx.core.diagonal.rst", "python/_autosummary/mlx.core.divide.rst", "python/_autosummary/mlx.core.divmod.rst", "python/_autosummary/mlx.core.equal.rst", "python/_autosummary/mlx.core.erf.rst", "python/_autosummary/mlx.core.erfinv.rst", "python/_autosummary/mlx.core.eval.rst", "python/_autosummary/mlx.core.exp.rst", "python/_autosummary/mlx.core.expand_dims.rst", "python/_autosummary/mlx.core.eye.rst", "python/_autosummary/mlx.core.fft.fft.rst", "python/_autosummary/mlx.core.fft.fft2.rst", "python/_autosummary/mlx.core.fft.fftn.rst", "python/_autosummary/mlx.core.fft.ifft.rst", "python/_autosummary/mlx.core.fft.ifft2.rst", "python/_autosummary/mlx.core.fft.ifftn.rst", "python/_autosummary/mlx.core.fft.irfft.rst", "python/_autosummary/mlx.core.fft.irfft2.rst", "python/_autosummary/mlx.core.fft.irfftn.rst", "python/_autosummary/mlx.core.fft.rfft.rst", "python/_autosummary/mlx.core.fft.rfft2.rst", "python/_autosummary/mlx.core.fft.rfftn.rst", "python/_autosummary/mlx.core.flatten.rst", "python/_autosummary/mlx.core.floor.rst", "python/_autosummary/mlx.core.floor_divide.rst", "python/_autosummary/mlx.core.full.rst", "python/_autosummary/mlx.core.grad.rst", "python/_autosummary/mlx.core.greater.rst", "python/_autosummary/mlx.core.greater_equal.rst", "python/_autosummary/mlx.core.identity.rst", "python/_autosummary/mlx.core.inner.rst", "python/_autosummary/mlx.core.isinf.rst", "python/_autosummary/mlx.core.isnan.rst", "python/_autosummary/mlx.core.isneginf.rst", "python/_autosummary/mlx.core.isposinf.rst", "python/_autosummary/mlx.core.jvp.rst", "python/_autosummary/mlx.core.less.rst", "python/_autosummary/mlx.core.less_equal.rst", "python/_autosummary/mlx.core.linalg.norm.rst", "python/_autosummary/mlx.core.linalg.qr.rst", "python/_autosummary/mlx.core.linspace.rst", "python/_autosummary/mlx.core.load.rst", "python/_autosummary/mlx.core.log.rst", "python/_autosummary/mlx.core.log10.rst", "python/_autosummary/mlx.core.log1p.rst", "python/_autosummary/mlx.core.log2.rst", "python/_autosummary/mlx.core.logaddexp.rst", "python/_autosummary/mlx.core.logical_and.rst", "python/_autosummary/mlx.core.logical_not.rst", "python/_autosummary/mlx.core.logical_or.rst", "python/_autosummary/mlx.core.logsumexp.rst", "python/_autosummary/mlx.core.matmul.rst", "python/_autosummary/mlx.core.max.rst", "python/_autosummary/mlx.core.maximum.rst", "python/_autosummary/mlx.core.mean.rst", "python/_autosummary/mlx.core.min.rst", "python/_autosummary/mlx.core.minimum.rst", "python/_autosummary/mlx.core.moveaxis.rst", "python/_autosummary/mlx.core.multiply.rst", "python/_autosummary/mlx.core.negative.rst", "python/_autosummary/mlx.core.new_stream.rst", "python/_autosummary/mlx.core.ones.rst", "python/_autosummary/mlx.core.ones_like.rst", "python/_autosummary/mlx.core.outer.rst", "python/_autosummary/mlx.core.pad.rst", "python/_autosummary/mlx.core.partition.rst", "python/_autosummary/mlx.core.prod.rst", "python/_autosummary/mlx.core.quantize.rst", "python/_autosummary/mlx.core.quantized_matmul.rst", "python/_autosummary/mlx.core.random.bernoulli.rst", "python/_autosummary/mlx.core.random.categorical.rst", "python/_autosummary/mlx.core.random.gumbel.rst", "python/_autosummary/mlx.core.random.key.rst", "python/_autosummary/mlx.core.random.normal.rst", "python/_autosummary/mlx.core.random.randint.rst", "python/_autosummary/mlx.core.random.seed.rst", "python/_autosummary/mlx.core.random.split.rst", "python/_autosummary/mlx.core.random.truncated_normal.rst", "python/_autosummary/mlx.core.random.uniform.rst", "python/_autosummary/mlx.core.reciprocal.rst", "python/_autosummary/mlx.core.repeat.rst", "python/_autosummary/mlx.core.reshape.rst", "python/_autosummary/mlx.core.round.rst", "python/_autosummary/mlx.core.rsqrt.rst", "python/_autosummary/mlx.core.save.rst", "python/_autosummary/mlx.core.save_gguf.rst", "python/_autosummary/mlx.core.save_safetensors.rst", "python/_autosummary/mlx.core.savez.rst", "python/_autosummary/mlx.core.savez_compressed.rst", "python/_autosummary/mlx.core.set_default_device.rst", "python/_autosummary/mlx.core.set_default_stream.rst", "python/_autosummary/mlx.core.sigmoid.rst", "python/_autosummary/mlx.core.sign.rst", "python/_autosummary/mlx.core.sin.rst", "python/_autosummary/mlx.core.sinh.rst", "python/_autosummary/mlx.core.softmax.rst", "python/_autosummary/mlx.core.sort.rst", "python/_autosummary/mlx.core.split.rst", "python/_autosummary/mlx.core.sqrt.rst", "python/_autosummary/mlx.core.square.rst", "python/_autosummary/mlx.core.squeeze.rst", "python/_autosummary/mlx.core.stack.rst", "python/_autosummary/mlx.core.stop_gradient.rst", "python/_autosummary/mlx.core.subtract.rst", "python/_autosummary/mlx.core.sum.rst", "python/_autosummary/mlx.core.swapaxes.rst", "python/_autosummary/mlx.core.take.rst", "python/_autosummary/mlx.core.take_along_axis.rst", "python/_autosummary/mlx.core.tan.rst", "python/_autosummary/mlx.core.tanh.rst", "python/_autosummary/mlx.core.tensordot.rst", "python/_autosummary/mlx.core.transpose.rst", "python/_autosummary/mlx.core.tri.rst", "python/_autosummary/mlx.core.tril.rst", "python/_autosummary/mlx.core.triu.rst", "python/_autosummary/mlx.core.value_and_grad.rst", "python/_autosummary/mlx.core.var.rst", "python/_autosummary/mlx.core.vjp.rst", "python/_autosummary/mlx.core.vmap.rst", "python/_autosummary/mlx.core.where.rst", "python/_autosummary/mlx.core.zeros.rst", "python/_autosummary/mlx.core.zeros_like.rst", "python/_autosummary/mlx.nn.value_and_grad.rst", "python/_autosummary/mlx.optimizers.AdaDelta.rst", "python/_autosummary/mlx.optimizers.Adafactor.rst", "python/_autosummary/mlx.optimizers.Adagrad.rst", "python/_autosummary/mlx.optimizers.Adam.rst", "python/_autosummary/mlx.optimizers.AdamW.rst", "python/_autosummary/mlx.optimizers.Adamax.rst", "python/_autosummary/mlx.optimizers.Lion.rst", "python/_autosummary/mlx.optimizers.Optimizer.rst", "python/_autosummary/mlx.optimizers.OptimizerState.rst", "python/_autosummary/mlx.optimizers.RMSprop.rst", "python/_autosummary/mlx.optimizers.SGD.rst", "python/_autosummary/mlx.utils.tree_flatten.rst", "python/_autosummary/mlx.utils.tree_map.rst", "python/_autosummary/mlx.utils.tree_unflatten.rst", "python/array.rst", "python/data_types.rst", "python/devices_and_streams.rst", "python/fft.rst", "python/linalg.rst", "python/nn.rst", "python/nn/_autosummary/mlx.nn.ALiBi.rst", "python/nn/_autosummary/mlx.nn.BatchNorm.rst", "python/nn/_autosummary/mlx.nn.Conv1d.rst", "python/nn/_autosummary/mlx.nn.Conv2d.rst", "python/nn/_autosummary/mlx.nn.Dropout.rst", "python/nn/_autosummary/mlx.nn.Dropout2d.rst", "python/nn/_autosummary/mlx.nn.Dropout3d.rst", "python/nn/_autosummary/mlx.nn.Embedding.rst", "python/nn/_autosummary/mlx.nn.GELU.rst", "python/nn/_autosummary/mlx.nn.GroupNorm.rst", "python/nn/_autosummary/mlx.nn.InstanceNorm.rst", "python/nn/_autosummary/mlx.nn.LayerNorm.rst", "python/nn/_autosummary/mlx.nn.Linear.rst", "python/nn/_autosummary/mlx.nn.Mish.rst", "python/nn/_autosummary/mlx.nn.Module.apply.rst", "python/nn/_autosummary/mlx.nn.Module.apply_to_modules.rst", "python/nn/_autosummary/mlx.nn.Module.children.rst", "python/nn/_autosummary/mlx.nn.Module.eval.rst", "python/nn/_autosummary/mlx.nn.Module.filter_and_map.rst", "python/nn/_autosummary/mlx.nn.Module.freeze.rst", "python/nn/_autosummary/mlx.nn.Module.leaf_modules.rst", "python/nn/_autosummary/mlx.nn.Module.load_weights.rst", "python/nn/_autosummary/mlx.nn.Module.modules.rst", "python/nn/_autosummary/mlx.nn.Module.named_modules.rst", "python/nn/_autosummary/mlx.nn.Module.parameters.rst", "python/nn/_autosummary/mlx.nn.Module.save_weights.rst", "python/nn/_autosummary/mlx.nn.Module.train.rst", "python/nn/_autosummary/mlx.nn.Module.trainable_parameters.rst", "python/nn/_autosummary/mlx.nn.Module.training.rst", "python/nn/_autosummary/mlx.nn.Module.unfreeze.rst", "python/nn/_autosummary/mlx.nn.Module.update.rst", "python/nn/_autosummary/mlx.nn.Module.update_modules.rst", "python/nn/_autosummary/mlx.nn.MultiHeadAttention.rst", "python/nn/_autosummary/mlx.nn.PReLU.rst", "python/nn/_autosummary/mlx.nn.QuantizedLinear.rst", "python/nn/_autosummary/mlx.nn.RMSNorm.rst", "python/nn/_autosummary/mlx.nn.ReLU.rst", "python/nn/_autosummary/mlx.nn.RoPE.rst", "python/nn/_autosummary/mlx.nn.SELU.rst", "python/nn/_autosummary/mlx.nn.Sequential.rst", "python/nn/_autosummary/mlx.nn.SiLU.rst", "python/nn/_autosummary/mlx.nn.SinusoidalPositionalEncoding.rst", "python/nn/_autosummary/mlx.nn.Softshrink.rst", "python/nn/_autosummary/mlx.nn.Step.rst", "python/nn/_autosummary/mlx.nn.Transformer.rst", "python/nn/_autosummary/mlx.nn.init.constant.rst", "python/nn/_autosummary/mlx.nn.init.glorot_normal.rst", "python/nn/_autosummary/mlx.nn.init.glorot_uniform.rst", "python/nn/_autosummary/mlx.nn.init.he_normal.rst", "python/nn/_autosummary/mlx.nn.init.he_uniform.rst", "python/nn/_autosummary/mlx.nn.init.identity.rst", "python/nn/_autosummary/mlx.nn.init.normal.rst", "python/nn/_autosummary/mlx.nn.init.uniform.rst", "python/nn/_autosummary_functions/mlx.nn.gelu.rst", "python/nn/_autosummary_functions/mlx.nn.gelu_approx.rst", "python/nn/_autosummary_functions/mlx.nn.gelu_fast_approx.rst", "python/nn/_autosummary_functions/mlx.nn.losses.binary_cross_entropy.rst", "python/nn/_autosummary_functions/mlx.nn.losses.cosine_similarity_loss.rst", "python/nn/_autosummary_functions/mlx.nn.losses.cross_entropy.rst", "python/nn/_autosummary_functions/mlx.nn.losses.gaussian_nll_loss.rst", "python/nn/_autosummary_functions/mlx.nn.losses.hinge_loss.rst", "python/nn/_autosummary_functions/mlx.nn.losses.huber_loss.rst", "python/nn/_autosummary_functions/mlx.nn.losses.kl_div_loss.rst", "python/nn/_autosummary_functions/mlx.nn.losses.l1_loss.rst", "python/nn/_autosummary_functions/mlx.nn.losses.log_cosh_loss.rst", "python/nn/_autosummary_functions/mlx.nn.losses.mse_loss.rst", "python/nn/_autosummary_functions/mlx.nn.losses.nll_loss.rst", "python/nn/_autosummary_functions/mlx.nn.losses.smooth_l1_loss.rst", "python/nn/_autosummary_functions/mlx.nn.losses.triplet_loss.rst", "python/nn/_autosummary_functions/mlx.nn.mish.rst", "python/nn/_autosummary_functions/mlx.nn.prelu.rst", "python/nn/_autosummary_functions/mlx.nn.relu.rst", "python/nn/_autosummary_functions/mlx.nn.selu.rst", "python/nn/_autosummary_functions/mlx.nn.silu.rst", "python/nn/_autosummary_functions/mlx.nn.softshrink.rst", "python/nn/_autosummary_functions/mlx.nn.step.rst", "python/nn/functions.rst", "python/nn/init.rst", "python/nn/layers.rst", "python/nn/losses.rst", "python/nn/module.rst", "python/ops.rst", "python/optimizers.rst", "python/random.rst", "python/transforms.rst", "python/tree_utils.rst", "usage/function_transforms.rst", "usage/indexing.rst", "usage/lazy_evaluation.rst", "usage/numpy.rst", "usage/quick_start.rst", "usage/saving_and_loading.rst", "usage/unified_memory.rst", "usage/using_streams.rst"], "titles": ["Operations", "Developer Documentation", "Linear Regression", "LLM inference", "Multi-Layer Perceptron", "MLX", "Build and Install", "mlx.core.Device", "mlx.core.Dtype", "mlx.core.Stream", "mlx.core.abs", "mlx.core.add", "mlx.core.all", "mlx.core.allclose", "mlx.core.any", "mlx.core.arange", "mlx.core.arccos", "mlx.core.arccosh", "mlx.core.arcsin", "mlx.core.arcsinh", "mlx.core.arctan", "mlx.core.arctanh", "mlx.core.argmax", "mlx.core.argmin", "mlx.core.argpartition", "mlx.core.argsort", "mlx.core.array", "mlx.core.array.T", "mlx.core.array.abs", "mlx.core.array.all", "mlx.core.array.any", "mlx.core.array.argmax", "mlx.core.array.argmin", "mlx.core.array.astype", "mlx.core.array.cos", "mlx.core.array.dtype", "mlx.core.array.exp", "mlx.core.array.item", "mlx.core.array.log", "mlx.core.array.log1p", "mlx.core.array.logsumexp", "mlx.core.array.max", "mlx.core.array.mean", "mlx.core.array.min", "mlx.core.array.ndim", "mlx.core.array.prod", "mlx.core.array.reciprocal", "mlx.core.array.reshape", "mlx.core.array.round", "mlx.core.array.rsqrt", "mlx.core.array.shape", "mlx.core.array.sin", "mlx.core.array.size", "mlx.core.array.split", "mlx.core.array.sqrt", "mlx.core.array.square", "mlx.core.array.sum", "mlx.core.array.tolist", "mlx.core.array.transpose", "mlx.core.array.var", "mlx.core.array_equal", "mlx.core.broadcast_to", "mlx.core.ceil", "mlx.core.clip", "mlx.core.concatenate", "mlx.core.conv1d", "mlx.core.conv2d", "mlx.core.convolve", "mlx.core.cos", "mlx.core.cosh", "mlx.core.default_device", "mlx.core.default_stream", "mlx.core.dequantize", "mlx.core.diag", "mlx.core.diagonal", "mlx.core.divide", "mlx.core.divmod", "mlx.core.equal", "mlx.core.erf", "mlx.core.erfinv", "mlx.core.eval", "mlx.core.exp", "mlx.core.expand_dims", "mlx.core.eye", "mlx.core.fft.fft", "mlx.core.fft.fft2", "mlx.core.fft.fftn", "mlx.core.fft.ifft", "mlx.core.fft.ifft2", "mlx.core.fft.ifftn", "mlx.core.fft.irfft", "mlx.core.fft.irfft2", "mlx.core.fft.irfftn", "mlx.core.fft.rfft", "mlx.core.fft.rfft2", "mlx.core.fft.rfftn", "mlx.core.flatten", "mlx.core.floor", "mlx.core.floor_divide", "mlx.core.full", "mlx.core.grad", "mlx.core.greater", "mlx.core.greater_equal", "mlx.core.identity", "mlx.core.inner", "mlx.core.isinf", "mlx.core.isnan", "mlx.core.isneginf", "mlx.core.isposinf", "mlx.core.jvp", "mlx.core.less", "mlx.core.less_equal", "mlx.core.linalg.norm", "mlx.core.linalg.qr", "mlx.core.linspace", "mlx.core.load", "mlx.core.log", "mlx.core.log10", "mlx.core.log1p", "mlx.core.log2", "mlx.core.logaddexp", "mlx.core.logical_and", "mlx.core.logical_not", "mlx.core.logical_or", "mlx.core.logsumexp", "mlx.core.matmul", "mlx.core.max", "mlx.core.maximum", "mlx.core.mean", "mlx.core.min", "mlx.core.minimum", "mlx.core.moveaxis", "mlx.core.multiply", "mlx.core.negative", "mlx.core.new_stream", "mlx.core.ones", "mlx.core.ones_like", "mlx.core.outer", "mlx.core.pad", "mlx.core.partition", "mlx.core.prod", "mlx.core.quantize", "mlx.core.quantized_matmul", "mlx.core.random.bernoulli", "mlx.core.random.categorical", "mlx.core.random.gumbel", "mlx.core.random.key", "mlx.core.random.normal", "mlx.core.random.randint", "mlx.core.random.seed", "mlx.core.random.split", "mlx.core.random.truncated_normal", "mlx.core.random.uniform", "mlx.core.reciprocal", "mlx.core.repeat", "mlx.core.reshape", "mlx.core.round", "mlx.core.rsqrt", "mlx.core.save", "mlx.core.save_gguf", "mlx.core.save_safetensors", "mlx.core.savez", "mlx.core.savez_compressed", "mlx.core.set_default_device", "mlx.core.set_default_stream", "mlx.core.sigmoid", "mlx.core.sign", "mlx.core.sin", "mlx.core.sinh", "mlx.core.softmax", "mlx.core.sort", "mlx.core.split", "mlx.core.sqrt", "mlx.core.square", "mlx.core.squeeze", "mlx.core.stack", "mlx.core.stop_gradient", "mlx.core.subtract", "mlx.core.sum", "mlx.core.swapaxes", "mlx.core.take", "mlx.core.take_along_axis", "mlx.core.tan", "mlx.core.tanh", "mlx.core.tensordot", "mlx.core.transpose", "mlx.core.tri", "mlx.core.tril", "mlx.core.triu", "mlx.core.value_and_grad", "mlx.core.var", "mlx.core.vjp", "mlx.core.vmap", "mlx.core.where", "mlx.core.zeros", "mlx.core.zeros_like", "mlx.nn.value_and_grad", "mlx.optimizers.AdaDelta", "mlx.optimizers.Adafactor", "mlx.optimizers.Adagrad", "mlx.optimizers.Adam", "mlx.optimizers.AdamW", "mlx.optimizers.Adamax", "mlx.optimizers.Lion", "mlx.optimizers.Optimizer", "mlx.optimizers.OptimizerState", "mlx.optimizers.RMSprop", "mlx.optimizers.SGD", "mlx.utils.tree_flatten", "mlx.utils.tree_map", "mlx.utils.tree_unflatten", "Array", "Data Types", "Devices and Streams", "FFT", "Linear Algebra", "Neural Networks", "mlx.nn.ALiBi", "mlx.nn.BatchNorm", "mlx.nn.Conv1d", "mlx.nn.Conv2d", "mlx.nn.Dropout", "mlx.nn.Dropout2d", "mlx.nn.Dropout3d", "mlx.nn.Embedding", "mlx.nn.GELU", "mlx.nn.GroupNorm", "mlx.nn.InstanceNorm", "mlx.nn.LayerNorm", "mlx.nn.Linear", "mlx.nn.Mish", "mlx.nn.Module.apply", "mlx.nn.Module.apply_to_modules", "mlx.nn.Module.children", "mlx.nn.Module.eval", "mlx.nn.Module.filter_and_map", "mlx.nn.Module.freeze", "mlx.nn.Module.leaf_modules", "mlx.nn.Module.load_weights", "mlx.nn.Module.modules", "mlx.nn.Module.named_modules", "mlx.nn.Module.parameters", "mlx.nn.Module.save_weights", "mlx.nn.Module.train", "mlx.nn.Module.trainable_parameters", "mlx.nn.Module.training", "mlx.nn.Module.unfreeze", "mlx.nn.Module.update", "mlx.nn.Module.update_modules", "mlx.nn.MultiHeadAttention", "mlx.nn.PReLU", "mlx.nn.QuantizedLinear", "mlx.nn.RMSNorm", "mlx.nn.ReLU", "mlx.nn.RoPE", "mlx.nn.SELU", "mlx.nn.Sequential", "mlx.nn.SiLU", "mlx.nn.SinusoidalPositionalEncoding", "mlx.nn.Softshrink", "mlx.nn.Step", "mlx.nn.Transformer", "mlx.nn.init.constant", "mlx.nn.init.glorot_normal", "mlx.nn.init.glorot_uniform", "mlx.nn.init.he_normal", "mlx.nn.init.he_uniform", "mlx.nn.init.identity", "mlx.nn.init.normal", "mlx.nn.init.uniform", "mlx.nn.gelu", "mlx.nn.gelu_approx", "mlx.nn.gelu_fast_approx", "mlx.nn.losses.binary_cross_entropy", "mlx.nn.losses.cosine_similarity_loss", "mlx.nn.losses.cross_entropy", "mlx.nn.losses.gaussian_nll_loss", "mlx.nn.losses.hinge_loss", "mlx.nn.losses.huber_loss", "mlx.nn.losses.kl_div_loss", "mlx.nn.losses.l1_loss", "mlx.nn.losses.log_cosh_loss", "mlx.nn.losses.mse_loss", "mlx.nn.losses.nll_loss", "mlx.nn.losses.smooth_l1_loss", "mlx.nn.losses.triplet_loss", "mlx.nn.mish", "mlx.nn.prelu", "mlx.nn.relu", "mlx.nn.selu", "mlx.nn.silu", "mlx.nn.softshrink", "mlx.nn.step", "Functions", "Initializers", "Layers", "Loss Functions", "Module", "Operations", "Optimizers", "Random", "Transforms", "Tree Utils", "Function Transforms", "Indexing Arrays", "Lazy Evaluation", "Conversion to NumPy and Other Frameworks", "Quick Start Guide", "Saving and Loading Arrays", "Unified Memory", "Using Streams"], "terms": {"mlx": [1, 2, 3, 4, 6, 216, 294, 297, 299, 300, 302, 303, 304, 305, 306, 307, 308, 309], "provid": [1, 3, 72, 100, 184, 189, 209, 216, 231, 236, 238, 246, 247, 248, 251, 261, 293, 297, 308, 310], "open": [1, 6, 15, 148, 152], "flexibl": [1, 5, 248], "which": [1, 3, 4, 5, 6, 15, 33, 74, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 100, 105, 106, 107, 108, 109, 112, 113, 115, 141, 144, 145, 154, 155, 158, 159, 160, 161, 162, 174, 175, 180, 189, 191, 192, 222, 223, 225, 231, 235, 254, 275, 278, 284, 294, 300, 303, 304, 305, 309, 310], "user": [1, 3, 216], "mai": [1, 112, 222, 303, 304], "add": [1, 3, 82, 120, 138, 141, 219, 220, 303, 309], "special": 1, "without": [1, 3, 5, 176, 249, 293, 302, 305, 306, 309], "much": [1, 3, 305], "hassl": 1, "while": [1, 3, 6, 155, 254, 305, 306], "librari": [1, 6, 216], "suppli": 1, "effici": [1, 3, 5, 222, 254, 305, 307], "can": [1, 3, 5, 6, 11, 15, 47, 58, 74, 75, 76, 77, 80, 101, 102, 110, 111, 112, 120, 127, 130, 132, 143, 144, 148, 151, 152, 159, 177, 189, 216, 224, 235, 246, 256, 275, 294, 297, 299, 300, 302, 303, 304, 305, 306, 307, 308, 309, 310], "compos": [1, 5, 216, 303, 307], "ani": [1, 3, 5, 15, 208, 209, 210, 216, 225, 231, 232, 235, 251, 261, 294, 302, 303, 305, 307, 308, 309], "number": [1, 15, 52, 66, 72, 83, 100, 103, 109, 114, 138, 141, 142, 144, 147, 150, 152, 154, 156, 184, 186, 189, 191, 192, 216, 218, 219, 220, 222, 223, 226, 227, 249, 250, 261, 263, 264, 265, 266, 300, 303, 310], "applic": [1, 6], "aris": [1, 306], "case": [1, 3, 86, 89, 90, 92, 93, 94, 95, 96, 113, 125, 155, 174, 222, 255, 260, 284, 289, 291, 292, 303, 307, 308, 309, 310], "where": [1, 4, 83, 141, 189, 192, 218, 219, 220, 221, 222, 223, 225, 226, 227, 228, 229, 235, 250, 252, 255, 257, 260, 265, 266, 270, 271, 272, 276, 287, 289, 290, 292, 303, 304], "new": [1, 4, 61, 74, 131, 155, 175, 185, 209, 249, 297, 299, 304, 305, 306], "function": [1, 2, 3, 4, 5, 13, 76, 78, 79, 100, 109, 112, 113, 125, 165, 189, 191, 192, 196, 209, 216, 225, 230, 232, 236, 246, 250, 256, 259, 260, 261, 270, 271, 272, 286, 291, 292, 294, 299, 300, 302, 304, 305, 306, 308], "highli": [1, 6], "optim": [1, 2, 4, 5, 247, 303, 305], "ar": [1, 2, 3, 4, 5, 6, 13, 15, 60, 61, 63, 67, 74, 83, 85, 86, 88, 89, 91, 92, 94, 95, 96, 100, 105, 106, 107, 108, 109, 112, 113, 115, 125, 137, 138, 139, 141, 142, 143, 144, 145, 148, 151, 152, 161, 162, 174, 175, 180, 189, 191, 192, 203, 208, 209, 218, 219, 220, 221, 222, 223, 226, 227, 228, 229, 238, 249, 251, 273, 275, 276, 293, 297, 302, 303, 304, 305, 306, 307, 308, 309], "need": [1, 3, 4, 5, 60, 141, 216, 247, 248, 258, 261, 300, 303, 305, 306, 307, 309], "For": [1, 3, 6, 112, 141, 210, 216, 218, 222, 231, 236, 243, 246, 251, 254, 258, 263, 264, 265, 266, 294, 300, 304, 305, 306, 307, 308, 309], "you": [1, 3, 4, 5, 6, 216, 258, 261, 294, 300, 303, 304, 306, 308, 309], "design": [1, 2, 5, 300, 309], "your": [1, 3, 6, 297, 303, 305], "own": [1, 6, 306], "link": [1, 6], "top": [1, 229], "core": [1, 2, 3, 4, 216, 218, 227, 238, 241, 244, 262, 263, 264, 265, 266, 267, 268, 269, 273, 275, 294, 297, 299, 306, 307], "we": [1, 2, 3, 4, 72, 141, 142, 201, 203, 216, 224, 256, 300, 302, 303, 305, 309], "inner": 1, "work": [1, 3, 6, 303, 304, 305], "go": [1, 3, 303], "over": [1, 3, 4, 12, 14, 22, 23, 24, 25, 65, 66, 86, 89, 92, 95, 104, 112, 114, 124, 126, 128, 129, 139, 140, 157, 169, 170, 178, 184, 190, 218, 219, 220, 226, 228, 252, 275, 303], "simpl": [1, 3, 4, 216, 224, 293, 303, 305], "learn": [1, 2, 4, 5, 197, 198, 199, 200, 201, 202, 203, 206, 207, 218, 226, 227, 228, 250, 252], "step": [1, 3, 4, 15, 198, 216], "involv": [1, 299], "ad": [1, 2, 6, 197, 198, 199, 200, 201, 202, 206, 227, 297, 305, 308], "let": [1, 2, 3, 303, 305, 306], "s": [1, 2, 3, 4, 35, 44, 72, 85, 86, 88, 89, 91, 92, 94, 95, 100, 112, 115, 128, 137, 141, 144, 156, 159, 160, 189, 190, 192, 196, 204, 216, 235, 236, 238, 242, 246, 299, 300, 303, 305, 306, 307, 308, 309], "sai": [1, 3, 294, 305], "would": [1, 3, 304, 305, 306, 309], "like": [1, 3, 5, 136, 195, 223, 281, 303, 305, 306, 307, 309], "an": [1, 3, 4, 6, 8, 12, 14, 26, 61, 65, 66, 80, 83, 96, 99, 103, 112, 115, 126, 129, 131, 135, 136, 138, 140, 141, 142, 154, 155, 156, 171, 174, 179, 180, 181, 184, 186, 192, 194, 195, 197, 204, 205, 208, 209, 216, 221, 226, 228, 229, 231, 249, 250, 251, 261, 262, 263, 264, 265, 266, 267, 268, 269, 271, 287, 294, 300, 302, 303, 304, 305, 306, 307, 308, 309, 310], "take": [1, 3, 4, 100, 109, 127, 130, 136, 142, 181, 189, 191, 192, 195, 249, 300, 303, 304, 308, 309, 310], "two": [1, 11, 13, 60, 74, 75, 77, 85, 88, 94, 101, 102, 110, 111, 113, 120, 125, 127, 130, 132, 137, 179, 251, 274, 303, 304, 309], "arrai": [1, 3, 4, 5, 8, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127, 128, 129, 130, 131, 132, 133, 135, 136, 137, 138, 139, 140, 141, 142, 143, 144, 145, 146, 147, 148, 150, 151, 152, 153, 154, 155, 156, 157, 158, 159, 160, 161, 162, 165, 166, 167, 168, 169, 170, 171, 172, 173, 174, 175, 176, 177, 178, 179, 180, 181, 182, 183, 184, 185, 186, 187, 188, 189, 190, 191, 192, 193, 194, 195, 216, 218, 231, 238, 241, 244, 250, 262, 263, 264, 265, 266, 267, 268, 269, 273, 274, 275, 276, 277, 278, 279, 280, 281, 282, 283, 284, 285, 286, 287, 292, 294, 297, 303, 305, 306, 307, 309], "x": [1, 2, 3, 4, 78, 103, 112, 142, 145, 156, 161, 165, 187, 188, 193, 203, 209, 216, 218, 225, 226, 227, 228, 229, 230, 231, 250, 252, 253, 255, 257, 258, 260, 270, 271, 272, 284, 286, 287, 288, 289, 290, 291, 292, 297, 299, 303, 304, 305, 306, 307, 309], "y": [1, 2, 3, 4, 193, 199, 216, 218, 222, 226, 227, 228, 229, 252, 277, 284, 299, 303, 305, 306], "scale": [1, 3, 72, 141, 142, 198, 222, 223, 249, 254, 255, 258, 289], "them": [1, 3, 216, 236, 246, 309], "both": [1, 11, 75, 76, 77, 101, 102, 110, 111, 112, 120, 127, 130, 132, 144, 177, 227, 299, 303, 307, 309], "some": [1, 2, 3, 4, 236, 246, 303, 305], "coeffici": [1, 197, 198, 200, 201, 202, 203], "alpha": [1, 141, 201, 206, 255, 285, 287, 289], "beta": [1, 72, 141, 200, 201, 202, 203, 218, 226, 227, 228, 284], "respect": [1, 2, 4, 100, 141, 189, 209, 216, 218, 225, 226, 227, 228, 297, 303, 307], "togeth": [1, 4, 141, 209], "get": [1, 2, 4, 6, 66, 146, 205, 216, 303, 305, 309], "z": [1, 305], "well": [1, 3, 216, 236, 246, 249, 305], "veri": [1, 3, 249, 305, 309], "easili": 1, "do": [1, 3, 6, 201, 216, 237, 246, 294, 297, 303, 305], "just": [1, 4, 304], "write": [1, 3, 216, 306], "out": [1, 6, 222, 223, 243, 303, 304], "follow": [1, 3, 4, 5, 6, 15, 67, 72, 112, 141, 197, 198, 199, 200, 201, 202, 203, 207, 216, 271, 272, 279, 300, 303, 309], "import": [1, 2, 3, 4, 6, 112, 161, 189, 208, 209, 210, 216, 218, 227, 238, 273, 275, 294, 297, 303, 304, 305, 306, 307], "mx": [1, 2, 3, 4, 96, 112, 113, 115, 161, 189, 216, 218, 227, 231, 238, 242, 253, 262, 263, 264, 265, 266, 267, 268, 269, 273, 274, 275, 279, 288, 294, 297, 299, 300, 303, 304, 305, 306, 307, 308, 309, 310], "def": [1, 2, 3, 4, 189, 216, 297, 303, 304, 305, 306, 309], "simple_axpbi": 1, "float": [1, 13, 15, 57, 98, 99, 112, 142, 143, 148, 151, 152, 197, 198, 199, 200, 201, 202, 203, 206, 207, 212, 218, 221, 222, 223, 226, 227, 228, 231, 252, 254, 258, 260, 261, 262, 263, 264, 265, 266, 268, 269, 274, 275, 276, 278, 284, 285, 291, 292], "return": [1, 2, 3, 4, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 33, 37, 50, 57, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 72, 73, 74, 75, 76, 77, 78, 79, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127, 128, 129, 130, 131, 132, 133, 135, 136, 137, 138, 139, 140, 141, 142, 143, 144, 145, 146, 147, 148, 150, 151, 152, 153, 154, 155, 156, 157, 165, 166, 167, 168, 169, 170, 171, 172, 173, 174, 175, 176, 177, 178, 179, 180, 181, 182, 183, 184, 185, 186, 187, 188, 189, 190, 191, 192, 193, 194, 195, 196, 208, 209, 210, 216, 233, 235, 237, 239, 240, 241, 244, 251, 262, 263, 264, 265, 266, 267, 268, 269, 273, 274, 275, 276, 277, 278, 279, 280, 281, 282, 283, 284, 285, 294, 297, 302, 303, 304, 305, 306, 308, 309], "thi": [1, 3, 4, 6, 12, 13, 14, 15, 22, 23, 24, 25, 109, 112, 113, 120, 124, 125, 126, 128, 129, 139, 140, 144, 169, 170, 171, 178, 180, 190, 216, 221, 222, 223, 232, 233, 235, 236, 239, 240, 241, 244, 246, 247, 248, 249, 251, 260, 263, 264, 265, 266, 271, 272, 281, 292, 297, 302, 303, 305, 306, 308], "perform": [1, 3, 5, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 125, 142, 156, 169, 180, 216, 226, 261, 265, 266, 304, 305, 309], "leav": [1, 209], "differenti": [1, 5], "howev": [1, 216, 225, 226, 300, 305, 306], "vector": [1, 2, 5, 104, 109, 112, 180, 191, 192, 224, 275, 307], "math": [1, 3, 285], "often": [1, 223], "realiz": 1, "axpbi": 1, "routin": 1, "defin": [1, 2, 3, 4, 6, 112, 142, 205, 208, 306], "same": [1, 3, 6, 13, 60, 61, 66, 67, 90, 93, 94, 95, 100, 109, 138, 144, 156, 191, 193, 216, 218, 221, 226, 227, 251, 262, 263, 264, 265, 266, 267, 268, 269, 275, 285, 297, 300, 304, 309], "realli": 1, "part": [1, 303, 304], "doe": [1, 3, 6, 216, 304, 305, 306], "fast": [1, 225, 272, 309], "so": [1, 3, 6, 100, 189, 221, 299, 305, 309], "decid": [1, 209, 235], "want": [1, 3, 303, 309], "reli": 1, "acceler": [1, 218], "framework": [1, 5], "continu": [1, 303], "impos": 1, "our": [1, 3, 4, 197, 198, 199, 200, 202, 203, 256], "assumpt": 1, "also": [1, 3, 4, 5, 6, 11, 75, 76, 77, 86, 89, 92, 95, 101, 102, 110, 111, 120, 127, 130, 132, 141, 177, 196, 205, 216, 235, 247, 249, 251, 255, 257, 270, 289, 290, 293, 299, 303, 304, 305, 306, 307, 310], "assum": [1, 3, 113, 209, 216, 226], "how": [1, 3, 4, 216, 219, 220, 224, 304, 309], "gradient": [1, 2, 4, 100, 176, 189, 196, 197, 198, 200, 201, 202, 203, 207, 216, 236, 247, 251, 261, 281, 297, 299, 303, 304, 305, 306, 307], "ins": 1, "what": [1, 3, 209], "coincid": 1, "right": [1, 6, 141, 225, 271, 272, 276, 278, 285], "place": [1, 3, 156, 305, 306], "cours": [1, 303], "The": [1, 3, 4, 5, 6, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 33, 35, 44, 50, 57, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 72, 73, 74, 75, 76, 77, 78, 79, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127, 128, 129, 130, 131, 132, 133, 135, 136, 137, 138, 139, 140, 141, 142, 143, 144, 145, 146, 147, 148, 150, 151, 152, 153, 154, 155, 159, 160, 165, 166, 167, 168, 169, 170, 172, 173, 174, 175, 176, 177, 178, 179, 180, 181, 182, 183, 184, 185, 186, 187, 188, 189, 190, 191, 192, 193, 194, 195, 196, 197, 198, 199, 200, 201, 202, 203, 204, 205, 206, 207, 208, 209, 210, 212, 218, 219, 220, 221, 222, 223, 224, 226, 227, 228, 229, 232, 238, 242, 247, 248, 249, 251, 252, 254, 256, 258, 260, 261, 262, 263, 264, 265, 266, 267, 268, 269, 273, 274, 275, 276, 277, 278, 279, 280, 281, 282, 283, 284, 285, 292, 294, 297, 299, 303, 304, 305, 306, 307, 308, 309, 310], "structur": [1, 303], "from": [1, 3, 4, 5, 72, 74, 91, 92, 94, 95, 99, 112, 115, 125, 136, 141, 143, 144, 145, 146, 148, 151, 161, 174, 176, 177, 180, 181, 193, 195, 208, 209, 210, 216, 229, 236, 238, 249, 263, 264, 265, 266, 268, 269, 276, 284, 294, 302, 303, 305, 306, 307, 308, 309], "frontend": 1, "api": [1, 303], "redirect": 1, "when": [1, 3, 5, 6, 112, 115, 219, 220, 265, 266, 279, 284, 297, 300, 309], "appropri": 1, "fallback": 1, "metal": 1, "vjp": [1, 307], "jvp": [1, 307], "In": [1, 3, 4, 125, 141, 197, 199, 200, 202, 203, 209, 216, 222, 226, 297, 302, 303, 305, 308, 309], "one": [1, 3, 6, 57, 63, 66, 82, 83, 112, 118, 125, 142, 144, 174, 177, 246, 275, 309], "sentenc": 1, "comput": [1, 2, 3, 4, 5, 6, 72, 100, 109, 112, 120, 128, 137, 141, 169, 176, 184, 189, 190, 191, 196, 197, 198, 200, 201, 202, 203, 216, 218, 226, 227, 228, 236, 247, 251, 252, 254, 261, 263, 264, 265, 266, 271, 272, 273, 274, 275, 276, 277, 278, 279, 280, 281, 282, 283, 284, 285, 299, 303, 307, 309], "graph": [1, 3, 4, 5, 303], "rule": 1, "evalu": [1, 3, 4, 5, 80, 109, 191, 216, 234, 243, 297, 299, 307], "said": [1, 3], "start": [1, 2, 3, 5, 6, 15, 114, 171, 304, 309], "discuss": 1, "more": [1, 4, 8, 57, 74, 125, 159, 160, 216, 218, 222, 254, 258, 261, 263, 264, 265, 266, 300, 303, 304, 307, 309], "detail": [1, 8, 197, 199, 200, 202, 203, 216, 222, 254, 258, 263, 264, 265, 266, 304, 307], "thei": [1, 2, 3, 13, 67, 203, 256, 277, 297, 302, 305, 307, 308, 309], "c": [1, 3, 112, 212, 218, 219, 220, 222, 223, 227, 306, 307, 309], "scalar": [1, 11, 13, 26, 37, 57, 60, 61, 63, 75, 76, 77, 98, 99, 100, 101, 102, 110, 111, 112, 114, 120, 121, 122, 123, 125, 127, 130, 132, 138, 148, 151, 152, 159, 177, 189, 193, 196, 285, 303, 305, 307], "sum": [1, 2, 11, 104, 112, 124, 169, 184, 216, 273, 274, 275, 276, 277, 278, 279, 280, 281, 282, 283, 284, 285, 304, 306], "element": [1, 10, 11, 16, 17, 18, 19, 20, 21, 24, 52, 62, 68, 69, 72, 75, 76, 77, 78, 79, 81, 83, 97, 98, 101, 102, 105, 106, 107, 108, 110, 111, 116, 117, 118, 119, 120, 121, 122, 123, 127, 130, 132, 133, 139, 141, 142, 153, 154, 157, 165, 166, 167, 168, 172, 173, 177, 180, 182, 183, 189, 193, 221, 222, 223, 230, 250, 254, 257, 286, 287, 290, 303], "wise": [1, 10, 11, 16, 17, 18, 19, 20, 21, 62, 68, 69, 75, 76, 77, 78, 79, 81, 97, 98, 101, 102, 110, 111, 116, 117, 118, 119, 120, 121, 122, 123, 127, 130, 132, 133, 153, 157, 165, 166, 167, 168, 172, 173, 177, 182, 183, 222, 223, 230, 250, 257, 286, 287, 290], "numpi": [1, 3, 4, 5, 11, 13, 15, 61, 75, 76, 77, 101, 102, 110, 111, 120, 125, 127, 130, 132, 177, 305, 307, 308], "style": [1, 11, 13, 75, 76, 77, 101, 102, 110, 111, 120, 125, 127, 130, 132, 177], "broadcast": [1, 11, 13, 61, 63, 75, 76, 77, 99, 101, 102, 110, 111, 120, 125, 127, 130, 132, 143, 144, 151, 152, 177, 181, 193, 249], "between": [1, 5, 63, 96, 261, 274, 277, 278, 281, 305, 309], "input": [1, 2, 3, 10, 11, 12, 13, 14, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 73, 74, 75, 76, 77, 78, 79, 81, 82, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 100, 101, 102, 104, 105, 106, 107, 108, 109, 110, 111, 112, 113, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127, 128, 129, 130, 131, 132, 133, 136, 137, 138, 139, 140, 141, 142, 150, 153, 154, 155, 156, 157, 165, 166, 167, 168, 169, 170, 171, 172, 173, 174, 175, 176, 177, 178, 179, 180, 181, 182, 183, 184, 185, 187, 188, 189, 190, 192, 193, 195, 218, 219, 220, 222, 223, 224, 226, 227, 228, 229, 249, 251, 252, 254, 260, 261, 262, 263, 264, 265, 266, 267, 268, 269, 273, 274, 276, 277, 278, 279, 281, 283, 285, 292, 294, 303, 304, 307, 308], "upcast": 1, "const": [1, 276], "factor": [1, 113, 275], "streamordevic": 1, "stream": [1, 5, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 28, 29, 30, 31, 32, 33, 34, 36, 38, 39, 40, 41, 42, 43, 45, 46, 47, 48, 49, 51, 53, 54, 55, 56, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 71, 72, 73, 74, 75, 76, 77, 78, 79, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 101, 102, 103, 104, 105, 106, 107, 108, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127, 128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141, 142, 143, 144, 145, 147, 148, 150, 151, 152, 153, 154, 155, 156, 157, 164, 165, 166, 167, 168, 169, 170, 171, 172, 173, 174, 175, 176, 177, 178, 179, 180, 181, 182, 183, 184, 185, 186, 187, 188, 190, 193, 194, 195, 309], "schedul": [1, 309], "itself": 1, "call": [1, 3, 4, 27, 98, 216, 224, 236, 246, 256, 297, 299, 303, 305], "other": [1, 3, 5, 112, 203, 216, 237, 297, 304, 305, 307], "within": [1, 24], "simplest": [1, 216], "wai": [1, 3, 6, 216, 303, 304], "about": [1, 3, 4, 305, 309], "term": [1, 197, 198, 199, 200, 201, 202, 206, 276], "exist": [1, 3, 236, 246], "auto": [1, 6], "ax": [1, 12, 14, 22, 23, 58, 82, 85, 86, 88, 89, 91, 92, 94, 95, 96, 104, 112, 124, 126, 128, 129, 138, 140, 169, 174, 178, 179, 184, 185, 190, 303], "multipli": [1, 141, 142, 221, 258], "earlier": 1, "goal": 1, "themselv": 1, "contain": [1, 3, 24, 25, 50, 74, 90, 91, 92, 112, 121, 122, 123, 141, 171, 193, 216, 235, 237, 238, 261, 294, 297, 303], "act": [1, 281], "data": [1, 4, 5, 8, 15, 83, 93, 94, 99, 103, 114, 135, 151, 186, 194, 223, 262, 263, 264, 265, 266, 267, 268, 269, 304, 306], "nor": [1, 100, 189], "rather": [1, 303, 309], "easi": [1, 216], "interfac": 1, "block": [1, 3, 261], "A": [1, 3, 5, 6, 50, 60, 100, 109, 112, 113, 115, 124, 125, 141, 143, 144, 145, 147, 148, 151, 152, 171, 175, 189, 191, 192, 196, 200, 202, 208, 209, 210, 216, 218, 222, 226, 227, 228, 230, 235, 239, 240, 247, 248, 252, 256, 258, 261, 263, 264, 266, 272, 285, 286, 297, 299, 303, 305, 306], "It": [1, 3, 6, 100, 189, 204, 216, 248, 251, 306, 308], "creat": [1, 3, 6, 83, 103, 216, 297, 299, 304, 306], "output": [1, 3, 6, 12, 13, 14, 15, 24, 61, 83, 90, 93, 94, 95, 99, 100, 103, 112, 114, 124, 126, 128, 129, 135, 136, 139, 140, 143, 144, 145, 147, 148, 151, 152, 161, 162, 169, 174, 178, 181, 186, 189, 190, 191, 192, 193, 194, 195, 218, 219, 220, 227, 229, 249, 251, 260, 261, 263, 264, 265, 266, 273, 274, 275, 276, 277, 278, 279, 280, 281, 282, 283, 284, 285, 292, 294, 303, 304, 305, 306, 307, 308, 309], "given": [1, 12, 14, 24, 61, 63, 64, 72, 74, 80, 82, 84, 85, 86, 87, 88, 89, 93, 94, 95, 99, 112, 124, 126, 128, 129, 140, 148, 156, 169, 171, 178, 186, 187, 188, 190, 221, 235, 249, 274, 276], "set": [1, 3, 4, 6, 198, 205, 225, 229, 234, 236, 243, 246, 247, 251, 254, 260, 274, 285, 292, 297, 300, 303, 305], "further": [1, 6, 303], "class": [1, 3, 4, 7, 8, 9, 26, 197, 198, 199, 200, 201, 202, 203, 204, 205, 206, 207, 217, 218, 219, 220, 221, 222, 223, 224, 225, 226, 227, 228, 229, 230, 249, 250, 251, 252, 253, 254, 255, 256, 257, 258, 259, 260, 261, 270, 271, 272, 273, 274, 275, 276, 277, 278, 279, 280, 281, 282, 283, 284, 285, 286, 287, 288, 289, 290, 291, 292, 297], "under": [1, 112], "These": [1, 181, 275, 309], "word": 1, "bit": [1, 72, 141, 142, 212, 231, 251], "abstract": 1, "back": [1, 3, 306], "give": [1, 3, 4, 24], "ourselv": 1, "concret": [1, 229, 305, 309], "imag": [1, 220, 222, 223], "public": [1, 216], "explicit": [1, 300, 306], "alpha_": 1, "beta_": 1, "must": [1, 6, 63, 80, 99, 112, 143, 144, 148, 151, 152, 193, 306], "know": [1, 3], "popul": 1, "To": [1, 2, 3, 4, 6, 216, 294, 303, 307], "avoid": 1, "unnecessari": [1, 3], "alloc": [1, 297], "respons": 1, "space": [1, 114, 283], "void": 1, "eval_cpu": 1, "std": [1, 268], "overrid": 1, "eval_gpu": 1, "jacobian": [1, 109, 191, 307], "product": [1, 104, 109, 125, 137, 140, 184, 191, 249, 307], "primal": [1, 109, 191], "tangent": [1, 20, 21, 109, 182, 183], "int": [1, 3, 4, 7, 9, 12, 14, 15, 22, 23, 24, 25, 29, 30, 31, 32, 40, 41, 42, 43, 45, 48, 50, 53, 56, 57, 59, 61, 64, 65, 66, 72, 73, 74, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 99, 100, 103, 112, 114, 124, 126, 128, 129, 131, 135, 138, 139, 140, 141, 142, 143, 144, 145, 146, 147, 148, 149, 150, 151, 152, 154, 155, 156, 169, 170, 171, 174, 175, 178, 179, 180, 181, 184, 185, 186, 187, 188, 189, 190, 192, 194, 216, 218, 219, 220, 224, 226, 227, 228, 229, 249, 251, 252, 254, 258, 261, 274, 275, 279, 283, 285, 297], "argnum": [1, 100, 189, 303], "cotan": 1, "across": [1, 226], "pair": [1, 138, 238, 254], "repres": [1, 3, 285, 306], "axi": [1, 3, 4, 12, 14, 22, 23, 24, 25, 29, 30, 31, 32, 40, 41, 42, 43, 45, 53, 56, 59, 64, 74, 82, 84, 87, 90, 91, 92, 93, 94, 95, 96, 112, 124, 126, 128, 129, 131, 138, 139, 140, 144, 154, 169, 170, 171, 174, 175, 178, 179, 180, 181, 185, 190, 192, 274, 275, 279, 283, 285, 304], "correspond": [1, 12, 14, 57, 63, 72, 74, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 124, 126, 129, 140, 178, 184, 192, 209, 303], "dimens": [1, 3, 12, 14, 22, 23, 44, 50, 57, 66, 74, 82, 91, 92, 94, 95, 96, 104, 112, 113, 124, 125, 126, 128, 129, 140, 141, 144, 150, 178, 181, 184, 185, 190, 218, 219, 220, 222, 223, 226, 227, 228, 249, 252, 254, 261, 275, 303], "vmap": [1, 303, 305, 307], "print": [1, 2, 3, 4, 6, 208, 209, 210, 216, 300, 303, 304, 305, 306, 307], "ostream": 1, "os": [1, 6], "equival": [1, 27, 47, 58, 76, 98, 180, 225, 248, 250, 251, 259], "check": [1, 6, 60, 238, 303, 304], "bool": [1, 12, 13, 14, 22, 23, 29, 30, 31, 32, 40, 41, 42, 43, 45, 56, 57, 59, 60, 112, 115, 124, 126, 128, 129, 140, 142, 143, 148, 151, 152, 178, 190, 198, 207, 218, 219, 220, 226, 227, 228, 229, 231, 235, 236, 238, 243, 246, 249, 251, 254, 258, 261, 273, 276], "is_equival": 1, "privat": 1, "fall": 1, "eval": [1, 2, 3, 4, 216, 297, 299, 303, 305, 307], "deriv": [1, 303, 305], "base": [1, 112, 117, 119, 202, 204, 254, 261, 297, 299, 300, 304], "abov": [1, 3, 6, 141, 187, 201, 216, 303, 304, 305, 309], "demonstr": [1, 306], "treat": [1, 91, 92, 94, 95, 180], "paramet": [1, 2, 3, 4, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 33, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127, 128, 129, 130, 131, 132, 133, 135, 136, 137, 138, 139, 140, 141, 142, 143, 144, 145, 146, 147, 148, 149, 150, 151, 152, 153, 154, 155, 156, 157, 158, 159, 160, 161, 162, 165, 166, 167, 168, 169, 170, 171, 172, 173, 174, 175, 176, 177, 178, 179, 180, 181, 182, 183, 184, 185, 186, 187, 188, 189, 190, 191, 192, 193, 194, 195, 196, 197, 198, 199, 200, 201, 202, 203, 204, 206, 207, 208, 209, 210, 218, 219, 220, 221, 222, 223, 224, 225, 226, 227, 228, 229, 231, 232, 235, 236, 238, 243, 246, 247, 248, 249, 250, 251, 252, 254, 256, 258, 259, 260, 261, 262, 263, 264, 265, 266, 267, 268, 269, 273, 274, 275, 276, 277, 278, 279, 280, 281, 282, 283, 284, 285, 292, 293, 294, 297, 299, 303, 305], "produc": [1, 249, 294], "through": [1, 176, 203, 261, 303, 306], "construct": [1, 4, 73, 99, 135, 194], "its": [1, 6, 125, 139, 150, 186, 196, 200, 201, 202, 210, 216, 251, 306, 309], "type": [1, 5, 6, 7, 8, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 33, 50, 57, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 72, 73, 74, 75, 76, 77, 78, 79, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127, 128, 129, 130, 131, 132, 133, 135, 136, 137, 138, 139, 140, 141, 142, 143, 144, 145, 146, 147, 148, 150, 151, 152, 153, 154, 155, 156, 157, 165, 166, 167, 168, 169, 170, 171, 172, 173, 174, 175, 176, 177, 178, 179, 180, 181, 182, 183, 184, 185, 186, 187, 188, 189, 190, 191, 192, 193, 194, 195, 204, 208, 216, 254, 261, 262, 263, 264, 265, 266, 267, 268, 269, 273, 274, 275, 276, 277, 278, 279, 280, 281, 282, 283, 284, 285, 304], "shape": [1, 3, 4, 47, 60, 61, 65, 66, 74, 84, 87, 90, 93, 94, 95, 99, 109, 125, 135, 136, 143, 144, 145, 147, 148, 151, 152, 155, 181, 191, 193, 194, 195, 216, 218, 219, 220, 222, 223, 227, 229, 238, 262, 263, 264, 265, 266, 267, 268, 269, 275, 285, 299, 303, 304, 307, 309], "pass": [1, 3, 4, 47, 58, 137, 138, 189, 196, 208, 209, 216, 236, 246, 247, 248, 251, 256, 305], "re": [1, 4, 6, 294], "now": [1, 3, 6, 251, 306], "promot": 1, "dtype": [1, 3, 15, 26, 33, 57, 83, 96, 99, 103, 112, 113, 114, 135, 145, 147, 148, 151, 152, 186, 194, 212, 262, 263, 264, 265, 266, 267, 268, 269, 273, 275, 303, 304, 306, 307, 308], "promoted_dtyp": 1, "promote_typ": 1, "float32": [1, 15, 83, 103, 112, 113, 114, 135, 145, 147, 151, 152, 186, 194, 212, 262, 263, 264, 265, 266, 267, 268, 269, 273, 275, 303, 304, 305, 306, 307, 308], "non": [1, 6, 230, 244, 286, 297], "point": [1, 2, 3, 6, 98, 142, 212], "out_dtyp": 1, "is_floating_point": 1, "cast": [1, 33, 93, 94, 95, 115, 231, 306], "up": [1, 3, 251], "determin": [1, 74, 242, 308], "x_cast": 1, "astyp": [1, 3, 231, 306], "y_cast": 1, "broadcasted_input": 1, "broadcast_arrai": 1, "out_shap": 1, "0": [1, 2, 3, 4, 6, 7, 15, 48, 53, 59, 64, 65, 66, 73, 74, 83, 96, 100, 112, 113, 138, 143, 152, 154, 156, 171, 175, 186, 187, 188, 189, 190, 192, 197, 198, 200, 201, 202, 203, 206, 207, 208, 216, 218, 219, 220, 221, 222, 223, 225, 226, 227, 228, 250, 253, 254, 255, 258, 259, 260, 261, 262, 263, 264, 265, 266, 267, 268, 269, 271, 272, 273, 275, 277, 278, 284, 285, 287, 288, 289, 291, 292, 294, 297, 300, 303, 304, 305, 306, 307, 308], "unique_ptr": 1, "make_uniqu": 1, "to_stream": 1, "handl": [1, 216], "resolv": 1, "No": [1, 3], "happen": [1, 3, 261, 299, 305], "alon": [1, 306], "effect": [1, 222, 305], "onli": [1, 3, 5, 6, 60, 65, 66, 112, 141, 212, 216, 235, 236, 238, 243, 246, 247, 248, 297, 303, 308, 309], "execut": [1, 6, 306, 309], "depend": [1, 2, 57, 112, 304, 308, 309], "devic": [1, 5, 6, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 28, 29, 30, 31, 32, 33, 34, 36, 38, 39, 40, 41, 42, 43, 45, 46, 47, 48, 49, 51, 53, 54, 55, 56, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 101, 102, 103, 104, 105, 106, 107, 108, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127, 128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141, 142, 143, 144, 145, 147, 148, 150, 151, 152, 153, 154, 155, 156, 157, 163, 165, 166, 167, 168, 169, 170, 171, 172, 173, 174, 175, 176, 177, 178, 179, 180, 181, 182, 183, 184, 185, 186, 187, 188, 190, 193, 194, 195, 309, 310], "specifi": [1, 15, 33, 66, 74, 91, 92, 99, 100, 112, 114, 131, 135, 144, 154, 179, 180, 181, 184, 185, 189, 192, 194, 218, 260, 273, 274, 275, 276, 277, 278, 279, 280, 281, 282, 283, 284, 285, 292, 303, 309], "memori": [1, 5, 198, 261, 297, 305, 306], "ha": [1, 3, 4, 5, 57, 74, 90, 91, 93, 94, 95, 100, 144, 218, 229, 297, 299, 304, 305, 307, 309], "been": [1, 3, 305], "try": [1, 6], "naiv": [1, 303], "gener": [1, 2, 15, 83, 91, 92, 114, 143, 147, 148, 151, 152, 261, 300, 304, 305, 310], "version": [1, 6, 72, 120, 124, 141, 169, 192, 300, 303, 304], "declar": 1, "member": [1, 216, 241, 244], "method": [1, 3, 7, 8, 9, 26, 197, 198, 199, 200, 201, 202, 203, 204, 205, 206, 207, 216, 242, 297], "each": [1, 50, 72, 80, 125, 138, 141, 142, 144, 154, 161, 162, 171, 185, 192, 193, 222, 223, 224, 226, 254, 261, 275, 300, 305], "find": [1, 2, 6], "pointwis": 1, "captur": [1, 216], "templat": 1, "axpby_impl": 1, "typenam": 1, "t": [1, 3, 78, 142, 189, 197, 198, 199, 200, 201, 202, 203, 206, 207, 216, 303, 309], "readi": 1, "fill": [1, 99, 136, 186, 195, 262, 263, 264, 265, 266, 268, 269], "malloc_or_wait": 1, "synchron": 1, "avail": [1, 2, 3, 4, 6, 8, 212, 309], "There": [1, 216], "wait": [1, 3], "here": [1, 3, 303, 305, 308, 309], "request": 1, "pressur": 1, "condit": [1, 193, 309], "set_data": 1, "nbyte": 1, "collect": [1, 205, 209, 302], "pointer": 1, "x_ptr": 1, "y_ptr": 1, "out_ptr": 1, "relev": 1, "static_cast": 1, "size_t": 1, "out_idx": 1, "size": [1, 3, 4, 50, 66, 72, 82, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 99, 103, 112, 141, 142, 144, 155, 171, 174, 198, 216, 219, 220, 224, 227, 251, 305, 306], "map": [1, 4, 115, 209, 224, 231], "linear": [1, 3, 4, 5, 209, 216, 225, 238, 251, 253, 255, 257, 270, 271, 272, 288, 289, 290, 294, 297], "indic": [1, 13, 22, 23, 24, 25, 100, 105, 106, 107, 108, 171, 180, 181, 189, 243, 245, 275, 304], "offset": [1, 3, 74], "x_offset": 1, "elem_to_loc": 1, "stride": [1, 65, 66, 219, 220, 254, 304], "y_offset": 1, "contigu": 1, "regularli": 1, "default": [1, 6, 12, 13, 14, 15, 22, 23, 24, 25, 60, 64, 65, 66, 72, 73, 74, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 100, 103, 112, 113, 114, 115, 124, 126, 128, 129, 135, 139, 140, 141, 142, 143, 144, 145, 147, 148, 150, 151, 152, 154, 155, 156, 170, 171, 174, 175, 178, 184, 185, 186, 187, 188, 189, 190, 192, 194, 197, 198, 199, 200, 201, 202, 203, 205, 206, 207, 212, 218, 219, 220, 227, 229, 231, 236, 238, 243, 246, 249, 250, 251, 254, 258, 259, 261, 262, 263, 264, 265, 266, 267, 268, 269, 273, 274, 275, 276, 277, 278, 279, 280, 281, 282, 283, 284, 285, 297, 300, 302, 303, 306, 308, 310], "row": [1, 83, 103, 141, 186], "major": 1, "henc": [1, 141], "doesn": [1, 216], "addit": [1, 3, 11, 115, 218, 226, 228, 249, 252, 297, 303], "abl": [1, 141], "all": [1, 4, 6, 13, 24, 66, 80, 83, 86, 89, 92, 95, 125, 138, 139, 174, 204, 216, 231, 232, 236, 239, 240, 241, 244, 246, 249, 251, 258, 261, 294, 297, 300, 304, 305, 307, 310], "incom": 1, "accordingli": 1, "dispatch": 1, "float16": [1, 115, 212, 231, 305, 306], "bfloat16": [1, 306], "complex64": 1, "throw": 1, "error": [1, 6, 78, 79, 171, 225, 251, 270, 271, 272, 281, 282, 303, 306], "encount": [1, 303], "unexpect": [1, 15], "regist": [1, 4], "op": [1, 137, 236, 305], "assert": 1, "2": [1, 2, 3, 4, 66, 73, 74, 78, 85, 88, 90, 91, 92, 93, 94, 95, 96, 112, 113, 119, 125, 141, 150, 184, 186, 187, 188, 197, 199, 200, 201, 206, 212, 216, 220, 225, 252, 258, 262, 263, 264, 265, 266, 267, 268, 269, 271, 275, 276, 278, 284, 285, 294, 297, 303, 304, 305, 306, 307, 308, 309], "1": [1, 3, 4, 15, 24, 25, 65, 66, 73, 74, 84, 85, 87, 88, 90, 91, 92, 93, 94, 95, 96, 104, 112, 113, 125, 137, 139, 141, 144, 152, 165, 170, 180, 189, 197, 198, 199, 200, 201, 202, 203, 206, 207, 212, 216, 218, 219, 220, 221, 222, 223, 225, 226, 227, 228, 229, 250, 252, 254, 255, 258, 260, 263, 264, 265, 266, 267, 268, 269, 271, 272, 273, 274, 275, 276, 277, 278, 279, 281, 283, 284, 285, 289, 292, 294, 297, 299, 303, 304, 306, 307, 308, 309], "correct": [1, 6, 200, 201, 202, 304, 305], "els": [1, 3, 216, 236, 305], "float16_t": 1, "bfloat16_t": 1, "complex64_t": 1, "runtime_error": 1, "support": [1, 3, 5, 6, 13, 65, 66, 96, 113, 115, 125, 141, 303, 304, 306, 308], "have": [1, 3, 6, 13, 60, 91, 92, 94, 95, 125, 144, 203, 208, 249, 256, 302, 304, 305, 309], "rememb": 1, "3": [1, 3, 6, 96, 112, 113, 198, 203, 264, 266, 300, 304, 306, 307], "complic": 1, "keep": [1, 12, 14, 22, 23, 124, 126, 128, 129, 140, 178, 190, 216, 235, 303, 305], "mind": [1, 3], "half": [1, 15, 148, 152, 254, 305], "precis": [1, 3, 216, 225], "direct": [1, 3, 203, 233, 309], "fix": [1, 3, 6, 305], "possibl": [1, 3, 125, 171, 224, 304, 309], "due": 1, "transpos": [1, 3, 27, 142], "aren": 1, "guarante": 1, "fit": [1, 141, 309], "requir": [1, 3, 216, 305, 306], "column": [1, 83, 103, 141], "inplac": 1, "expect": [1, 3, 219, 220, 221, 222, 223, 258, 261, 276, 304], "answer": 1, "copi": [1, 3, 5, 139, 170, 306], "simpli": [1, 3, 6, 253, 288, 297, 303], "catlas_saxpbi": 1, "axpby_impl_acceler": 1, "first": [1, 2, 3, 4, 6, 74, 96, 100, 121, 123, 125, 139, 150, 179, 184, 189, 198, 200, 201, 202, 208, 216, 226, 274, 303, 306, 309], "mode": [1, 67, 234, 243, 245, 265, 266], "i": [1, 3, 109, 112, 201, 216, 219, 220, 222, 223, 236, 281, 303], "e": [1, 4, 6, 78, 109, 165, 199, 218, 219, 220, 222, 223, 226, 227, 228, 236, 252, 293, 299, 305, 310], "match": [1, 6, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 238, 275, 304, 306], "transposit": 1, "data_s": 1, "items": 1, "flag": [1, 306], "copy_inplac": 1, "copytyp": 1, "n": [1, 3, 26, 65, 66, 83, 84, 86, 87, 89, 90, 93, 95, 103, 186, 190, 218, 219, 220, 222, 223, 281, 285], "incx": 1, "inci": 1, "great": 1, "But": [1, 309], "criteria": 1, "luckili": [1, 305], "alwai": [1, 208, 303], "With": 1, "final": [1, 2, 3, 4], "singl": [1, 4, 80, 109, 115, 138, 191, 304, 308], "row_contigu": 1, "col_contigu": 1, "common": [1, 305], "hit": 1, "mileston": 1, "enough": [1, 305], "run": [1, 3, 4, 5, 6, 137, 197, 198, 200, 201, 202, 218, 231, 305, 309, 310], "If": [1, 3, 6, 12, 13, 14, 15, 22, 23, 24, 25, 57, 60, 63, 64, 67, 73, 74, 80, 93, 94, 95, 98, 99, 100, 112, 115, 124, 125, 126, 128, 129, 135, 138, 139, 140, 144, 154, 169, 170, 171, 178, 180, 181, 184, 189, 190, 192, 194, 198, 209, 218, 219, 220, 226, 228, 229, 236, 238, 246, 251, 254, 256, 258, 273, 275, 285, 303, 305, 308, 309, 310], "plan": 1, "stop": [1, 3, 15, 114, 176, 303, 304], "enjoi": 1, "speed": 1, "appl": [1, 3, 5, 6, 309], "silicon": [1, 3, 5, 6, 309], "address": 1, "shade": 1, "languag": [1, 212], "kernel": [1, 65, 66, 304], "written": 1, "help": [1, 3, 309], "resourc": 1, "walkthrough": 1, "pipelin": 1, "specif": [1, 6, 303], "cpp": 1, "algorithm": [1, 203], "launch": [1, 304], "exactli": [1, 3, 238, 303], "mani": [1, 171, 219, 220, 224, 305], "thread": 1, "pick": 1, "updat": [1, 2, 3, 4, 198, 201, 203, 207, 209, 218, 231, 238, 248, 299, 305], "assign": [1, 297], "axpby_gener": 1, "buffer": [1, 306], "constant": [1, 3, 6, 138, 206, 216, 218, 226, 228, 252, 276, 285, 306], "4": [1, 3, 72, 96, 112, 141, 142, 161, 212, 218, 227, 251, 261, 263, 264, 265, 273, 304, 307, 309], "5": [1, 2, 3, 6, 112, 143, 206, 218, 221, 222, 223, 227, 259, 262, 265, 266, 284, 291, 294, 303, 304], "x_stride": 1, "6": [1, 3, 112, 161, 206, 261, 264, 271, 272, 276, 285, 304, 307], "y_stride": 1, "7": [1, 3, 112, 141, 304], "ndim": [1, 96, 112], "8": [1, 3, 6, 112, 141, 197, 198, 199, 200, 201, 202, 206, 212, 227, 261, 274, 304, 307, 309], "uint": 1, "index": [1, 5, 7, 9, 24, 82, 83, 100, 139, 180, 181, 189], "thread_position_in_grid": 1, "convert": [1, 57, 96, 251, 305, 306, 307], "instanti": [1, 4, 305], "uniqu": [1, 300], "host": 1, "name": [1, 115, 141, 142, 159, 160, 161, 162, 205, 216, 226, 235, 238, 240, 304, 308], "identifi": [1, 208, 302], "instantiate_axpbi": 1, "type_nam": 1, "host_nam": 1, "axpby_general_": 1, "compil": [1, 6, 305], "mlx_ext": 1, "metallib": [1, 6], "see": [1, 3, 4, 6, 8, 28, 29, 30, 31, 32, 34, 36, 38, 39, 40, 41, 42, 43, 45, 46, 47, 48, 49, 51, 53, 54, 55, 56, 58, 59, 112, 159, 160, 216, 218, 222, 225, 234, 250, 251, 254, 255, 258, 259, 263, 264, 265, 266, 270, 271, 272, 289, 303, 304, 307, 309], "later": [1, 6], "co": [1, 258, 303], "locat": [1, 247, 248, 309], "share": [1, 5, 72, 141, 142], "register_librari": 1, "potenti": 1, "path": [1, 6, 161, 162, 238], "tri": 1, "load": [1, 4, 5, 238], "hasn": 1, "alreadi": [1, 3], "static": [1, 6], "object": [1, 8, 26, 37, 57, 143, 148, 151, 152, 192, 208, 209, 222, 302], "why": [1, 3], "packag": [1, 2, 4, 294], "process": [1, 3, 67, 209, 223, 224, 261, 302], "logic": [1, 121, 122, 123], "grid": 1, "shown": 1, "below": [1, 6, 112, 186, 188, 212, 305], "prepar": [1, 3], "carri": 1, "should": [1, 2, 3, 4, 6, 74, 109, 141, 181, 189, 191, 208, 216, 219, 220, 222, 223, 243, 249, 256, 275, 277, 297, 302, 303, 305, 306, 310], "d": [1, 3, 73, 74, 104, 112, 125, 137, 180, 186, 187, 188, 197, 200, 202, 210, 223, 309], "ostringstream": 1, "kname": 1, "axpby_": 1, "general_": 1, "type_to_nam": 1, "make": [1, 3, 4, 6, 125, 216, 305, 307, 309], "sure": [1, 3, 6, 216], "look": [1, 3], "folder": 1, "get_colocated_mtllib_path": 1, "get_kernel": 1, "str": [1, 67, 100, 112, 115, 158, 159, 160, 161, 162, 189, 208, 210, 231, 232, 235, 236, 238, 240, 242, 246, 265, 266, 273, 274, 275, 276, 277, 278, 279, 280, 281, 282, 283, 284, 285], "encod": [1, 254, 258, 261, 275], "compute_encod": 1, "get_command_encod": 1, "setcomputepipelinest": 1, "those": [1, 3, 216], "nelem": 1, "set_array_buff": 1, "setbyt": 1, "sizeof": 1, "threadgroup": 1, "higher": [1, 104, 303], "than": [1, 3, 57, 67, 74, 76, 101, 102, 110, 111, 125, 198, 203, 209, 254, 260, 284, 292, 303, 309], "max": [1, 112, 127, 198, 202, 250, 274, 276, 277, 285, 287, 303, 309], "allow": [1, 204, 216, 248, 297, 304, 307], "tgp_size": 1, "min": [1, 112, 130, 250, 287], "maxtotalthreadsperthreadgroup": 1, "3d": [1, 218, 223], "mtl": 1, "group_dim": 1, "grid_dim": 1, "divid": [1, 98, 141], "among": 1, "dispatchthread": 1, "few": [1, 3, 4, 5, 305, 307], "thing": [1, 3], "note": [1, 3, 6, 13, 65, 66, 91, 92, 112, 141, 144, 216, 306, 308], "befor": [1, 3, 6, 24, 139, 235, 261, 304, 305], "move": [1, 131, 309], "track": [1, 216, 218], "activ": [1, 6, 222, 230, 260, 261, 286, 291, 292, 293], "command": [1, 6], "instead": [1, 6, 216, 248, 258, 303, 305], "end_encod": 1, "end": [1, 74, 141, 255, 260, 278, 284, 289, 291, 292], "until": [1, 305, 307], "limit": [1, 63, 304], "flush": 1, "enqueu": 1, "commit": 1, "associ": [1, 161, 162, 305], "suggest": 1, "deeper": 1, "dive": 1, "studi": 1, "come": [1, 3, 303], "far": [1, 299], "built": [1, 6, 305], "includ": [1, 232, 251, 276, 303, 304, 307, 308, 310], "forward": [1, 189, 305], "diff": 1, "push": 1, "along": [1, 22, 23, 64, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 112, 154, 169, 171, 175, 180, 181, 184, 216], "similarli": [1, 6, 125, 303, 305], "scale_arr": 1, "contribut": 1, "tangent_x": 1, "tangent_i": 1, "revers": [1, 185, 258], "arg": [1, 3, 8, 47, 58, 80, 161, 162], "push_back": 1, "fulli": [1, 5, 306, 309], "overal": 1, "directori": [1, 3, 6], "extens": [1, 115, 212, 242, 308], "h": [1, 65, 66, 112, 218, 220, 222, 223, 303, 305], "mlx_sample_extens": 1, "__init__": [1, 3, 4, 7, 8, 9, 26, 216, 297], "py": [1, 3, 6], "cmakelist": 1, "txt": 1, "setup": [1, 2, 4, 6], "hold": [1, 3, 8, 112, 204], "instal": 1, "pybind11": [1, 6], "sinc": [1, 3, 4, 203, 297, 306, 309], "compon": [1, 3], "etc": [1, 141, 216], "becom": 1, "pybind11_modul": 1, "m": [1, 6, 83, 112, 186, 197], "doc": [1, 4], "sampl": [1, 2, 3, 114, 143, 144, 145, 148, 151, 152, 263, 264, 265, 266, 268, 269, 276, 285, 300], "_a": 1, "pos_onli": 1, "kw_onli": 1, "none": [1, 3, 7, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 28, 29, 30, 31, 32, 33, 34, 36, 38, 39, 40, 41, 42, 43, 45, 46, 47, 48, 49, 51, 53, 54, 55, 56, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127, 128, 129, 130, 131, 132, 133, 135, 136, 137, 138, 139, 140, 141, 142, 143, 144, 145, 147, 148, 149, 150, 151, 152, 153, 154, 155, 156, 157, 163, 164, 165, 166, 167, 168, 169, 170, 171, 172, 173, 174, 175, 176, 177, 178, 179, 180, 181, 182, 183, 184, 185, 186, 187, 188, 189, 190, 192, 193, 194, 195, 198, 208, 209, 225, 231, 235, 236, 246, 249, 258, 261, 273, 274, 275, 276, 277, 278, 279, 280, 281, 282, 283, 284, 285, 304], "r": [1, 3, 113, 189, 222], "pbdoc": 1, "most": [1, 144, 216, 303, 304, 305], "complex": [1, 91, 92, 93, 94, 95, 143, 148, 151, 152, 208, 216, 248, 303], "bell": 1, "whistl": 1, "liter": [1, 273, 274, 275, 276, 277, 278, 279, 280, 281, 282, 283, 284, 285], "string": [1, 306, 308], "modul": [1, 3, 4, 196, 251, 256, 261, 294, 302, 305], "ensur": [1, 6, 281], "caster": 1, "find_packag": 1, "config": 1, "add_librari": 1, "sourc": [1, 131, 185], "target_sourc": 1, "cmake_current_list_dir": 1, "header": 1, "target_include_directori": 1, "target_link_librari": 1, "attach": 1, "conveni": [1, 4], "mlx_build_metallib": 1, "target": [1, 189, 273, 275, 276, 277, 278, 279, 280, 281, 282, 283, 284], "destin": [1, 131], "automat": [1, 5, 115, 307, 308, 309], "practic": 1, "mlx_build_met": [1, 6], "mlx_ext_metallib": 1, "titl": 1, "include_dir": 1, "project_source_dir": 1, "mlx_include_dir": 1, "output_directori": 1, "cmake_library_output_directori": 1, "add_depend": 1, "endif": 1, "pybind11_add_modul": 1, "build_shared_lib": 1, "target_link_opt": 1, "wl": 1, "rpath": 1, "loader_path": 1, "onc": 1, "describ": [1, 305], "util": [1, 3, 5, 6, 161, 216], "__name__": [1, 3], "__main__": [1, 3], "descript": [1, 3, 212], "ext_modul": 1, "cmakeextens": 1, "cmdclass": 1, "build_ext": 1, "cmakebuild": 1, "package_dir": 1, "package_data": 1, "dylib": 1, "zip_saf": 1, "fals": [1, 3, 12, 13, 14, 22, 23, 29, 30, 31, 32, 40, 41, 42, 43, 45, 56, 59, 60, 112, 115, 124, 126, 128, 129, 140, 178, 190, 193, 198, 207, 208, 209, 212, 226, 227, 229, 236, 238, 246, 249, 251, 254, 258, 261, 273, 276, 306], "python_requir": 1, "even": [1, 3, 305, 306], "though": [1, 3, 305, 306], "j8": 1, "libmlx_ext": 1, "cpython": 1, "3x": 1, "darwin": 1, "pip": [1, 6], "after": [1, 3, 4, 24, 96, 98, 139, 141, 218, 226, 228, 249, 261, 284, 309], "plai": [1, 3], "ones": [1, 3, 136, 161, 186, 247, 248, 251, 304], "b": [1, 3, 11, 13, 60, 75, 76, 77, 98, 101, 102, 104, 110, 111, 112, 120, 121, 123, 125, 127, 130, 132, 137, 141, 177, 184, 189, 229, 303, 304, 305, 306, 307, 308, 309], "f": [1, 2, 4, 112, 201, 216, 306], "item": [1, 2, 3, 4, 209, 305, 306, 307], "true": [1, 2, 3, 13, 60, 112, 115, 142, 169, 193, 198, 208, 209, 212, 216, 218, 219, 220, 226, 227, 228, 229, 235, 236, 238, 243, 246, 251, 254, 258, 261, 273, 281], "quick": [1, 5], "benchmark": 1, "compar": [1, 60], "time": [1, 3, 6, 216, 303, 305, 309], "set_default_devic": 1, "256": [1, 4], "512": [1, 3, 261, 309], "random": [1, 2, 3, 4, 5, 218, 227, 238, 243, 303, 309, 310], "normal": [1, 2, 3, 151, 205, 216, 218, 226, 227, 228, 252, 261, 263, 265, 306, 309], "bench": 1, "warm": 1, "rang": [1, 2, 3, 4, 6, 15, 96, 114, 264, 266, 271, 272, 299, 300, 303, 305, 309], "100": [1, 2, 3, 303, 305, 309], "5000": 1, "simple_tim": 1, "custom_tim": 1, "3f": [1, 4], "custom": [1, 261], "114": 1, "109": 1, "modest": 1, "improv": [1, 3, 197, 198, 199, 200, 201, 202, 206], "awai": [1, 3], "good": [1, 6, 309], "nn": [1, 3, 4, 161, 209, 216, 294, 297, 299, 305], "grad": [1, 2, 4, 189, 299, 303, 304, 305, 307], "simplifi": [], "full": [1, 4, 47, 58, 67, 169, 247, 248, 276, 305], "implement": [2, 4, 112, 197, 198, 199, 200, 202, 203, 204, 205, 224, 235, 249, 254, 256, 258, 260, 261, 292, 303, 306], "basic": [2, 156, 303], "model": [2, 4, 5, 161, 196, 209, 216, 231, 234, 236, 238, 242, 243, 245, 246, 247, 249, 261, 294, 297, 299, 305], "problem": [2, 4, 216], "metadata": [2, 115, 159], "num_featur": [2, 218], "num_exampl": 2, "1_000": 2, "num_it": 2, "10_000": 2, "iter": [2, 4, 209, 300, 305], "sgd": [2, 4, 203, 299], "lr": [2, 203], "01": [2, 201], "rate": [2, 197, 198, 199, 200, 201, 202, 203, 206, 207], "ll": [2, 4, 278, 303], "synthet": 2, "dataset": [2, 305], "matrix": [2, 72, 73, 83, 103, 112, 113, 125, 141, 142, 251, 267, 294], "ground": [2, 3, 275, 284], "truth": [2, 275, 284], "w_star": 2, "valu": [2, 3, 10, 13, 15, 22, 23, 37, 57, 60, 63, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 99, 112, 114, 138, 143, 144, 145, 147, 148, 151, 152, 159, 180, 181, 189, 192, 196, 198, 201, 205, 208, 209, 212, 221, 222, 223, 227, 229, 235, 249, 250, 254, 259, 260, 261, 262, 273, 274, 275, 276, 277, 278, 280, 281, 282, 283, 284, 292, 297, 303], "gaussian": [2, 225, 270, 271, 272, 276], "nois": 2, "exampl": [2, 3, 4, 15, 96, 112, 113, 180, 216, 218, 227, 236, 238, 243, 246, 262, 263, 264, 265, 266, 267, 268, 269, 273, 275, 294, 299, 300, 303, 304, 305, 306, 307, 308], "noisi": 2, "label": [2, 275], "ep": [2, 197, 198, 199, 200, 201, 202, 206, 218, 226, 227, 228, 252, 274, 276, 285], "1e": [2, 4, 13, 197, 198, 199, 200, 201, 202, 206, 218, 226, 227, 228, 252, 274, 276, 285], "us": [2, 3, 4, 5, 6, 15, 72, 76, 96, 112, 113, 125, 141, 142, 154, 155, 197, 198, 200, 201, 202, 203, 204, 208, 216, 222, 224, 225, 229, 231, 235, 242, 247, 248, 249, 251, 254, 258, 261, 265, 266, 271, 272, 274, 294, 297, 299, 300, 302, 303, 304, 307, 309], "weight": [2, 65, 66, 198, 201, 203, 207, 209, 216, 238, 242, 251, 275, 297, 303, 305], "squar": [2, 3, 103, 157, 172, 189, 197, 198, 200, 201, 202, 209, 216, 252, 282, 284, 303, 306], "loss": [2, 4, 189, 216, 299, 303, 305], "loss_fn": [2, 4, 299, 303], "w": [2, 66, 72, 141, 142, 189, 207, 218, 220, 222, 223, 229, 303], "mean": [2, 3, 4, 189, 216, 218, 226, 236, 252, 268, 273, 274, 275, 276, 277, 278, 279, 280, 281, 282, 283, 284, 285, 303, 306], "grad_fn": [2, 303], "initi": [2, 3, 216, 218, 226, 227, 228, 229, 250, 252, 262, 263, 264, 265, 266, 267, 268, 269, 297, 305], "randomli": [2, 3, 221, 222, 223], "Then": [2, 6], "repeatedli": 2, "_": [2, 3, 216, 300, 305, 309], "verifi": [2, 6], "close": [2, 5, 6, 13], "error_norm": 2, "5f": 2, "someth": [2, 3, 304], "00005": 2, "00364": 2, "complet": [2, 3, 6, 247, 248, 303, 309], "logist": [2, 165, 257, 271, 272, 290], "github": [2, 4, 6], "repo": [2, 4, 6], "enabl": [3, 6, 207], "larg": [3, 216, 249, 281, 305], "ish": 3, "transform": [3, 5, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 196, 216, 218, 226, 228, 229, 235, 236, 246, 251, 254, 304], "compromis": 3, "eas": 3, "llama": 3, "famili": 3, "less": [3, 24, 111, 139, 254, 284], "200": 3, "line": [3, 305, 306], "python": [3, 37, 50, 57, 80, 208, 209, 210, 297, 302, 303, 306], "neural": [3, 5, 206, 224, 230, 263, 264, 286, 294, 297], "network": [3, 5, 206, 218, 222, 224, 263, 264, 294, 297], "build": [3, 5, 265, 297], "concis": 3, "architectur": [3, 6, 216, 248, 309], "notabl": [3, 5], "rope": [3, 216], "posit": [3, 24, 74, 96, 100, 108, 131, 139, 189, 209, 216, 219, 220, 249, 254, 258, 276, 285], "option": [3, 12, 14, 15, 22, 23, 24, 25, 26, 31, 32, 64, 65, 66, 67, 72, 73, 74, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 99, 100, 103, 107, 108, 112, 113, 114, 115, 124, 126, 128, 129, 135, 138, 139, 140, 141, 142, 143, 144, 145, 147, 148, 150, 151, 152, 154, 155, 169, 170, 171, 174, 175, 178, 180, 181, 184, 185, 186, 187, 188, 189, 190, 192, 194, 197, 198, 199, 200, 201, 202, 203, 206, 207, 208, 209, 218, 219, 220, 229, 231, 235, 236, 238, 246, 249, 251, 254, 258, 261, 262, 263, 264, 265, 266, 267, 268, 269, 273, 274, 275, 276, 277, 278, 279, 280, 281, 282, 283, 284, 285, 300, 308, 310], "kei": [3, 143, 144, 145, 147, 148, 150, 151, 152, 205, 208, 209, 235, 236, 246, 249, 254, 300, 302, 303], "cach": [3, 254], "concaten": 3, "project": [3, 249], "llamaattent": 3, "self": [3, 4, 7, 9, 26, 27, 28, 29, 30, 31, 32, 33, 34, 36, 37, 38, 39, 40, 41, 42, 43, 45, 46, 47, 48, 49, 51, 53, 54, 55, 56, 57, 58, 59, 216, 230, 286, 297], "dim": [3, 224, 226, 227, 228, 249, 252, 254, 258, 261], "num_head": [3, 249, 261], "super": [3, 4, 216, 297], "tradit": [3, 222, 223, 254], "query_proj": 3, "bia": [3, 72, 141, 142, 200, 201, 202, 209, 216, 219, 220, 229, 236, 238, 246, 249, 251, 303], "key_proj": 3, "value_proj": 3, "out_proj": [3, 297], "__call__": [3, 4, 216, 297], "queri": [3, 249], "mask": [3, 243, 249, 304], "extract": [3, 73, 74, 216, 235, 297], "l": [3, 4, 216, 218, 219, 284], "reshap": [3, 112, 304], "combin": 3, "key_cach": 3, "value_cach": 3, "sqrt": [3, 78, 197, 199, 200, 201, 206, 218, 226, 227, 228, 229, 252, 258, 263, 264, 265, 266], "score": 3, "softmax": [3, 275], "values_hat": 3, "rm": [3, 6, 198], "swiglu": 3, "rmsnorm": [3, 216], "llamaencoderlay": 3, "mlp_dim": [3, 261], "norm1": 3, "norm2": 3, "linear1": 3, "linear2": 3, "linear3": 3, "sigmoid": [3, 257, 271, 272, 290], "instanc": [3, 141, 210, 216, 227, 231, 232, 233, 236, 239, 240, 246, 248, 256, 297, 306], "embed": [3, 216, 254, 258, 274], "emb": [3, 224, 258], "token": [3, 224], "num_lay": [3, 4, 299], "vocab_s": 3, "norm": [3, 202, 203, 226, 285], "multiheadattent": [3, 216], "create_additive_causal_mask": 3, "list": [3, 8, 12, 14, 26, 29, 30, 40, 41, 42, 43, 45, 50, 53, 56, 57, 59, 61, 64, 80, 82, 85, 86, 88, 89, 91, 92, 94, 95, 99, 100, 109, 112, 124, 126, 128, 129, 135, 138, 140, 143, 144, 145, 147, 148, 151, 152, 155, 159, 169, 171, 174, 175, 178, 184, 185, 189, 190, 191, 194, 200, 201, 202, 203, 208, 210, 216, 236, 238, 239, 240, 241, 244, 246, 247, 248, 297, 302, 303, 305], "still": [3, 6, 112, 305], "consid": [3, 13, 60, 208, 209, 226, 302], "train": [3, 4, 216, 218, 221, 222, 223, 234, 236, 246, 263, 264], "ignor": [3, 63, 198], "whatsoev": 3, "rest": [3, 209, 254], "subsect": 3, "prompt": 3, "autoregress": 3, "yield": [3, 4, 300], "temp": 3, "causal": 3, "save": [3, 5, 115, 141, 159, 160, 161, 162, 242, 305], "append": [3, 125, 305], "store": 3, "per": [3, 4, 72, 141, 142, 204, 218, 226, 227, 228, 252, 305], "care": [3, 305], "last": [3, 25, 57, 86, 89, 91, 92, 94, 95, 96, 104, 113, 125, 144, 170, 184, 219, 220, 222, 223, 226, 306], "logit": [3, 144, 273, 275], "next": [3, 4], "categor": 3, "lazili": [3, 216], "noth": [3, 216, 305], "yet": [3, 112, 216, 297, 303, 304, 305, 307], "forc": [3, 4, 216, 307], "choos": [3, 254], "pars": 3, "feed": 3, "loop": [3, 4, 303, 305], "unsqueez": 3, "sequenc": [3, 218, 219, 261, 300, 309], "length": [3, 174, 218, 219], "len": [3, 86, 89, 92, 95], "overwrit": 3, "discard": [3, 208], "old": 3, "moment": [3, 198, 200, 201, 202], "anymor": 3, "everyth": 3, "small": [3, 218, 226, 228, 252, 276, 281, 285, 309], "10": [3, 4, 117, 156, 161, 209, 216, 238, 294, 304], "12": 3, "8192": 3, "1024": 3, "actual": [3, 15, 238, 297, 305], "materi": [3, 5], "could": [3, 216], "20_000": 3, "machin": [3, 5, 6, 206], "8gb": 3, "ram": 3, "32": [3, 4, 141, 142, 212], "44": 3, "doubl": 3, "bracket": 3, "becaus": [3, 216, 305], "batch": [3, 125, 218, 219, 220, 222, 223, 249, 305], "zip": [3, 4], "haven": 3, "anyth": [3, 189, 305], "result": [3, 15, 57, 72, 104, 112, 115, 125, 137, 142, 154, 156, 175, 184, 193, 209, 258, 303, 306], "similar": [3, 209, 247, 248, 249, 274, 306, 308], "runtim": 3, "section": [3, 6, 171, 285, 303], "access": [3, 37, 216, 297, 305, 309], "origin": [3, 74, 197, 198, 199, 200, 202, 203, 218, 263, 264, 265, 266, 306], "sentencepiec": 3, "pytorch": [3, 5, 226, 303], "compat": [3, 144, 308], "npz": [3, 115, 161, 162, 238, 242, 308], "file": [3, 6, 115, 158, 159, 160, 161, 162, 238, 242, 303, 308], "directli": 3, "argpars": 3, "itertool": [3, 209], "starmap": [3, 209], "np": [3, 4, 306, 307], "torch": [3, 306], "map_torch_to_mlx": 3, "tok_embed": 3, "elif": 3, "replac": [3, 247, 248, 261, 284], "attention_norm": 3, "ffn_norm": 3, "wq": 3, "wk": 3, "wv": 3, "wo": 3, "w1": 3, "w2": 3, "w3": 3, "ffn": 3, "separ": [3, 47, 58, 226], "submodul": [3, 4, 216, 236, 237, 246, 248], "feed_forward": 3, "parser": 3, "argumentpars": 3, "add_argu": 3, "torch_weight": 3, "output_fil": 3, "parse_arg": 3, "state": [3, 4, 204, 205, 216, 299, 300], "savez": [3, 242, 308], "k": [3, 73, 83, 186, 187, 188, 229, 236], "v": [3, 67, 216, 236, 306], "left": [3, 112, 141, 225, 254, 271, 272, 276, 278, 285], "disk": 3, "text": [3, 198, 203, 230, 255, 260, 263, 264, 265, 266, 276, 277, 278, 281, 284, 286, 287, 289, 291, 292], "format": [3, 115, 158, 159, 160, 161, 162, 306], "oper": [3, 5, 33, 169, 176, 181, 203, 216, 261, 303, 304, 305, 306, 307, 309, 310], "dictionari": [3, 115, 159, 160, 204, 205, 208, 216, 235, 247, 248, 302, 308], "represent": [3, 141, 208, 210], "tree_unflatten": 3, "helper": 3, "weight_fil": 3, "incur": 3, "sever": [3, 65, 66, 161, 162, 308], "futur": [3, 251, 304, 305], "pth": 3, "current": [3, 5, 6, 65, 66, 141, 198, 216, 305], "around": 3, "m1": [3, 303, 309], "ultra": 3, "7b": 3, "me": 3, "ishmael": 3, "year": 3, "ago": 3, "never": [3, 305], "long": 3, "info": [3, 6], "247": 3, "press": [3, 112], "enter": 3, "littl": 3, "monei": 3, "my": [3, 6], "purs": 3, "greater": [3, 24, 102, 139, 260, 292], "consequ": 3, "walk": 3, "down": 3, "gower": 3, "street": 3, "afternoon": 3, "heavi": 3, "rain": 3, "saw": [3, 303], "off": [3, 6, 305], "man": 3, "rag": 3, "who": 3, "sat": 3, "upon": [3, 209], "hi": 3, "bundl": 3, "hard": 3, "wet": 3, "he": [3, 265, 266], "were": [3, 309], "cry": 3, "watch": 3, "him": 3, "observ": 3, "numer": [3, 112, 120, 124, 169, 197, 198, 199, 200, 201, 202, 206, 218, 226, 227, 228, 252, 274, 276, 285, 305], "crowd": 3, "wa": [3, 205, 305], "hurri": 3, "437": 3, "330": 3, "second": [3, 74, 121, 123, 125, 179, 189, 198, 200, 201, 202, 274, 303, 309], "spent": 3, "amount": 3, "39": 3, "ms": 3, "By": [3, 303, 306], "bigger": [3, 198], "remain": [3, 189, 221, 222, 223], "almost": 3, "nobodi": 3, "took": 3, "least": [3, 63, 113, 141], "notic": [3, 303, 308], "distanc": [3, 285], "had": 3, "doubt": 3, "minut": 3, "straight": 3, "slowli": 3, "rais": [3, 112, 171, 238], "ey": 3, "speak": [3, 112], "resum": 3, "postur": 3, "stood": 3, "feel": 3, "pain": 3, "heart": 3, "smile": 3, "face": 3, "am": 3, "someon": 3, "three": 3, "quarter": 3, "hour": 3, "made": 3, "immedi": [3, 231], "repli": 3, "again": [3, 6, 216], "hand": [3, 303, 305], "did": 3, "accustom": 3, "thu": [3, 216], "question": [3, 305], "reason": [3, 304], "tell": [3, 306], "understand": [3, 263, 264], "579": 3, "690": 3, "num": [3, 114, 150], "500": [3, 309], "628": 3, "went": 3, "nervou": 3, "trembl": 3, "told": 3, "And": 3, "perhap": 3, "surpris": 3, "matter": [3, 216], "shall": 3, "anyhow": 3, "friend": 3, "ye": 3, "slight": [3, 305], "kind": 3, "longer": [3, 67, 303], "soon": 3, "unless": [3, 13, 112, 297], "unlik": [3, 13, 222, 223], "strang": 3, "amus": 3, "That": 3, "secret": 3, "disappoint": 3, "mine": 3, "cannot": [3, 63, 304, 306], "happi": 3, "ask": 3, "Is": [3, 258, 261], "shop": 3, "bui": 3, "food": 3, "633": 3, "21": 3, "475": 3, "su": 3, "j": [3, 6, 112, 199, 200, 202, 222], "lu": 3, "pan": 3, "murtadha": 3, "wen": 3, "liu": 3, "2021": 3, "roform": [3, 254], "enhanc": [3, 254, 305], "rotari": [3, 254], "arxiv": [3, 197, 203, 226, 227, 228, 230, 252, 286], "preprint": [3, 197, 203], "2104": 3, "09864": 3, "zhang": 3, "sennrich": 3, "2019": [3, 201], "root": [3, 157, 172, 252], "advanc": 3, "inform": [3, 4, 6, 159, 160, 216, 218, 225, 249, 303, 309], "system": [3, 6], "shazeer": 3, "2020": 3, "glu": 3, "variant": [3, 202, 284], "2002": 3, "05202": 3, "classifi": 4, "mnist": 4, "As": [4, 180, 216], "mlp": [4, 216, 261, 299], "inherit": [4, 302], "standard": [4, 37, 57, 125, 145, 261, 263, 265, 268, 307], "idiom": 4, "input_dim": [4, 216, 229, 251], "hidden_dim": [4, 297, 299], "output_dim": [4, 216, 229, 251], "layer_s": 4, "idim": 4, "odim": 4, "maximum": [4, 22, 63, 216, 253, 258, 271, 272, 288, 297, 305], "cross": [4, 273, 275], "entropi": [4, 273, 275], "sub": [4, 74, 150], "commonli": [4, 247, 294], "cross_entropi": [4, 216], "accuraci": 4, "valid": [4, 67, 96, 192, 208, 236, 246, 302], "eval_fn": 4, "argmax": 4, "loader": 4, "num_class": [4, 299], "batch_siz": [4, 299], "num_epoch": [4, 299], "learning_r": [4, 197, 198, 199, 200, 201, 202, 203, 206, 207, 299], "train_imag": [4, 299], "train_label": [4, 299], "test_imag": 4, "test_label": 4, "shuffl": 4, "minibatch": 4, "batch_iter": [4, 299], "perm": 4, "permut": 4, "id": [4, 6], "put": 4, "trainabl": [4, 196, 216, 297], "loss_and_grad_fn": [4, 299, 303], "value_and_grad": [4, 216, 247, 297, 299, 303, 306, 307], "epoch": 4, "test": [4, 6], "confus": 4, "decent": 4, "95": 4, "brought": 5, "research": 5, "except": [5, 83, 90, 91, 93, 94, 95, 226, 238, 304, 306], "featur": [5, 65, 66, 218, 226, 227, 228, 229, 251, 252, 254, 261, 305], "main": [5, 74, 83, 209, 216], "differ": [5, 177, 284, 303], "lazi": [5, 297, 307], "multi": [5, 219, 220, 304, 306], "cpu": [5, 113, 309], "gpu": [5, 304, 309], "inspir": 5, "jax": [5, 300], "arrayfir": 5, "unifi": 5, "live": [5, 309], "guid": 5, "convers": 5, "regress": [5, 281], "layer": [5, 216, 222, 223, 226, 228, 229, 243, 248, 251, 256, 261, 293, 297], "perceptron": 5, "llm": 5, "infer": [5, 99, 115], "fft": 5, "algebra": 5, "tree": [5, 80, 100, 189, 192, 204, 208, 209, 210, 303], "develop": [5, 6], "document": [5, 47, 58, 159, 160, 303, 304], "meet": 6, "seri": 6, "chip": 6, "nativ": 6, "maco": 6, "13": 6, "recommend": [6, 203], "14": 6, "sonoma": 6, "distribut": [6, 143, 144, 145, 147, 151, 152, 229, 263, 264, 265, 266, 268, 269, 276, 279, 283, 285, 294], "probabl": [6, 148, 221, 222, 223, 251, 273, 275, 279, 309], "platform": 6, "processor": 6, "arm": [6, 212], "i386": 6, "switch": 6, "conda": 6, "17": 6, "g": [6, 112, 141, 206, 207, 293, 305, 310], "clang": 6, "cmake": 6, "24": 6, "xcode": 6, "15": [6, 112], "environ": 6, "via": [6, 305, 306], "rosetta": 6, "unam": 6, "p": [6, 143, 200, 202, 216, 221, 222, 223, 285], "clone": 6, "git": 6, "com": 6, "ml": 6, "explor": 6, "cd": 6, "brew": 6, "global": [6, 149, 300], "env": 6, "cmake_build_parallel_level": 6, "edit": [6, 248], "unittest": 6, "discov": 6, "stub": 6, "dev": 6, "generate_stub": 6, "mkdir": 6, "either": [6, 11, 47, 57, 58, 63, 75, 76, 77, 98, 101, 102, 110, 111, 112, 120, 125, 127, 130, 132, 177, 189, 256, 265, 266], "libmlx": 6, "preprocessor": 6, "metal_path": 6, "mlx_build_test": 6, "ON": 6, "mlx_build_exampl": 6, "mlx_build_benchmark": 6, "mlx_build_python_bind": 6, "multipl": [6, 125, 132, 141, 142, 249, 258, 305, 308], "wish": 6, "variabl": [6, 100, 109, 189, 191, 192], "export": 6, "developer_dir": 6, "app": 6, "content": [6, 235], "sdk": 6, "xcrun": 6, "macosx": 6, "show": [6, 212], "unabl": 6, "tool": 6, "select": [6, 193, 231, 235], "sudo": 6, "ouptut": 6, "finder": 6, "iterm": 6, "termin": 6, "click": 6, "uncheck": 6, "window": 6, "restart": 6, "devicetyp": 7, "attribut": [7, 8, 9, 26, 297], "kwarg": [8, 161, 162, 310], "union": [10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 28, 29, 30, 31, 32, 33, 34, 36, 38, 39, 40, 41, 42, 43, 45, 46, 47, 48, 49, 51, 53, 54, 55, 56, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 72, 73, 74, 75, 76, 77, 78, 79, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127, 128, 129, 130, 131, 132, 133, 135, 136, 137, 138, 139, 140, 141, 142, 143, 144, 145, 147, 148, 150, 151, 152, 153, 154, 155, 156, 157, 159, 165, 166, 167, 168, 169, 170, 171, 172, 173, 174, 175, 176, 177, 178, 179, 180, 181, 182, 183, 184, 185, 186, 187, 188, 189, 190, 193, 194, 195, 220, 236, 238, 246], "absolut": [10, 13, 271, 272, 284], "semant": [11, 61, 75, 76, 77, 101, 102, 110, 111, 120, 125, 127, 130, 132, 177, 309], "keepdim": [12, 14, 22, 23, 29, 30, 31, 32, 40, 41, 42, 43, 45, 56, 59, 112, 124, 126, 128, 129, 140, 169, 178, 190], "reduct": [12, 14, 124, 126, 129, 140, 273, 274, 275, 276, 277, 278, 279, 280, 281, 282, 283, 284, 285], "reduc": [12, 14, 22, 23, 124, 126, 128, 129, 140, 178, 190, 218, 261, 281], "unspecifi": [12, 14, 15, 22, 23, 24, 25, 64, 99, 124, 126, 128, 129, 135, 139, 140, 154, 169, 170, 178, 180, 190, 194, 310], "entir": [12, 14, 22, 23, 124, 126, 128, 129, 140, 178, 190, 222, 223], "singleton": [12, 14, 22, 23, 124, 125, 126, 128, 129, 140, 178, 190], "rtol": 13, "05": [13, 218, 226, 227, 228, 252], "atol": 13, "08": [13, 199, 200, 201, 202, 206, 274], "approxim": [13, 225, 270, 271, 272], "comparison": [13, 77, 101, 102, 110, 111], "equal": [13, 24, 60, 83, 102, 111, 139, 148, 171, 227, 229], "ab": [13, 112, 189, 226, 227, 228, 230, 252, 286], "array_equ": 13, "rel": [13, 198], "toler": 13, "boolean": [13, 60, 105, 106, 107, 108, 121, 122, 123, 212, 245, 304], "interv": [15, 114, 148, 152], "increment": 15, "otherwis": [15, 208, 209, 236, 238, 246, 260, 261, 273, 278, 284, 291, 292, 305, 306], "int32": [15, 96, 112, 148, 212, 304, 307], "convent": [15, 67, 201], "lead": 15, "fraction": 15, "integr": [15, 180, 305], "invers": [16, 17, 18, 19, 20, 21, 79, 87, 88, 89, 90, 91, 92], "cosin": [16, 17, 68, 69, 254, 274, 303], "hyperbol": [17, 19, 21, 69, 168, 183], "sine": [18, 19, 167, 168, 254, 303], "minimum": [23, 63, 258, 274], "kth": [24, 139], "partit": 24, "order": [24, 112, 139, 141, 216, 226, 247, 256, 303], "undefin": [24, 139, 304], "sort": [24, 25, 139], "flatten": [24, 25, 112, 137, 139, 154, 170, 180, 181, 208], "dimension": [26, 84, 85, 86, 87, 88, 89, 93, 94, 95, 218, 219, 220, 224, 229, 251, 258, 304, 306], "val": [26, 99], "tupl": [26, 47, 50, 58, 64, 66, 76, 80, 82, 109, 112, 113, 138, 141, 155, 174, 189, 191, 198, 200, 201, 202, 203, 208, 209, 210, 220, 238, 240, 254, 256, 302, 303], "ndarrai": [26, 304, 305, 307], "properti": [27, 35, 44, 50, 52, 245, 303], "argument": [27, 47, 58, 80, 100, 189, 209, 216, 300, 303, 308, 309, 310], "decim": [48, 156], "indices_or_sect": [53, 171], "nest": [57, 216, 297, 302, 303], "ddof": [59, 190], "equal_nan": [13, 60], "nan": [13, 60, 106], "a_min": 63, "a_max": 63, "edg": [63, 138], "At": 63, "anoth": [63, 125, 177, 193, 216, 231, 303, 304, 309], "pad": [65, 66, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 219, 220], "dilat": [65, 66], "group": [65, 66, 72, 141, 142, 226, 251], "1d": [65, 67, 159, 181], "convolut": [65, 66, 67, 219, 220, 222, 223], "channel": [65, 66, 218, 219, 220, 222, 223], "c_in": [65, 66], "c_out": [65, 66], "convolv": [65, 66], "2d": [66, 74, 141, 218, 222], "spatial": [66, 226], "symmetr": 66, "discret": [67, 84, 85, 86, 87, 88, 89, 93, 94, 95, 224], "swap": [67, 179, 248, 251], "conv": 67, "filter": [67, 219, 220, 231, 235], "flip": 67, "signal": 67, "bias": [72, 141, 142, 236, 246, 249], "group_siz": [72, 141, 142, 251], "64": [72, 141, 142, 212, 251], "configur": 72, "formal": [72, 141], "notat": [72, 208, 240], "quantiz": [72, 115, 142, 251], "w_i": [72, 141], "hat": [72, 141], "occupi": [72, 141, 142], "divis": [75, 98, 141], "quotient": [75, 76, 98], "remaind": 76, "fuction": 76, "faster": [76, 270, 303], "mathrm": [78, 165, 227], "frac": [78, 141, 165, 197, 199, 200, 201, 202, 206, 218, 221, 222, 223, 226, 227, 228, 229, 252, 263, 264, 265, 266, 274, 276, 278, 281], "pi": [78, 258, 303], "int_0": 78, "dx": [], "erf": 79, "node": [80, 192], "dict": [80, 115, 159, 160, 161, 241, 244, 247, 248, 297, 302, 303, 308], "leaf": [80, 208, 209, 235], "exponenti": [81, 255, 289], "insert": [74, 82, 309], "ident": [83, 176, 216, 243], "diagon": [73, 83, 186, 187, 188], "zero": [83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 186, 187, 188, 195, 198, 216, 221, 222, 223, 238, 262, 263, 264, 265, 266, 267, 268, 269, 294, 304], "th": [73, 83], "whose": [83, 196], "One": [84, 87, 93, 157, 303], "fourier": [84, 85, 86, 87, 88, 89, 93, 94, 95], "truncat": [84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 151], "dft": [84, 85, 86, 87, 88, 89, 93, 94, 95], "rfft": 90, "real": [90, 91, 92, 93, 94, 95], "rfft2": 91, "rfftn": 92, "silent": [93, 94, 95], "start_axi": 96, "end_axi": 96, "integ": [98, 112, 138, 141, 142, 143, 148, 171, 184, 192, 212, 224, 304], "floor": 98, "fun": [100, 109, 189, 191, 192, 304, 305, 309], "argnam": [100, 189], "cpp_function": [], "neither": [100, 189], "keyword": [100, 161, 162, 189, 209, 216, 300, 308, 310], "strict": [101, 110, 236, 238, 246], "ordinari": 104, "ord": 112, "tabl": [112, 212, 224], "frobeniu": 112, "matric": [112, 113], "strictli": 112, "mathemat": 112, "variou": 112, "purpos": 112, "calcul": [112, 198, 276], "fro": 112, "inf": [112, 249], "largest": 112, "sing": 112, "smallest": 112, "singular": 112, "nuclear": 112, "_f": 112, "sum_": [112, 281], "a_": 112, "valueerror": [112, 238, 303], "refer": [112, 227, 230, 263, 264, 265, 266, 286, 304], "golub": 112, "van": 112, "loan": 112, "baltimor": 112, "md": 112, "john": 112, "hopkin": 112, "univers": 112, "1985": 112, "pg": 112, "la": 112, "arang": [112, 304, 306], "9": [112, 197, 200, 201, 202, 203, 275, 306], "74597": 112, "20": 112, "84804": 112, "41421": 112, "23607": [112, 113], "74166": 112, "24264": 112, "11": 112, "225": 112, "50": 114, "evenli": 114, "binari": [115, 158, 159, 160, 161, 162, 260, 273, 292], "npy": [115, 158, 308], "safetensor": [115, 160, 238, 242, 305, 308], "gguf": [115, 159, 308], "unsupport": 115, "tensor": [115, 184, 285, 306], "natur": [116, 118, 305], "logarithm": [116, 117, 118, 119], "log": [118, 120, 124, 276, 279, 281, 283], "plu": 118, "exp": [120, 124, 145, 169, 255, 279, 289, 309], "stabl": [120, 124, 169, 281], "prepend": 125, "remov": [74, 125, 144, 174, 275], "negat": 133, "beforehand": 137, "pad_with": 138, "constant_valu": 138, "pad_width": 138, "before_1": 138, "after_1": 138, "before_2": 138, "after_2": 138, "before_n": 138, "after_n": 138, "before_i": 138, "after_i": 138, "extend": 138, "side": 138, "smaller": [139, 203], "everi": [141, 209, 303], "particular": [141, 226], "consecut": [141, 254], "w_1": 141, "w_g": 141, "begin": [141, 255, 260, 278, 284, 289, 291, 292], "align": 141, "max_i": 141, "min_i": 141, "textrm": [141, 225, 270], "round": 141, "pack": [141, 142], "unsign": [141, 142, 212], "lower": [141, 148, 151, 152, 186, 269], "upper": [141, 148, 151, 152, 269], "1st": 141, "signific": 141, "2nd": 141, "dequant": 141, "w_q": 141, "whether": [142, 235, 249, 273, 276], "prng": [143, 144, 145, 146, 147, 148, 149, 150, 151, 152, 300], "num_sampl": 144, "unnorm": [144, 273, 275], "draw": 144, "uint32": [22, 23, 24, 25, 144, 212], "cdf": [145, 225, 270], "accord": [145, 193, 249, 263, 264, 265, 266], "seed": 146, "low": [148, 152, 269, 294], "high": [148, 152, 216, 224, 269, 294], "bound": [148, 151, 152, 225, 269, 304, 309], "roadcast": 148, "domain": 151, "uniformli": 152, "repetit": 154, "preserv": [155, 303], "reciproc": 157, "arr": [158, 304], "uncompress": 161, "my_path": 161, "tree_flatten": [161, 209, 210, 216], "transformerencod": 161, "128": [161, 216], "flat_param": 161, "compress": 162, "simplif": [], "reus": [], "consumpt": [], "meant": [], "overhead": [305, 309], "1m": [], "thousand": 305, "foo": [], "matmul": 309, "twice": 309, "subarrai": [74, 171], "being": [176, 216], "prevent": [176, 285, 306], "flow": [176, 305], "unchang": [176, 254], "axis1": [74, 179], "axis2": [74, 179], "taken": [74, 180], "prior": [180, 181], "exclud": 181, "dot": [184, 208, 240, 249], "elsewher": [186, 304], "col": 186, "triangl": 186, "mse": 189, "param": [189, 216, 294, 303], "lvalu": 189, "dlvalu": 189, "dparam": 189, "lasso": 189, "l1": [189, 278, 280, 281, 284], "varianc": [190, 218, 226, 276], "divisor": 190, "cotang": 191, "in_ax": [192, 303], "out_ax": [192, 303], "prefix": [192, 208], "fn": [196, 209, 307], "callabl": [196, 208, 209, 231, 232, 235, 256, 261, 262, 263, 264, 265, 266, 267, 268, 269], "wrt": 196, "rho": 197, "06": [197, 276, 285], "paper": [197, 198, 199, 200, 202, 203, 218, 258], "zeiler": 197, "2012": [197, 206], "adapt": [197, 198, 199], "1212": 197, "5701": 197, "v_": [197, 199, 200, 201, 202, 206, 207], "v_t": [197, 199, 200, 201, 202, 206, 207], "g_t": [197, 199, 200, 201, 202, 203, 206, 207], "delta": [197, 278], "w_": [197, 198, 199, 200, 201, 202, 203, 206, 207], "u_t": 197, "epsilon": [197, 199, 200, 201, 202, 206, 218, 226, 227, 228, 252, 274, 276], "u_": 197, "w_t": [197, 199, 200, 201, 202, 203, 206, 207], "lambda": [197, 198, 199, 200, 201, 202, 203, 206, 207, 209, 216, 231, 236, 255, 259, 289, 291, 303], "averag": [197, 198, 200, 201, 202], "denomin": [197, 199, 200, 201, 202, 206, 227, 274], "stabil": [197, 198, 199, 200, 201, 202, 206, 218, 226, 227, 228, 252, 274, 276], "duchi": 199, "hazan": 199, "singer": 199, "2011": 199, "subgradi": 199, "onlin": 199, "stochast": [199, 200, 202, 207, 305], "jmlr": 199, "999": [200, 201, 202], "omit": [200, 202], "estim": [200, 202], "kingma": [200, 202], "ba": [200, 202], "2015": [200, 202, 222], "iclr": [200, 201, 202], "m_": [200, 201, 202, 203], "beta_1": [198, 200, 201, 202, 203], "m_t": [200, 201, 202, 203], "beta_2": [200, 201, 202, 203], "weight_decai": [198, 201, 203, 207], "contrast": [201, 205], "loshchilov": 201, "hutter": 201, "decoupl": 201, "decai": [198, 201, 203, 207], "regular": [201, 222, 230, 286, 304], "adam": [202, 203], "infin": [105, 107, 108, 202], "99": [203, 206], "sign": [13, 203, 212], "tend": 203, "larger": [203, 254], "10x": 203, "adamw": 203, "maintain": [203, 222, 223], "strength": [203, 207], "wd": 203, "chen": 203, "symbol": 203, "discoveri": 203, "2302": 203, "06675": 203, "c_": 203, "eta": 203, "c_t": 203, "momentum": [203, 207, 218], "basi": 204, "appli": [204, 209, 216, 218, 219, 220, 222, 223, 225, 226, 227, 228, 229, 230, 232, 243, 250, 251, 252, 253, 255, 257, 259, 260, 270, 273, 274, 275, 276, 277, 278, 279, 280, 281, 282, 283, 284, 285, 286, 287, 288, 289, 290, 291, 292, 294], "optimizerst": 204, "recurs": [205, 216, 235, 236, 241, 244, 246, 297], "defaultdict": 205, "miss": [205, 238, 308], "present": 205, "tieleman": 206, "hinton": 206, "lectur": 206, "coursera": 206, "smooth": [206, 275, 284], "dampen": 207, "nesterov": 207, "descent": [207, 305], "mu": 207, "tau": 207, "l2": [207, 278, 281], "penalti": 207, "is_leaf": [208, 209], "arbitrari": [208, 297], "depth": [208, 223, 303], "hello": [208, 210], "charact": 208, "flat": [208, 210], "superset": 209, "extra": 209, "closer": 209, "constitut": 209, "dict_kei": 209, "recreat": 210, "world": 210, "42": 210, "byte": 212, "bool_": 212, "uint8": 212, "uint16": 212, "16": [212, 227, 231, 297], "uint64": 212, "int8": 212, "int16": 212, "int64": 212, "arbitrarili": [216, 302, 303, 307], "done": [216, 221, 305, 306], "manual": 216, "explicitli": [216, 300], "solv": 216, "intuit": 216, "freez": [216, 246, 297], "finetun": 216, "in_dim": [216, 297], "out_dim": [216, 297], "enumer": 216, "caus": [216, 305], "local": [216, 222], "scope": 216, "l2_loss": 216, "y_hat": 216, "trainable_paramet": [216, 235], "loss_and_grad": 216, "workhors": 216, "Its": 216, "frozen": [216, 236, 244, 246, 251, 297], "individu": [216, 222, 223], "subset": [216, 235], "action": 216, "displai": 216, "tree_map": 216, "count": 216, "num_param": 216, "preclud": 216, "pure": [216, 299], "pattern": [216, 305], "achiev": 216, "other_input": 216, "necessari": 216, "wrap": 216, "apply_to_modul": [216, 236], "children": 216, "filter_and_map": 216, "leaf_modul": 216, "load_weight": [216, 305], "named_modul": 216, "save_weight": 216, "unfreez": [216, 236], "update_modul": 216, "sequenti": [216, 294], "relu": [216, 250, 261, 287, 294], "prelu": 216, "gelu": [216, 271, 272], "silu": 216, "selu": 216, "mish": 216, "quantizedlinear": 216, "conv1d": 216, "conv2d": 216, "batchnorm": 216, "layernorm": 216, "groupnorm": 216, "instancenorm": 216, "dropout": [216, 222, 223, 243, 261], "dropout2d": 216, "dropout3d": 216, "alibi": 216, "sinusoidalpositionalencod": 216, "gelu_approx": [216, 225, 270], "gelu_fast_approx": [216, 225, 270], "binary_cross_entropi": 216, "kl_div_loss": 216, "l1_loss": 216, "mse_loss": 216, "nll_loss": 216, "smooth_l1_loss": 216, "triplet_loss": 216, "hinge_loss": 216, "huber_loss": 216, "log_cosh_loss": 216, "cosine_similarity_loss": 216, "affin": [218, 226, 227, 228, 229, 251], "track_running_stat": 218, "var": [218, 226, 227, 228, 276], "gamma": [218, 226, 227, 228, 252, 263, 264, 265, 266], "nc": 218, "nlc": [218, 219], "four": 218, "nhwc": [218, 220], "height": [218, 220, 222, 223], "width": [218, 220, 222, 223, 251], "deep": [218, 263, 264, 265, 266], "intern": 218, "covari": 218, "shift": 218, "bn": 218, "in_channel": [219, 220], "out_channel": [219, 220], "kernel_s": [219, 220], "learnabl": [219, 220, 256], "portion": 221, "dure": [221, 222, 223, 306], "independ": [222, 223], "nwhc": 222, "whc": 222, "entri": [222, 223], "benefici": [222, 223, 305], "earli": 222, "adjac": 222, "pixel": 222, "correl": 222, "thompson": 222, "goroshin": 222, "jain": 222, "lecun": 222, "bregler": 222, "cvpr": 222, "ndhwc": 223, "dhwc": 223, "medic": 223, "video": 223, "num_embed": 224, "lookup": 224, "typic": [224, 299, 305], "usual": [224, 302, 305], "vocabulari": 224, "approx": 225, "unit": [225, 253, 255, 257, 263, 264, 265, 266, 270, 271, 272, 288, 289, 290], "phi": [225, 270], "geluapprox": 225, "sigma": [225, 257, 263, 264, 265, 266, 271, 272, 290], "60033": [225, 271], "0433603": [225, 271], "gelufast": 225, "773": [225, 272], "regard": 225, "num_group": 226, "pytorch_compat": 226, "split": 226, "preced": 226, "http": [226, 227, 228, 230, 252, 286], "org": [226, 227, 228, 230, 252, 286], "1803": 226, "08494": 226, "inorm": 227, "1607": [227, 228], "08022": 227, "06450": 228, "uniform": [216, 229, 238, 264, 266, 294, 300, 303, 309], "mathcal": 229, "u": 229, "d_i": 229, "monoton": [230, 286], "1908": [230, 286], "08681": [230, 286], "tanh": [230, 286], "softplu": [230, 286], "map_fn": [231, 235], "filter_fn": [231, 235], "valid_parameter_filt": 231, "apply_fn": 232, "descend": 233, "is_leaf_fn": 235, "found": 235, "drop": 235, "idempot": [236, 246], "attent": [236, 249, 258, 261], "endswith": 236, "file_or_weight": 238, "ok": [238, 303], "certain": 243, "ie": 246, "noop": 246, "unfrozen": 246, "chang": [247, 251, 278, 284, 306], "tracer": 247, "partial": [247, 248, 305], "child": 248, "programmat": 248, "query_input_dim": 249, "key_input_dim": 249, "value_input_dim": 249, "value_dim": 249, "value_output_dim": 249, "head": [249, 261], "aggreg": 249, "linearli": 249, "neg": [74, 96, 107, 249, 276, 283, 285, 304], "attend": 249, "num_paramet": 250, "init": [216, 250, 294], "25": 250, "parametr": [250, 287], "classmethod": 251, "from_linear": 251, "quantize_modul": 251, "1910": 252, "07467": 252, "rectifi": [253, 265, 266, 288], "10000": 254, "rotat": 254, "slightli": [254, 309], "angular": 254, "frequenc": [254, 258], "_cos_sin_theta_kei": 254, "precomput": 254, "_cos_sin_theta_valu": 254, "leq": [255, 278, 289], "0507": [255, 289], "67326": [255, 289], "elu": [255, 289], "plain": 256, "known": [257, 290], "swish": [257, 290], "cdot": [257, 271, 272, 274, 277, 290], "min_freq": 258, "0001": 258, "max_freq": 258, "cos_first": 258, "full_turn": 258, "sinusoid": 258, "sin": [258, 303, 307], "threshold": [260, 278, 284, 292], "geq": [260, 292], "num_encoder_lay": 261, "num_decoder_lay": 261, "custom_encod": 261, "custom_decod": 261, "norm_first": 261, "decod": 261, "interact": 261, "mechan": 261, "hidden": 261, "exact": [271, 272], "0003": 271, "015": 272, "pre": [], "predict": [273, 276, 277, 278, 279, 280, 281, 282, 283, 284], "105361": 273, "223144": 273, "20397": 273, "916291": 273, "612192": [], "x1": 274, "x2": 274, "x_1": 274, "x_2": 274, "label_smooth": 275, "hing": 277, "y_": [277, 281], "pred": [277, 281], "huber": 278, "l_": 278, "kullback": 279, "leibler": 279, "diverg": 279, "cosh": 281, "logcosh": 281, "sensit": 281, "outlier": 281, "dual": 281, "behavior": [281, 304, 305], "offer": 281, "balanc": 281, "robust": 281, "approach": [281, 303], "task": 281, "likelihood": [276, 283], "nll": [276, 283], "formula": 284, "anchor": 285, "margin": 285, "triplet": 285, "_p": 285, "degre": 285, "pairwis": 285, "instabl": 285, "subclass": 297, "concept": 297, "mymlp": 297, "in_proj": 297, "subsequ": 299, "implicit": [300, 303], "fine": [300, 305], "grain": 300, "control": [300, 305], "manag": [300, 309], "pseudo": 300, "altern": 300, "splittabl": 300, "threefri": 300, "counter": 300, "cycl": 302, "slice": 304, "ellipsi": 304, "syntax": 304, "idx": 304, "mix": 304, "take_along_axi": 304, "lack": 304, "propag": [303, 304], "extrem": [304, 305], "ineffici": [304, 305], "nonzero": 304, "reflect": [304, 306], "dfdx": [303, 304], "record": 305, "nice": [303, 305], "rerun": 305, "dynam": 305, "easier": 305, "worri": 305, "fun1": 305, "expensive_fun": 305, "cost": [198, 305], "code": 305, "consum": 305, "eager": 305, "thank": 305, "weights_fp16": 305, "trade": 305, "too": 305, "bad": 305, "idea": [303, 305], "On": [303, 305], "grow": 305, "computation": 305, "costli": 305, "wide": 305, "pretti": 305, "ten": [303, 305], "okai": 305, "outer": 305, "value_and_grad_fn": 305, "awar": 305, "implicitli": 305, "anytim": 305, "memoryview": [305, 306], "perfectli": 305, "first_lay": 305, "second_layer_a": 305, "second_layer_b": 305, "frequent": 305, "protocol": 306, "receiv": 306, "pep": 306, "3118": 306, "view": 306, "a_view": 306, "owndata": 306, "quit": [303, 306], "power": [303, 306], "extern": 306, "x_view": 306, "modifi": 306, "df": 306, "x\u00b2": 306, "2x": 306, "indirectli": 306, "modif": 306, "seen": 306, "occur": 306, "incorpor": 306, "issu": [303, 306], "incorrect": 306, "experiment": 306, "break": 306, "advis": 306, "intermedi": 306, "jnp": 306, "tf": 306, "inspect": 307, "page": 307, "composit": 307, "archiv": 308, "savez_compress": 308, "save_safetensor": [242, 308], "save_gguf": 308, "arr_0": 308, "pool": 309, "advantag": 309, "don": 309, "parallel": 309, "race": 309, "interest": 309, "albeit": 309, "contriv": [303, 309], "suppos": [303, 309], "d1": 309, "d2": 309, "4096": [303, 309], "dens": 309, "better": [303, 309], "millisecond": 309, "measur": 309, "default_stream": 310, "default_devic": 310, "my_devic": 310, "pypi": 6, "forg": 6, "grep": 6, "cmake_host_system_processor": 6, "arm64": 6, "x86_64": 6, "wipe": 6, "cahc": 6, "rf": 6, "inifn": 105, "behind": 303, "d2fdx2": 303, "differentiaion": 303, "backward": 303, "zero_grad": 303, "detach": 303, "requires_grad": 303, "dloss_dw": 303, "dloss_dx": 303, "lot": 303, "redund": 303, "stop_gradi": 303, "autom": 303, "sake": 303, "clariti": 303, "difficult": 303, "primit": 303, "priorit": 303, "xs": 303, "ys": 303, "naive_add": 303, "vmap_add": 303, "timeit": 303, "total": 303, "390": 303, "wherea": 303, "025": 303, "Of": 303, "handi": 303, "infinit": 13, "dt": 78, "inclus": 96, "outsid": 96, "clamp": 96, "factorizatoin": 113, "q": 113, "894427": 113, "447214": 113, "57771": 113, "return_metadata": 115, "matadata": 115, "obj": 159, "30": 198, "001": 198, "clip_threshold": 198, "decay_r": 198, "scale_paramet": 198, "relative_step": 198, "warmup_init": 198, "sublinear": 198, "epsilon_1": 198, "epsilon_2": 198, "parameter_scal": 198, "clip": 198, "unscal": 198, "softshrink": 216, "gaussian_nll_loss": 216, "glorot_norm": 216, "glorot_uniform": 216, "he_norm": 216, "he_uniform": 216, "lambd": [259, 291], "checkpoint": 261, "chekpoint": 261, "usag": 261, "expens": 261, "init_fn": [262, 263, 264, 265, 266, 267, 268, 269, 294], "glorot": [263, 264], "deviat": [263, 265, 268], "fan_in": [263, 264, 265, 266], "fan_out": [263, 264, 265, 266], "difficulti": [263, 264], "feedforward": [263, 264], "191107": 263, "61278": 263, "150594": 263, "363207": 263, "gain": [263, 264, 265, 266], "89613": 263, "53947": 263, "48095": 263, "995016": 263, "223404": 264, "890597": 264, "379159": 264, "776856": 264, "90041": 264, "02264": 264, "912766": 264, "12451": 264, "fan": [265, 266], "delv": [265, 266], "surpass": [265, 266], "human": [265, 266], "level": [265, 266], "imagenet": [265, 266], "classif": [265, 266], "25211": 265, "458835": 265, "177208": 265, "0137595": 265, "6967": 265, "02765": 265, "15268": 265, "75787": 265, "kaim": 266, "0300242": 266, "0184009": 266, "793615": 266, "666329": 266, "64331": 266, "16506": 266, "08619": 266, "79854": 266, "982273": 268, "534422": 268, "380709": 268, "0645099": 268, "883935": 269, "863726": 269, "617261": 269, "417497": 269, "with_logit": 273, "539245": 273, "prob": 273, "510826": 273, "hot": 275, "0485873": 275, "348587": 275}, "objects": {"mlx.core": [[7, 0, 1, "", "Device"], [8, 0, 1, "", "Dtype"], [9, 0, 1, "", "Stream"], [10, 2, 1, "", "abs"], [11, 2, 1, "", "add"], [12, 2, 1, "", "all"], [13, 2, 1, "", "allclose"], [14, 2, 1, "", "any"], [15, 2, 1, "", "arange"], [16, 2, 1, "", "arccos"], [17, 2, 1, "", "arccosh"], [18, 2, 1, "", "arcsin"], [19, 2, 1, "", "arcsinh"], [20, 2, 1, "", "arctan"], [21, 2, 1, "", "arctanh"], [22, 2, 1, "", "argmax"], [23, 2, 1, "", "argmin"], [24, 2, 1, "", "argpartition"], [25, 2, 1, "", "argsort"], [26, 0, 1, "", "array"], [60, 2, 1, "", "array_equal"], [61, 2, 1, "", "broadcast_to"], [62, 2, 1, "", "ceil"], [63, 2, 1, "", "clip"], [64, 2, 1, "", "concatenate"], [65, 2, 1, "", "conv1d"], [66, 2, 1, "", "conv2d"], [67, 2, 1, "", "convolve"], [68, 2, 1, "", "cos"], [69, 2, 1, "", "cosh"], [70, 2, 1, "", "default_device"], [71, 2, 1, "", "default_stream"], [72, 2, 1, "", "dequantize"], [73, 2, 1, "", "diag"], [74, 2, 1, "", "diagonal"], [75, 2, 1, "", "divide"], [76, 2, 1, "", "divmod"], [77, 2, 1, "", "equal"], [78, 2, 1, "", "erf"], [79, 2, 1, "", "erfinv"], [80, 2, 1, "", "eval"], [81, 2, 1, "", "exp"], [82, 2, 1, "", "expand_dims"], [83, 2, 1, "", "eye"], [96, 2, 1, "", "flatten"], [97, 2, 1, "", "floor"], [98, 2, 1, "", "floor_divide"], [99, 2, 1, "", "full"], [100, 2, 1, "", "grad"], [101, 2, 1, "", "greater"], [102, 2, 1, "", "greater_equal"], [103, 2, 1, "", "identity"], [104, 2, 1, "", "inner"], [105, 2, 1, "", "isinf"], [106, 2, 1, "", "isnan"], [107, 2, 1, "", "isneginf"], [108, 2, 1, "", "isposinf"], [109, 2, 1, "", "jvp"], [110, 2, 1, "", "less"], [111, 2, 1, "", "less_equal"], [114, 2, 1, "", "linspace"], [115, 2, 1, "", "load"], [116, 2, 1, "", "log"], [117, 2, 1, "", "log10"], [118, 2, 1, "", "log1p"], [119, 2, 1, "", "log2"], [120, 2, 1, "", "logaddexp"], [121, 2, 1, "", "logical_and"], [122, 2, 1, "", "logical_not"], [123, 2, 1, "", "logical_or"], [124, 2, 1, "", "logsumexp"], [125, 2, 1, "", "matmul"], [126, 2, 1, "", "max"], [127, 2, 1, "", "maximum"], [128, 2, 1, "", "mean"], [129, 2, 1, "", "min"], [130, 2, 1, "", "minimum"], [131, 2, 1, "", "moveaxis"], [132, 2, 1, "", "multiply"], [133, 2, 1, "", "negative"], [134, 2, 1, "", "new_stream"], [135, 2, 1, "", "ones"], [136, 2, 1, "", "ones_like"], [137, 2, 1, "", "outer"], [138, 2, 1, "", "pad"], [139, 2, 1, "", "partition"], [140, 2, 1, "", "prod"], [141, 2, 1, "", "quantize"], [142, 2, 1, "", "quantized_matmul"], [153, 2, 1, "", "reciprocal"], [154, 2, 1, "", "repeat"], [155, 2, 1, "", "reshape"], [156, 2, 1, "", "round"], [157, 2, 1, "", "rsqrt"], [158, 2, 1, "", "save"], [159, 2, 1, "", "save_gguf"], [160, 2, 1, "", "save_safetensors"], [161, 2, 1, "", "savez"], [162, 2, 1, "", "savez_compressed"], [163, 2, 1, "", "set_default_device"], [164, 2, 1, "", "set_default_stream"], [165, 2, 1, "", "sigmoid"], [166, 2, 1, "", "sign"], [167, 2, 1, "", "sin"], [168, 2, 1, "", "sinh"], [169, 2, 1, "", "softmax"], [170, 2, 1, "", "sort"], [171, 2, 1, "", "split"], [172, 2, 1, "", "sqrt"], [173, 2, 1, "", "square"], [174, 2, 1, "", "squeeze"], [175, 2, 1, "", "stack"], [176, 2, 1, "", "stop_gradient"], [177, 2, 1, "", "subtract"], [178, 2, 1, "", "sum"], [179, 2, 1, "", "swapaxes"], [180, 2, 1, "", "take"], [181, 2, 1, "", "take_along_axis"], [182, 2, 1, "", "tan"], [183, 2, 1, "", "tanh"], [184, 2, 1, "", "tensordot"], [185, 2, 1, "", "transpose"], [186, 2, 1, "", "tri"], [187, 2, 1, "", "tril"], [188, 2, 1, "", "triu"], [189, 2, 1, "", "value_and_grad"], [190, 2, 1, "", "var"], [191, 2, 1, "", "vjp"], [192, 2, 1, "", "vmap"], [193, 2, 1, "", "where"], [194, 2, 1, "", "zeros"], [195, 2, 1, "", "zeros_like"]], "mlx.core.Device": [[7, 1, 1, "", "__init__"]], "mlx.core.Dtype": [[8, 1, 1, "", "__init__"]], "mlx.core.Stream": [[9, 1, 1, "", "__init__"]], "mlx.core.array": [[27, 3, 1, "", "T"], [26, 1, 1, "", "__init__"], [28, 1, 1, "", "abs"], [29, 1, 1, "", "all"], [30, 1, 1, "", "any"], [31, 1, 1, "", "argmax"], [32, 1, 1, "", "argmin"], [33, 1, 1, "", "astype"], [34, 1, 1, "", "cos"], [35, 3, 1, "", "dtype"], [36, 1, 1, "", "exp"], [37, 1, 1, "", "item"], [38, 1, 1, "", "log"], [39, 1, 1, "", "log1p"], [40, 1, 1, "", "logsumexp"], [41, 1, 1, "", "max"], [42, 1, 1, "", "mean"], [43, 1, 1, "", "min"], [44, 3, 1, "", "ndim"], [45, 1, 1, "", "prod"], [46, 1, 1, "", "reciprocal"], [47, 1, 1, "", "reshape"], [48, 1, 1, "", "round"], [49, 1, 1, "", "rsqrt"], [50, 3, 1, "", "shape"], [51, 1, 1, "", "sin"], [52, 3, 1, "", "size"], [53, 1, 1, "", "split"], [54, 1, 1, "", "sqrt"], [55, 1, 1, "", "square"], [56, 1, 1, "", "sum"], [57, 1, 1, "", "tolist"], [58, 1, 1, "", "transpose"], [59, 1, 1, "", "var"]], "mlx.core.fft": [[84, 2, 1, "", "fft"], [85, 2, 1, "", "fft2"], [86, 2, 1, "", "fftn"], [87, 2, 1, "", "ifft"], [88, 2, 1, "", "ifft2"], [89, 2, 1, "", "ifftn"], [90, 2, 1, "", "irfft"], [91, 2, 1, "", "irfft2"], [92, 2, 1, "", "irfftn"], [93, 2, 1, "", "rfft"], [94, 2, 1, "", "rfft2"], [95, 2, 1, "", "rfftn"]], "mlx.core.linalg": [[112, 2, 1, "", "norm"], [113, 2, 1, "", "qr"]], "mlx.core.random": [[143, 2, 1, "", "bernoulli"], [144, 2, 1, "", "categorical"], [145, 2, 1, "", "gumbel"], [146, 2, 1, "", "key"], [147, 2, 1, "", "normal"], [148, 2, 1, "", "randint"], [149, 2, 1, "", "seed"], [150, 2, 1, "", "split"], [151, 2, 1, "", "truncated_normal"], [152, 2, 1, "", "uniform"]], "mlx.nn": [[217, 0, 1, "", "ALiBi"], [218, 0, 1, "", "BatchNorm"], [219, 0, 1, "", "Conv1d"], [220, 0, 1, "", "Conv2d"], [221, 0, 1, "", "Dropout"], [222, 0, 1, "", "Dropout2d"], [223, 0, 1, "", "Dropout3d"], [224, 0, 1, "", "Embedding"], [225, 0, 1, "", "GELU"], [226, 0, 1, "", "GroupNorm"], [227, 0, 1, "", "InstanceNorm"], [228, 0, 1, "", "LayerNorm"], [229, 0, 1, "", "Linear"], [230, 0, 1, "", "Mish"], [297, 0, 1, "", "Module"], [249, 0, 1, "", "MultiHeadAttention"], [250, 0, 1, "", "PReLU"], [251, 0, 1, "", "QuantizedLinear"], [252, 0, 1, "", "RMSNorm"], [253, 0, 1, "", "ReLU"], [254, 0, 1, "", "RoPE"], [255, 0, 1, "", "SELU"], [256, 0, 1, "", "Sequential"], [257, 0, 1, "", "SiLU"], [258, 0, 1, "", "SinusoidalPositionalEncoding"], [259, 0, 1, "", "Softshrink"], [260, 0, 1, "", "Step"], [261, 0, 1, "", "Transformer"], [270, 0, 1, "", "gelu"], [271, 0, 1, "", "gelu_approx"], [272, 0, 1, "", "gelu_fast_approx"], [286, 0, 1, "", "mish"], [287, 0, 1, "", "prelu"], [288, 0, 1, "", "relu"], [289, 0, 1, "", "selu"], [290, 0, 1, "", "silu"], [291, 0, 1, "", "softshrink"], [292, 0, 1, "", "step"], [196, 2, 1, "", "value_and_grad"]], "mlx.nn.Module": [[231, 1, 1, "", "apply"], [232, 1, 1, "", "apply_to_modules"], [233, 1, 1, "", "children"], [234, 1, 1, "", "eval"], [235, 1, 1, "", "filter_and_map"], [236, 1, 1, "", "freeze"], [237, 1, 1, "", "leaf_modules"], [238, 1, 1, "", "load_weights"], [239, 1, 1, "", "modules"], [240, 1, 1, "", "named_modules"], [241, 1, 1, "", "parameters"], [242, 1, 1, "", "save_weights"], [243, 1, 1, "", "train"], [244, 1, 1, "", "trainable_parameters"], [245, 3, 1, "", "training"], [246, 1, 1, "", "unfreeze"], [247, 1, 1, "", "update"], [248, 1, 1, "", "update_modules"]], "mlx.nn.RoPE": [[254, 4, 1, "", "_cos_sin_theta_key"], [254, 4, 1, "", "_cos_sin_theta_value"]], "mlx.nn.init": [[262, 2, 1, "", "constant"], [263, 2, 1, "", "glorot_normal"], [264, 2, 1, "", "glorot_uniform"], [265, 2, 1, "", "he_normal"], [266, 2, 1, "", "he_uniform"], [267, 2, 1, "", "identity"], [268, 2, 1, "", "normal"], [269, 2, 1, "", "uniform"]], "mlx.nn.losses": [[273, 0, 1, "", "binary_cross_entropy"], [274, 0, 1, "", "cosine_similarity_loss"], [275, 0, 1, "", "cross_entropy"], [276, 0, 1, "", "gaussian_nll_loss"], [277, 0, 1, "", "hinge_loss"], [278, 0, 1, "", "huber_loss"], [279, 0, 1, "", "kl_div_loss"], [280, 0, 1, "", "l1_loss"], [281, 0, 1, "", "log_cosh_loss"], [282, 0, 1, "", "mse_loss"], [283, 0, 1, "", "nll_loss"], [284, 0, 1, "", "smooth_l1_loss"], [285, 0, 1, "", "triplet_loss"]], "mlx.optimizers": [[197, 0, 1, "", "AdaDelta"], [198, 0, 1, "", "Adafactor"], [199, 0, 1, "", "Adagrad"], [200, 0, 1, "", "Adam"], [201, 0, 1, "", "AdamW"], [202, 0, 1, "", "Adamax"], [203, 0, 1, "", "Lion"], [204, 0, 1, "", "Optimizer"], [205, 0, 1, "", "OptimizerState"], [206, 0, 1, "", "RMSprop"], [207, 0, 1, "", "SGD"]], "mlx.optimizers.Optimizer": [[204, 4, 1, "", "state"]], "mlx.utils": [[208, 2, 1, "", "tree_flatten"], [209, 2, 1, "", "tree_map"], [210, 2, 1, "", "tree_unflatten"]]}, "objtypes": {"0": "py:class", "1": "py:method", "2": "py:function", "3": "py:property", "4": "py:attribute"}, "objnames": {"0": ["py", "class", "Python class"], "1": ["py", "method", "Python method"], "2": ["py", "function", "Python function"], "3": ["py", "property", "Python property"], "4": ["py", "attribute", "Python attribute"]}, "titleterms": {"oper": [0, 1, 298], "develop": 1, "document": 1, "introduc": 1, "exampl": [1, 5, 309], "primit": 1, "us": [1, 305, 310], "implement": [1, 3], "cpu": 1, "backend": 1, "gpu": 1, "transform": [1, 261, 301, 303, 305, 307], "build": [1, 6], "bind": 1, "python": [1, 5, 6], "cmake": 1, "setuptool": 1, "usag": [1, 5], "result": 1, "script": [1, 3], "download": [1, 3], "code": [1, 3], "linear": [2, 215, 229], "regress": 2, "llm": 3, "infer": 3, "model": 3, "attent": 3, "layer": [3, 4, 295], "encod": 3, "full": [3, 99], "gener": 3, "put": 3, "all": [3, 12, 29], "togeth": 3, "convert": 3, "weight": 3, "load": [3, 115, 308], "benchmark": 3, "multi": 4, "perceptron": 4, "mlx": [5, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127, 128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141, 142, 143, 144, 145, 146, 147, 148, 149, 150, 151, 152, 153, 154, 155, 156, 157, 158, 159, 160, 161, 162, 163, 164, 165, 166, 167, 168, 169, 170, 171, 172, 173, 174, 175, 176, 177, 178, 179, 180, 181, 182, 183, 184, 185, 186, 187, 188, 189, 190, 191, 192, 193, 194, 195, 196, 197, 198, 199, 200, 201, 202, 203, 204, 205, 206, 207, 208, 209, 210, 217, 218, 219, 220, 221, 222, 223, 224, 225, 226, 227, 228, 229, 230, 231, 232, 233, 234, 235, 236, 237, 238, 239, 240, 241, 242, 243, 244, 245, 246, 247, 248, 249, 250, 251, 252, 253, 254, 255, 256, 257, 258, 259, 260, 261, 262, 263, 264, 265, 266, 267, 268, 269, 270, 271, 272, 273, 274, 275, 276, 277, 278, 279, 280, 281, 282, 283, 284, 285, 286, 287, 288, 289, 290, 291, 292], "instal": [5, 6], "api": [5, 6], "refer": 5, "c": [5, 6], "further": 5, "read": 5, "from": [6, 304], "pypi": [], "troubleshoot": 6, "sourc": 6, "requir": 6, "option": 6, "metal": 6, "found": 6, "x86": 6, "shell": 6, "core": [7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127, 128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141, 142, 143, 144, 145, 146, 147, 148, 149, 150, 151, 152, 153, 154, 155, 156, 157, 158, 159, 160, 161, 162, 163, 164, 165, 166, 167, 168, 169, 170, 171, 172, 173, 174, 175, 176, 177, 178, 179, 180, 181, 182, 183, 184, 185, 186, 187, 188, 189, 190, 191, 192, 193, 194, 195], "devic": [7, 213], "dtype": [8, 35], "stream": [9, 213, 310], "ab": [10, 28], "add": 11, "allclos": 13, "ani": [14, 30], "arang": 15, "arcco": 16, "arccosh": 17, "arcsin": 18, "arcsinh": 19, "arctan": 20, "arctanh": 21, "argmax": [22, 31], "argmin": [23, 32], "argpartit": 24, "argsort": 25, "arrai": [26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 211, 304, 308], "t": 27, "astyp": 33, "co": [34, 68], "exp": [36, 81], "item": 37, "log": [38, 116], "log1p": [39, 118], "logsumexp": [40, 124], "max": [41, 126], "mean": [42, 128], "min": [43, 129], "ndim": 44, "prod": [45, 140], "reciproc": [46, 153], "reshap": [47, 155], "round": [48, 156], "rsqrt": [49, 157], "shape": 50, "sin": [51, 167], "size": 52, "split": [53, 150, 171], "sqrt": [54, 172], "squar": [55, 173], "sum": [56, 178], "tolist": 57, "transpos": [58, 185], "var": [59, 190], "array_equ": 60, "broadcast_to": 61, "ceil": 62, "clip": 63, "concaten": 64, "conv1d": [65, 219], "conv2d": [66, 220], "convolv": 67, "cosh": 69, "default_devic": 70, "default_stream": 71, "dequant": 72, "divid": 75, "divmod": 76, "equal": 77, "erf": 78, "erfinv": 79, "eval": [80, 234], "expand_dim": 82, "ey": 83, "fft": [84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 214], "fft2": 85, "fftn": 86, "ifft": 87, "ifft2": 88, "ifftn": 89, "irfft": 90, "irfft2": 91, "irfftn": 92, "rfft": 93, "rfft2": 94, "rfftn": 95, "flatten": 96, "floor": 97, "floor_divid": 98, "grad": [100, 216], "greater": 101, "greater_equ": 102, "ident": [103, 267], "inner": 104, "jvp": 109, "less": 110, "less_equ": 111, "linalg": [112, 113], "norm": 112, "linspac": 114, "log10": 117, "log2": 119, "logaddexp": 120, "logical_and": 121, "logical_not": 122, "logical_or": 123, "matmul": 125, "maximum": 127, "minimum": 130, "moveaxi": 131, "multipli": 132, "neg": 133, "new_stream": 134, "ones": 135, "ones_lik": 136, "outer": 137, "pad": 138, "partit": 139, "quantiz": 141, "quantized_matmul": 142, "random": [143, 144, 145, 146, 147, 148, 149, 150, 151, 152, 300], "bernoulli": 143, "categor": 144, "gumbel": 145, "kei": 146, "normal": [147, 268], "randint": 148, "seed": 149, "truncated_norm": 151, "uniform": [152, 269], "repeat": 154, "save": [158, 308], "save_gguf": 159, "save_safetensor": 160, "savez": 161, "savez_compress": 162, "set_default_devic": 163, "set_default_stream": 164, "sigmoid": 165, "sign": 166, "simplifi": [], "sinh": 168, "softmax": 169, "sort": 170, "squeez": 174, "stack": 175, "stop_gradi": 176, "subtract": 177, "swapax": 179, "take": 180, "take_along_axi": 181, "tan": 182, "tanh": 183, "tensordot": 184, "tri": 186, "tril": 187, "triu": 188, "value_and_grad": [189, 196], "vjp": 191, "vmap": 192, "where": 193, "zero": 194, "zeros_lik": 195, "nn": [196, 217, 218, 219, 220, 221, 222, 223, 224, 225, 226, 227, 228, 229, 230, 231, 232, 233, 234, 235, 236, 237, 238, 239, 240, 241, 242, 243, 244, 245, 246, 247, 248, 249, 250, 251, 252, 253, 254, 255, 256, 257, 258, 259, 260, 261, 262, 263, 264, 265, 266, 267, 268, 269, 270, 271, 272, 273, 274, 275, 276, 277, 278, 279, 280, 281, 282, 283, 284, 285, 286, 287, 288, 289, 290, 291, 292], "optim": [197, 198, 199, 200, 201, 202, 203, 204, 205, 206, 207, 299], "adadelta": 197, "adagrad": 199, "adam": 200, "adamw": 201, "adamax": 202, "lion": 203, "optimizerst": 205, "rmsprop": 206, "sgd": 207, "util": [208, 209, 210, 302], "tree_flatten": 208, "tree_map": 209, "tree_unflatten": 210, "data": 212, "type": 212, "support": 212, "algebra": 215, "neural": 216, "network": 216, "quick": [216, 307], "start": [216, 307], "The": 216, "modul": [216, 231, 232, 233, 234, 235, 236, 237, 238, 239, 240, 241, 242, 243, 244, 245, 246, 247, 248, 297], "class": 216, "paramet": [216, 241], "updat": [216, 247, 304], "inspect": 216, "valu": 216, "alibi": 217, "batchnorm": 218, "dropout": 221, "dropout2d": 222, "dropout3d": 223, "embed": 224, "gelu": [225, 270], "groupnorm": 226, "instancenorm": 227, "layernorm": 228, "mish": [230, 286], "appli": 231, "apply_to_modul": 232, "children": 233, "filter_and_map": 235, "freez": 236, "leaf_modul": 237, "load_weight": 238, "named_modul": 240, "save_weight": 242, "train": [243, 245], "trainable_paramet": 244, "unfreez": 246, "update_modul": 248, "multiheadattent": 249, "prelu": [250, 287], "quantizedlinear": 251, "rmsnorm": 252, "relu": [253, 288], "rope": 254, "selu": [255, 289], "sequenti": 256, "silu": [257, 290], "sinusoidalpositionalencod": 258, "step": [260, 292], "gelu_approx": 271, "gelu_fast_approx": 272, "loss": [273, 274, 275, 276, 277, 278, 279, 280, 281, 282, 283, 284, 285, 296], "binary_cross_entropi": 273, "cosine_similarity_loss": 274, "cross_entropi": 275, "hinge_loss": 277, "huber_loss": 278, "kl_div_loss": 279, "l1_loss": 280, "log_cosh_loss": 281, "mse_loss": 282, "nll_loss": 283, "smooth_l1_loss": 284, "triplet_loss": 285, "function": [293, 296, 303, 307], "tree": 302, "index": 304, "differ": 304, "numpi": [304, 306], "In": 304, "place": 304, "lazi": 305, "evalu": 305, "why": 305, "comput": 305, "graph": [305, 307], "onli": 305, "what": 305, "you": 305, "when": 305, "convers": 306, "other": 306, "framework": 306, "pytorch": 306, "jax": 306, "tensorflow": 306, "guid": 307, "basic": 307, "serial": 308, "format": 308, "unifi": 309, "memori": 309, "A": 309, "simpl": 309, "specifi": 310, "isinf": 105, "isnan": 106, "isneginf": 107, "isposinf": 108, "automat": 303, "differenti": 303, "vector": 303, "diag": 73, "diagon": 74, "qr": 113, "adafactor": 198, "softshrink": [259, 291], "init": [262, 263, 264, 265, 266, 267, 268, 269], "constant": 262, "glorot_norm": 263, "glorot_uniform": 264, "he_norm": 265, "he_uniform": 266, "gaussian_nll_loss": 276, "initi": 294}, "envversion": {"sphinx.domains.c": 2, "sphinx.domains.changeset": 1, "sphinx.domains.citation": 1, "sphinx.domains.cpp": 6, "sphinx.domains.index": 1, "sphinx.domains.javascript": 2, "sphinx.domains.math": 2, "sphinx.domains.python": 3, "sphinx.domains.rst": 2, "sphinx.domains.std": 2, "sphinx.ext.intersphinx": 1, "sphinx": 56}})
\ No newline at end of file
diff --git a/docs/build/html/usage/function_transforms.html b/docs/build/html/usage/function_transforms.html
index a57d98443..9071c0bd5 100644
--- a/docs/build/html/usage/function_transforms.html
+++ b/docs/build/html/usage/function_transforms.html
@@ -9,7 +9,7 @@
- Function Transforms — MLX 0.0.9 documentation
+ Function Transforms — MLX 0.1.0 documentation
@@ -134,8 +134,8 @@
-
-
+
+
@@ -241,6 +241,8 @@
mlx.core.cos
mlx.core.cosh
mlx.core.dequantize
+mlx.core.diag
+mlx.core.diagonal
mlx.core.divide
mlx.core.divmod
mlx.core.equal
@@ -351,7 +353,6 @@
mlx.core.jvp
mlx.core.vjp
mlx.core.vmap
-mlx.core.simplify
FFT
Linear Algebra
Neural Networks
@@ -433,6 +436,7 @@
mlx.nn.prelu
mlx.nn.relu
mlx.nn.selu
+mlx.nn.softshrink
mlx.nn.silu
mlx.nn.step
@@ -441,6 +445,7 @@
mlx.nn.losses.binary_cross_entropy
mlx.nn.losses.cosine_similarity_loss
mlx.nn.losses.cross_entropy
+mlx.nn.losses.gaussian_nll_loss
mlx.nn.losses.hinge_loss
mlx.nn.losses.huber_loss
mlx.nn.losses.kl_div_loss
@@ -452,14 +457,26 @@
mlx.nn.losses.triplet_loss
+Initializers
-Optimizers
+
+Optimizers
-Tree Utils
@@ -441,6 +445,7 @@
mlx.nn.losses.binary_cross_entropy
mlx.nn.losses.cosine_similarity_loss
mlx.nn.losses.cross_entropy
+mlx.nn.losses.gaussian_nll_loss
mlx.nn.losses.hinge_loss
mlx.nn.losses.huber_loss
mlx.nn.losses.kl_div_loss
@@ -452,14 +457,26 @@
mlx.nn.losses.triplet_loss
+Initializers
-Optimizers
+
+Optimizers
-Tree Utils
@@ -441,6 +445,7 @@
mlx.nn.losses.binary_cross_entropy
mlx.nn.losses.cosine_similarity_loss
mlx.nn.losses.cross_entropy
+mlx.nn.losses.gaussian_nll_loss
mlx.nn.losses.hinge_loss
mlx.nn.losses.huber_loss
mlx.nn.losses.kl_div_loss
@@ -452,14 +457,26 @@
mlx.nn.losses.triplet_loss
+Initializers
-Optimizers
+
+Optimizers
-Tree Utils
+Tree Utils
mlx.utils.tree_flatten
mlx.utils.tree_unflatten
mlx.utils.tree_map
@@ -674,7 +691,7 @@ describe below.
Transforming Compute Graphs
Lazy evaluation let’s us record a compute graph without actually doing any
computations. This is useful for function transformations like grad()
and
-vmap()
and graph optimizations like simplify()
.
+vmap()
and graph optimizations.
Currently, MLX does not compile and rerun compute graphs. They are all
generated dynamically. However, lazy evaluation makes it much easier to
integrate compilation for future performance enhancements.
diff --git a/docs/build/html/usage/numpy.html b/docs/build/html/usage/numpy.html
index 52bf06e40..6eee51a40 100644
--- a/docs/build/html/usage/numpy.html
+++ b/docs/build/html/usage/numpy.html
@@ -9,7 +9,7 @@
- Conversion to NumPy and Other Frameworks — MLX 0.0.9 documentation
+ Conversion to NumPy and Other Frameworks — MLX 0.1.0 documentation
@@ -134,8 +134,8 @@
-
-
+
+
@@ -241,6 +241,8 @@
mlx.core.cos
mlx.core.cosh
mlx.core.dequantize
+mlx.core.diag
+mlx.core.diagonal
mlx.core.divide
mlx.core.divmod
mlx.core.equal
@@ -351,7 +353,6 @@
mlx.core.jvp
mlx.core.vjp
mlx.core.vmap
-mlx.core.simplify
FFT
Linear Algebra
Neural Networks
@@ -433,6 +436,7 @@
mlx.nn.prelu
mlx.nn.relu
mlx.nn.selu
+mlx.nn.softshrink
mlx.nn.silu
mlx.nn.step
@@ -441,6 +445,7 @@
mlx.nn.losses.binary_cross_entropy
mlx.nn.losses.cosine_similarity_loss
mlx.nn.losses.cross_entropy
+mlx.nn.losses.gaussian_nll_loss
mlx.nn.losses.hinge_loss
mlx.nn.losses.huber_loss
mlx.nn.losses.kl_div_loss
@@ -452,14 +457,26 @@
mlx.nn.losses.triplet_loss
+Initializers
-Optimizers
+
+Optimizers
-Tree Utils
@@ -441,6 +445,7 @@
mlx.nn.losses.binary_cross_entropy
mlx.nn.losses.cosine_similarity_loss
mlx.nn.losses.cross_entropy
+mlx.nn.losses.gaussian_nll_loss
mlx.nn.losses.hinge_loss
mlx.nn.losses.huber_loss
mlx.nn.losses.kl_div_loss
@@ -452,14 +457,26 @@
mlx.nn.losses.triplet_loss
+Initializers
-Optimizers
+
+Optimizers
-Tree Utils
@@ -441,6 +445,7 @@
mlx.nn.losses.binary_cross_entropy
mlx.nn.losses.cosine_similarity_loss
mlx.nn.losses.cross_entropy
+mlx.nn.losses.gaussian_nll_loss
mlx.nn.losses.hinge_loss
mlx.nn.losses.huber_loss
mlx.nn.losses.kl_div_loss
@@ -452,14 +457,26 @@
mlx.nn.losses.triplet_loss
+Initializers
-Optimizers
+
+Optimizers
-Tree Utils
@@ -441,6 +445,7 @@
mlx.nn.losses.binary_cross_entropy
mlx.nn.losses.cosine_similarity_loss
mlx.nn.losses.cross_entropy
+mlx.nn.losses.gaussian_nll_loss
mlx.nn.losses.hinge_loss
mlx.nn.losses.huber_loss
mlx.nn.losses.kl_div_loss
@@ -452,14 +457,26 @@
mlx.nn.losses.triplet_loss
+Initializers
-Optimizers
+
+Optimizers
-Tree Utils
@@ -441,6 +445,7 @@
mlx.nn.losses.binary_cross_entropy
mlx.nn.losses.cosine_similarity_loss
mlx.nn.losses.cross_entropy
+mlx.nn.losses.gaussian_nll_loss
mlx.nn.losses.hinge_loss
mlx.nn.losses.huber_loss
mlx.nn.losses.kl_div_loss
@@ -452,14 +457,26 @@
mlx.nn.losses.triplet_loss
+Initializers
-Optimizers
+
+Optimizers
-Tree Utils