Awni Hannun
e88f2d4a8e
fix cross entropy axis param ( #2641 )
...
* fix cross entropy axis param
* faster grad clipping
2025-10-01 16:49:55 -07:00
Gökdeniz Gülmez
db5443e831
Adding Relu2 ( #2582 )
...
* in. com.
* upd. ackn.
* update __init__
* nits
* nits + format
* used mx.maximum(x, 0) instead of calling the function and moves relu6 under relu2 to make it nicer
* same with _make_activation_module
* Update python/mlx/nn/layers/activations.py
upd
Co-authored-by: Awni Hannun <awni.hannun@gmail.com >
* update funct.rst
* upd. layers.rst
---------
Co-authored-by: Awni Hannun <awni.hannun@gmail.com >
2025-09-10 07:24:30 -07:00
XXXXRT666
8f163a367d
typing: add type hints to mlx.core.array, linalg, distributed, and random ( #2565 )
...
* Add type annotations to mlx methods
* Missing list_or_scalar
2025-09-04 09:08:11 -07:00
Manuel Villanueva
89a3df9014
Fixed several type annotations in the MLX stubs which degraded to Unknown/Any ( #2560 )
...
* Added scalar to stubs to fix Unkown Type Hint
### Proposed changes
Issue #2478 reports that several type annotations in the MLX stubs degrade to Unknown/Any in editors like VS Code with Pylance, due to missing imports (Union, Optional, Tuple) and an undefined scalar type alias.
This PR updates the stub generation patterns to:
• Add missing typing imports in mlx.core.__prefix__ so that Union, Optional, Tuple, etc. are always available.
• Define and export scalar: TypeAlias = Union[int, float, bool] in mlx.core.__suffix__ so that functions typed with Union[scalar, array] resolve correctly instead of falling back to Any.
• Update submodule stub prefixes (distributed, fast, linalg, metal, random) to import scalar alongside array, Device, and Stream, ensuring type checkers resolve the union consistently across modules.
With these changes, functions like mlx.add now display rich type signatures such as:
```
def add(
a: scalar | array,
b: scalar | array,
stream: Stream | Device | None = None
) -> array
```
instead of degrading to Any.
### Checklist
• I have read the CONTRIBUTING document
• I have run pre-commit run --all-files to format my code / installed pre-commit prior to committing changes
• I have added tests that prove my fix is effective or that my feature works (n/a — stub generation only)
• I have updated the necessary documentation (if needed)
* add bool to patterns
---------
Co-authored-by: Awni Hannun <awni@apple.com >
2025-09-03 12:52:08 -07:00
Artur Antonov
c5460762e7
Fix AdamW weight_decay default value in docstring ( #2557 )
2025-08-31 21:29:30 -07:00
Awni Hannun
111f1e71af
Faster contiguous gather for indices in the first axis ( #2552 )
...
* faster contiguous gather for indices in the first axis
* work per thread > 1
* angelos suggestion for scales / biases
2025-08-28 21:26:30 -07:00
Awni Hannun
70560b6bd5
Add mode parameter for quantization ( #2499 )
...
* add mode parameter for quantization
* mxfp4 quantize/dequantize + start of optional biases
* mxfp4 works
* speedup
* cpu mxfp4
* fix
* fix test tol
* fix
* refactor
* add quant mode enum
2025-08-28 06:45:26 -07:00
Awni Hannun
3dcb286baf
Remove stream from average grads so it uses default ( #2532 )
...
* Remove stream from average grads so it uses default
* comment
2025-08-25 15:56:29 -07:00
Awni Hannun
068a4612e9
nccl default for backend=any ( #2528 )
...
* nccl default for backend=any
* check num gpus + ensure row contiguous for all reduce
* comment
2025-08-22 12:24:27 -07:00
Awni Hannun
f93f87c802
nccl dep + default for cuda ( #2526 )
2025-08-21 17:57:49 -07:00
Anastasiia Filippova
9392fc3f88
NCCL backend ( #2476 )
2025-08-21 11:56:15 -07:00
Luca Vivona
728d4db582
Support destination arg in tree flatten/unflatten ( #2450 )
2025-08-06 15:34:59 -07:00
Awni Hannun
b405591249
fix circular reference ( #2443 )
2025-07-30 09:37:44 -07:00
Skonor
7d9d6ef456
docs: fix adam and adamw eps placement ( #2416 )
...
Co-authored-by: Mikhail Gorbunov <m_gorbunov@apple.com >
2025-07-24 16:40:45 -07:00
Gökdeniz Gülmez
deee214a95
Adding support for the Muon Optimizer ( #1914 )
...
* initial commit with workong optmimizer
* update ACKNOWLEDGMENTS.md
* nits and adding it to test
* nits
* G.astype(mx.bfloat16) to G.astype(G.dtype)
* G.ndim >= 2 to assert G.ndim == 2
* remove coments
* replace with mx.addmm
* remove comments
* format
* nits
* match muon
* fix addmm
---------
Co-authored-by: Awni Hannun <awni@apple.com >
2025-07-18 12:25:28 -07:00
Angelos Katharopoulos
0eb035b4b1
Fix type promotion in Adam with bias correction ( #2350 )
2025-07-10 11:14:42 -07:00
Cheng
19facd4b20
Build with all cpu cores by default ( #2336 )
2025-07-07 06:06:45 -07:00
Awni Hannun
cfb6a244ea
allow parameters to be deleted ( #2325 )
2025-07-01 21:27:23 -07:00
Awni Hannun
33bf1a244b
Fix module update in strict mode ( #2321 )
...
* fix module update in strict mode
* allow GELU to be pickled
2025-06-29 11:12:29 -07:00
Angelos Katharopoulos
5adf185f86
Fix update_modules() when providing a subset ( #2308 )
2025-06-20 17:19:46 -07:00
Christopher Fleetwood
004c1d8ef2
Report number of missing parameters ( #2264 )
...
* chore: inform
* chore: format
---------
Co-authored-by: FL33TW00D <FL33TW00D@users.noreply.github.com >
2025-06-10 06:37:50 -07:00
Awni Hannun
c763fe1be0
default strict mode for module update and update_modules ( #2239 )
2025-06-05 15:27:02 -07:00
Angelos Katharopoulos
0359bf02c9
Nearest upsample ( #2202 )
2025-05-19 11:23:38 -07:00
Angelos Katharopoulos
a3a632d567
Fix the launcher when ran locally ( #2147 )
2025-05-01 12:56:09 -07:00
Awni Hannun
aa5d84f102
Allow quant layer to be unfrozen ( #2142 )
2025-04-30 09:08:29 -07:00
Param Thakkar
600e87e03c
Added output_padding parameters in conv_transpose ( #2092 )
2025-04-23 09:26:33 -07:00
Awni Hannun
68d1b3256b
nit: fix exception handling ( #2066 )
2025-04-11 14:12:08 -07:00
Yi Wang
a8931306e1
Remove unused variable in CMakeBuild ( #2011 )
...
Fix https://github.com/ml-explore/mlx/issues/2010
2025-03-27 16:00:51 -07:00
Chunyang Wen
022eabb734
Remove unused import ( #1987 )
2025-03-24 20:19:32 -07:00
Angelos Katharopoulos
4eef8102c9
Distributed layers ( #1270 )
2025-03-21 13:52:17 -07:00
jiyzhang
65a38c452b
update the formula of smooth_l1_loss ( #1986 )
2025-03-21 06:25:23 -07:00
jiyzhang
95e335db7b
Update smooth_l1_loss in losses.py ( #1974 )
...
According the definition of smooth_l1_loss, the line
diff = predictions - targets
Should be updated to
diff = mx.abs(predictions - targets)
After the modification, the result is consistent with PyTorch smooth_l1_loss
2025-03-19 20:19:02 -07:00
Chunyang Wen
3779150750
refactor: all use schedule ( #1973 )
2025-03-19 11:24:04 -07:00
Chunyang Wen
45ad06aac8
Fix typo; Fix lint warning when reuse the same name ( #1968 )
...
* Fix typo; Fix lint warning when reuse the same name
* Add missing period
2025-03-18 07:12:24 -07:00
Chunyang Wen
cffceda6ee
Add type hint for _extra_repr ( #1948 )
2025-03-10 06:05:36 -07:00
Chunyang Wen
d14c9fe7ea
Add file info when raising errors in save ( #1943 )
2025-03-08 14:51:04 -08:00
Chunyang Wen
5db90ce822
Fix obsured warning ( #1944 )
2025-03-08 14:50:39 -08:00
Chunyang Wen
d699cc1330
Fix unreachable warning ( #1939 )
...
* Fix unreachable warning
* Update error message
2025-03-07 17:23:04 -08:00
Chunyang Wen
a198b2787e
Remove unused modules ( #1936 )
2025-03-06 14:20:27 -08:00
Chunyang Wen
04edad8c59
Add doc string for path ( #1937 )
2025-03-06 14:20:09 -08:00
Angelos Katharopoulos
0792ff02ff
Only fail when 10 consecutive socket errors occur ( #1928 )
2025-03-05 13:16:19 -08:00
Angelos Katharopoulos
9680f72cca
Add a multi optimizer ( #1916 )
2025-03-04 13:16:35 -08:00
Angelos Katharopoulos
a0737273d3
Allow debugging in distributed mode ( #1920 )
2025-03-04 13:01:10 -08:00
Angelos Katharopoulos
5d68082881
Ring docs ( #1829 )
2025-02-28 11:34:21 -08:00
Angelos Katharopoulos
607181644f
Add mlx.distributed_config script ( #1902 )
2025-02-28 11:16:39 -08:00
Franck Verrot
7df3f792a2
Ensure Conv2D and Conv3D's kernel sizes aren't trimmed ( #1852 )
...
Before the change, this snippet:
```
print(nn.Conv1d(1, 32, 3, padding=1))
print(nn.Conv2d(1, 32, (3, 3), padding=1))
print(nn.Conv3d(1, 32, (3, 3, 3), padding=1))
```
would output:
```
Conv1d(1, 32, kernel_size=3, stride=1, padding=1, dilation=1, groups=1, bias=True)
Conv2d(1, 32, kernel_size=(3,), stride=(1, 1), padding=(1, 1), dilation=1, groups=1, bias=True)
Conv3d(1, 32, kernel_size=(3, 3), stride=(1, 1, 1), padding=(1, 1, 1), dilation=1, bias=True)
```
After the change, the output will be:
```
Conv1d(1, 32, kernel_size=3, stride=1, padding=1, dilation=1, groups=1, bias=True)
Conv2d(1, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), dilation=1, groups=1, bias=True)
Conv3d(1, 32, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), dilation=1, bias=True)
```
2025-02-10 06:27:01 -08:00
Awni Hannun
83a0340fa7
allow command ( #1836 )
2025-02-06 10:32:24 -08:00
Awni Hannun
ca305afdbe
loading empty list is ok when strict = false ( #1834 )
2025-02-05 16:19:27 -08:00
Awni Hannun
ec7c7def40
no line buffer for mpi jobs ( #1825 )
2025-02-03 12:02:15 -08:00
Angelos Katharopoulos
ded914f442
Small distributed launch helper ( #1810 )
2025-01-29 17:55:04 -08:00