Commit Graph

209 Commits

Author SHA1 Message Date
Gökdeniz Gülmez
deee214a95
Adding support for the Muon Optimizer (#1914)
* initial commit with workong optmimizer

* update ACKNOWLEDGMENTS.md

* nits and adding it to test

* nits

* G.astype(mx.bfloat16) to G.astype(G.dtype)

* G.ndim >= 2 to assert G.ndim == 2

* remove coments

* replace with  mx.addmm

* remove comments

* format

* nits

* match muon

* fix addmm

---------

Co-authored-by: Awni Hannun <awni@apple.com>
2025-07-18 12:25:28 -07:00
Angelos Katharopoulos
0eb035b4b1
Fix type promotion in Adam with bias correction (#2350) 2025-07-10 11:14:42 -07:00
Cheng
19facd4b20
Build with all cpu cores by default (#2336) 2025-07-07 06:06:45 -07:00
Awni Hannun
cfb6a244ea
allow parameters to be deleted (#2325) 2025-07-01 21:27:23 -07:00
Awni Hannun
33bf1a244b
Fix module update in strict mode (#2321)
* fix module update in strict mode

* allow GELU to be pickled
2025-06-29 11:12:29 -07:00
Angelos Katharopoulos
5adf185f86
Fix update_modules() when providing a subset (#2308) 2025-06-20 17:19:46 -07:00
Christopher Fleetwood
004c1d8ef2
Report number of missing parameters (#2264)
* chore: inform

* chore: format

---------

Co-authored-by: FL33TW00D <FL33TW00D@users.noreply.github.com>
2025-06-10 06:37:50 -07:00
Awni Hannun
c763fe1be0
default strict mode for module update and update_modules (#2239) 2025-06-05 15:27:02 -07:00
Angelos Katharopoulos
0359bf02c9
Nearest upsample (#2202) 2025-05-19 11:23:38 -07:00
Angelos Katharopoulos
a3a632d567
Fix the launcher when ran locally (#2147) 2025-05-01 12:56:09 -07:00
Awni Hannun
aa5d84f102
Allow quant layer to be unfrozen (#2142) 2025-04-30 09:08:29 -07:00
Param Thakkar
600e87e03c
Added output_padding parameters in conv_transpose (#2092) 2025-04-23 09:26:33 -07:00
Awni Hannun
68d1b3256b
nit: fix exception handling (#2066) 2025-04-11 14:12:08 -07:00
Yi Wang
a8931306e1
Remove unused variable in CMakeBuild (#2011)
Fix https://github.com/ml-explore/mlx/issues/2010
2025-03-27 16:00:51 -07:00
Chunyang Wen
022eabb734
Remove unused import (#1987) 2025-03-24 20:19:32 -07:00
Angelos Katharopoulos
4eef8102c9
Distributed layers (#1270) 2025-03-21 13:52:17 -07:00
jiyzhang
65a38c452b
update the formula of smooth_l1_loss (#1986) 2025-03-21 06:25:23 -07:00
jiyzhang
95e335db7b
Update smooth_l1_loss in losses.py (#1974)
According the definition of smooth_l1_loss, the line 

diff = predictions - targets

Should be updated to 

diff = mx.abs(predictions - targets)

After the modification, the result is consistent with PyTorch smooth_l1_loss
2025-03-19 20:19:02 -07:00
Chunyang Wen
3779150750
refactor: all use schedule (#1973) 2025-03-19 11:24:04 -07:00
Chunyang Wen
45ad06aac8
Fix typo; Fix lint warning when reuse the same name (#1968)
* Fix typo; Fix lint warning when reuse the same name

* Add missing period
2025-03-18 07:12:24 -07:00
Chunyang Wen
cffceda6ee
Add type hint for _extra_repr (#1948) 2025-03-10 06:05:36 -07:00
Chunyang Wen
d14c9fe7ea
Add file info when raising errors in save (#1943) 2025-03-08 14:51:04 -08:00
Chunyang Wen
5db90ce822
Fix obsured warning (#1944) 2025-03-08 14:50:39 -08:00
Chunyang Wen
d699cc1330
Fix unreachable warning (#1939)
* Fix unreachable warning

* Update error message
2025-03-07 17:23:04 -08:00
Chunyang Wen
a198b2787e
Remove unused modules (#1936) 2025-03-06 14:20:27 -08:00
Chunyang Wen
04edad8c59
Add doc string for path (#1937) 2025-03-06 14:20:09 -08:00
Angelos Katharopoulos
0792ff02ff
Only fail when 10 consecutive socket errors occur (#1928) 2025-03-05 13:16:19 -08:00
Angelos Katharopoulos
9680f72cca
Add a multi optimizer (#1916) 2025-03-04 13:16:35 -08:00
Angelos Katharopoulos
a0737273d3
Allow debugging in distributed mode (#1920) 2025-03-04 13:01:10 -08:00
Angelos Katharopoulos
5d68082881
Ring docs (#1829) 2025-02-28 11:34:21 -08:00
Angelos Katharopoulos
607181644f
Add mlx.distributed_config script (#1902) 2025-02-28 11:16:39 -08:00
Franck Verrot
7df3f792a2
Ensure Conv2D and Conv3D's kernel sizes aren't trimmed (#1852)
Before the change, this snippet:

```
print(nn.Conv1d(1, 32, 3, padding=1))
print(nn.Conv2d(1, 32, (3, 3), padding=1))
print(nn.Conv3d(1, 32, (3, 3, 3), padding=1))
```

would output:

```
Conv1d(1, 32, kernel_size=3, stride=1, padding=1, dilation=1, groups=1, bias=True)
Conv2d(1, 32, kernel_size=(3,), stride=(1, 1), padding=(1, 1), dilation=1, groups=1, bias=True)
Conv3d(1, 32, kernel_size=(3, 3), stride=(1, 1, 1), padding=(1, 1, 1), dilation=1, bias=True)
```

After the change, the output will be:

```
Conv1d(1, 32, kernel_size=3, stride=1, padding=1, dilation=1, groups=1, bias=True)
Conv2d(1, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), dilation=1, groups=1, bias=True)
Conv3d(1, 32, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), dilation=1, bias=True)
```
2025-02-10 06:27:01 -08:00
Awni Hannun
83a0340fa7
allow command (#1836) 2025-02-06 10:32:24 -08:00
Awni Hannun
ca305afdbe
loading empty list is ok when strict = false (#1834) 2025-02-05 16:19:27 -08:00
Awni Hannun
ec7c7def40
no line buffer for mpi jobs (#1825) 2025-02-03 12:02:15 -08:00
Angelos Katharopoulos
ded914f442
Small distributed launch helper (#1810) 2025-01-29 17:55:04 -08:00
Awni Hannun
1017ac4a9e
add dilation for conv 3d layers + test for 3d conv w/ dilation (#1802) 2025-01-28 06:17:07 -08:00
Awni Hannun
2235dee906
catch stream errors earlier to avoid aborts (#1801) 2025-01-27 14:05:43 -08:00
Nripesh Niketan
5cc5201914
feat: Add orthogonal initializer and corresponding tests (#1651)
* feat: Add orthogonal initializer and corresponding tests

* lint

* Add acknowledgements

* nits

---------

Co-authored-by: Awni Hannun <awni@apple.com>
2025-01-13 07:29:20 -08:00
Awni Hannun
657f466402
use sdpa and exportable functions in transformer multi head attention (#1760) 2025-01-09 13:11:55 -08:00
Awni Hannun
c9d30aa6ac
MLX in C++ example (#1736)
* MLX in C++ example

* nits

* fix docs
2025-01-02 19:09:04 -08:00
Awni Hannun
c3628eea49
Add mx.finfo and use it when making causal mask (#1726)
* finfo

* fixes

* docs
2024-12-19 14:52:41 -08:00
Tomohiro Oga
a6b426422e
add cubic to type hinting for upsample (#1709) 2024-12-17 07:30:23 -08:00
Awni Hannun
29a620cab2
No reshapes in quantized embedding (#1682)
* no reshapes in quantized embedding

* fix inadvertant cast

* add tol
2024-12-09 18:57:38 -08:00
mt_caret
fd3377dd1f
Support bias correction in Adam and AdamW optimizers (#1640) 2024-12-06 12:13:34 -08:00
Alex Barron
1445dcaa60
let class predicate specify quantization parameters (#1638) 2024-12-02 14:09:28 -08:00
Awni Hannun
aa86876813
fix transformer decoder post norm LN (#1637) 2024-12-02 07:02:17 -08:00
Awni Hannun
7cbb4aef17
Doc fix (#1615) 2024-11-22 11:12:25 -08:00
Angelos Katharopoulos
d8c824c594
Formatting fixes (#1606) 2024-11-20 15:30:36 -08:00
Saanidhya
cb431dfc9f
Adds 3D pooling (#1526) 2024-11-19 16:45:24 -08:00