Awni Hannun
28f39e9038
Log for complex numbers in Metal ( #2025 )
...
* Log for complex numbers in Metal
* fix log2
2025-03-30 17:04:38 -07:00
Awni Hannun
05d7118561
causal vector sdpa ( #2018 )
...
* causal vector sdpa
* get rid of memory threshold
2025-03-28 12:36:13 -07:00
Awni Hannun
98b901ad66
enable complex gemm ( #2017 )
2025-03-28 10:45:13 -07:00
Awni Hannun
5580b47291
iinfo and scalar overflow detection ( #2009 )
2025-03-27 19:54:56 -07:00
Yi Wang
a8931306e1
Remove unused variable in CMakeBuild ( #2011 )
...
Fix https://github.com/ml-explore/mlx/issues/2010
2025-03-27 16:00:51 -07:00
Chunyang Wen
022eabb734
Remove unused import ( #1987 )
2025-03-24 20:19:32 -07:00
Awni Hannun
a84cc0123f
promote mask when needed ( #1998 )
2025-03-23 19:58:28 -07:00
Angelos Katharopoulos
4eef8102c9
Distributed layers ( #1270 )
2025-03-21 13:52:17 -07:00
Angelos Katharopoulos
69e4dd506b
Add a ring all gather ( #1985 )
2025-03-21 13:36:51 -07:00
Awni Hannun
2a980a76ce
Add stats and limit to common allocator and enable tests ( #1988 )
...
* add stats to common allocator and enable tests
* linux memory and default
* fix
2025-03-21 12:28:36 -07:00
Awni Hannun
4e1994e9d7
move memory APIs into top level mlx.core ( #1982 )
2025-03-21 07:25:12 -07:00
jiyzhang
65a38c452b
update the formula of smooth_l1_loss ( #1986 )
2025-03-21 06:25:23 -07:00
Awni Hannun
7b7e2352cd
fix malloc or wait deadlock ( #1976 )
2025-03-20 16:48:43 -07:00
Awni Hannun
005e7efa64
fix mask in sdpa ( #1980 )
...
* fix mask in sdpa
* fix attention mask
* Re-enable routing for array mask
---------
Co-authored-by: Jagrit Digani <digani@apple.com>
2025-03-20 14:53:12 -07:00
Jagrit Digani
b42d13ec84
Update attention tests to show diff, disable array masks ( #1978 )
2025-03-20 14:25:38 -07:00
Jagrit Digani
9adcd1a650
Support fused masking in Attention ( #1924 )
...
* Update API to allow mask='causal' in fast::sdpa
* Add fallback
* Update steel::AttnParams
* Fix typo
* WIP, basic causal
* Update tests
* Update benchmarking
* Update masking loop limits
* Add bool masking and update tests
* Update additive mask
* Update benchmarks
* Update benchmarks
* Update tests
* Update for bfloat error
* Update early exit
* Add random seed to tests
2025-03-20 11:01:32 -07:00
Awni Hannun
3c164fca8c
Fix multistream GPU deadlock ( #1969 )
...
* fix multistream GPU deadlock
* comments
2025-03-20 07:19:47 -07:00
jiyzhang
95e335db7b
Update smooth_l1_loss in losses.py ( #1974 )
...
According the definition of smooth_l1_loss, the line
diff = predictions - targets
Should be updated to
diff = mx.abs(predictions - targets)
After the modification, the result is consistent with PyTorch smooth_l1_loss
2025-03-19 20:19:02 -07:00
Chunyang Wen
3779150750
refactor: all use schedule ( #1973 )
2025-03-19 11:24:04 -07:00
Chunyang Wen
45ad06aac8
Fix typo; Fix lint warning when reuse the same name ( #1968 )
...
* Fix typo; Fix lint warning when reuse the same name
* Add missing period
2025-03-18 07:12:24 -07:00
Awni Hannun
c6ea2ba329
Use same accumulation precision in gemv as gemm ( #1962 )
...
* use same accumulation precision in gemv as gemm
* faster
* fix compile
2025-03-16 07:13:24 -07:00
Awni Hannun
2770a10240
fix grad with inplace updates ( #1961 )
2025-03-13 19:13:09 -07:00
Awni Hannun
32da94507a
fix vmap for flatten ( #1955 )
2025-03-11 10:42:22 -07:00
Awni Hannun
3c3e558c60
Support transposed head/seq for kv ( #1950 )
...
* support transposed head/seq for kv
* fix flaky test
* nit
2025-03-10 10:53:45 -07:00
Chunyang Wen
cffceda6ee
Add type hint for _extra_repr ( #1948 )
2025-03-10 06:05:36 -07:00
Chunyang Wen
d14c9fe7ea
Add file info when raising errors in save ( #1943 )
2025-03-08 14:51:04 -08:00
Chunyang Wen
5db90ce822
Fix obsured warning ( #1944 )
2025-03-08 14:50:39 -08:00
Chunyang Wen
d699cc1330
Fix unreachable warning ( #1939 )
...
* Fix unreachable warning
* Update error message
2025-03-07 17:23:04 -08:00
Chunyang Wen
a198b2787e
Remove unused modules ( #1936 )
2025-03-06 14:20:27 -08:00
Chunyang Wen
04edad8c59
Add doc string for path ( #1937 )
2025-03-06 14:20:09 -08:00
David Wisdom
392b3060b0
Fix typo in randint docstring ( #1932 )
...
This commit fixes a typo in the docstring for mlx.core.random.randint() by changing "roadcastable" to "broadcastable".
2025-03-05 21:48:00 -08:00
Angelos Katharopoulos
0792ff02ff
Only fail when 10 consecutive socket errors occur ( #1928 )
2025-03-05 13:16:19 -08:00
Abe Leininger
3835a428c5
Adds nuclear norm support ( #1894 )
...
* adjust norm unit test tolerance
2025-03-04 13:26:02 -08:00
Angelos Katharopoulos
9680f72cca
Add a multi optimizer ( #1916 )
2025-03-04 13:16:35 -08:00
Angelos Katharopoulos
a0737273d3
Allow debugging in distributed mode ( #1920 )
2025-03-04 13:01:10 -08:00
Awni Hannun
e613d0eaf0
SDPA support for small batch (over sequence) queries ( #1922 )
...
* batch query sdpa
* batch sdpa for query
2025-03-04 10:59:04 -08:00
Awni Hannun
6bcd6bcf70
fix donation in scan ( #1917 )
2025-03-03 11:30:59 -08:00
Awni Hannun
4e7cd31d12
Fix slice data size ( #1913 )
...
* fix slice data size
* add test
2025-03-02 21:50:42 -08:00
Angelos Katharopoulos
5e6c130d93
RMS norm without scaling ( #1915 )
2025-02-28 20:26:57 -08:00
Angelos Katharopoulos
5d68082881
Ring docs ( #1829 )
2025-02-28 11:34:21 -08:00
Angelos Katharopoulos
607181644f
Add mlx.distributed_config script ( #1902 )
2025-02-28 11:16:39 -08:00
Angelos Katharopoulos
6bf00ef631
Fix ring of 2 and allow scalars in API ( #1906 )
2025-02-25 17:03:01 -08:00
Awni Hannun
7d042f17fe
Double for lapack ( #1904 )
...
* double for lapack ops
* add double support for lapack ops
2025-02-25 11:39:36 -08:00
Awni Hannun
28b8079e30
fix double type promotion ( #1901 )
2025-02-25 06:00:53 -08:00
Awni Hannun
7face5d9fd
fix cpu compile ( #1897 )
2025-02-24 14:10:30 -08:00
Awni Hannun
2d0f384b6f
fix simd erf_inv ( #1896 )
2025-02-24 13:57:47 -08:00
Angelos Katharopoulos
10b271d963
Ring update ( #1885 )
2025-02-20 14:32:31 -08:00
Awni Hannun
bbda0fdbdb
Allow non-square lu ( #1889 )
2025-02-20 08:13:23 -08:00
Awni Hannun
c707b2b0a6
Limit compile buffers ( #1887 )
...
* limit compile buffers
* maybe not flaky test
2025-02-19 20:28:13 -08:00
Angelos Katharopoulos
78ba24c37d
Raise an exception in the rope op if input is integer ( #1884 )
2025-02-19 14:43:39 -08:00
Angelos Katharopoulos
1a2cb72030
Ensure linspace always contains start and stop ( #1883 )
2025-02-19 13:53:20 -08:00
Abe Leininger
344a29506e
Enforce triangular matrix form in tri_inv
( #1876 )
...
* fix tri_inv bug
* Revert "fix tri_inv bug"
This reverts commit b74b290201
.
* Make sure that tri_inv returns a triangular matrix
---------
Co-authored-by: Angelos Katharopoulos <a_katharopoulos@apple.com>
2025-02-19 12:42:33 -08:00
Angelos Katharopoulos
71de73a668
Fix convs by reverting #1803 ( #1882 )
2025-02-18 14:36:34 -08:00
Alex Barron
4c1dfa58b7
xor op on arrays ( #1875 )
2025-02-17 00:24:53 -08:00
Jagrit Digani
2dc307f2e6
Winograd Update for Small batches ( #1803 )
...
* Build in padding to Winograd kernels
* Add new fused Winograd kernel
* Enable weight flipping in Winograd kernels
2025-02-14 13:08:13 -08:00
Alex Barron
7f2d1024f3
add f8_e4m3 loading ( #1859 )
2025-02-13 17:10:03 -08:00
Awni Hannun
428f589364
Revert "More buffer donation in some cases ( #1858 )" ( #1863 )
...
This reverts commit d274ae77f2
.
2025-02-13 14:21:44 -08:00
Alex Barron
5cd97f7ffe
Bitwise Inverse ( #1862 )
...
* add bitwise inverse
* add vmap + fix nojit
* inverse -> invert
* add to compile + remove unused
2025-02-13 08:44:14 -08:00
Awni Hannun
d274ae77f2
More buffer donation in some cases ( #1858 )
...
* more donation
* fix
* add test
2025-02-12 19:41:37 -08:00
Alex Barron
55c5ac7820
fix int64 bug ( #1860 )
2025-02-12 19:23:46 -08:00
Angelos Katharopoulos
0145911bea
Fixes output donation for IO ops on the GPU ( #1857 )
2025-02-12 10:52:30 -08:00
Awni Hannun
0a5215693e
Fix grad copies ( #1854 )
...
* fix grad with copies
* add test
* add test
2025-02-11 15:26:42 -08:00
Awni Hannun
2a45056ba8
Cycle leak break ( #1856 )
...
* detect and break leaks in custom function
* detect and break leaks in custom function
2025-02-11 14:45:02 -08:00
Abe Leininger
a5ededf1c3
CPU LU factorization and linear solvers ( #1451 )
...
* linalg solve backend
* nits
* more nits + fix
* luf primitive and lu, solve, and solve_triangular backends
* changes / nits
---------
Co-authored-by: Awni Hannun <awni@apple.com>
2025-02-10 12:32:24 -08:00
Franck Verrot
7df3f792a2
Ensure Conv2D and Conv3D's kernel sizes aren't trimmed ( #1852 )
...
Before the change, this snippet:
```
print(nn.Conv1d(1, 32, 3, padding=1))
print(nn.Conv2d(1, 32, (3, 3), padding=1))
print(nn.Conv3d(1, 32, (3, 3, 3), padding=1))
```
would output:
```
Conv1d(1, 32, kernel_size=3, stride=1, padding=1, dilation=1, groups=1, bias=True)
Conv2d(1, 32, kernel_size=(3,), stride=(1, 1), padding=(1, 1), dilation=1, groups=1, bias=True)
Conv3d(1, 32, kernel_size=(3, 3), stride=(1, 1, 1), padding=(1, 1, 1), dilation=1, bias=True)
```
After the change, the output will be:
```
Conv1d(1, 32, kernel_size=3, stride=1, padding=1, dilation=1, groups=1, bias=True)
Conv2d(1, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), dilation=1, groups=1, bias=True)
Conv3d(1, 32, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), dilation=1, bias=True)
```
2025-02-10 06:27:01 -08:00
Angelos Katharopoulos
9eb7d7362f
Fix Split::vmap ( #1845 )
2025-02-08 09:22:13 -08:00
Awni Hannun
1c0c118f7c
Fp64 on the CPU ( #1843 )
...
* add fp64 data type
* clean build
* update docs
* fix bug
2025-02-07 15:52:22 -08:00
Awni Hannun
83a0340fa7
allow command ( #1836 )
2025-02-06 10:32:24 -08:00
Awni Hannun
af1b725fda
Fix a couple of slicing bugs ( #1827 )
...
* fix a few bugs
* fix conv grad
* speedup test
* comment
2025-02-05 19:50:08 -08:00
Awni Hannun
9174606d4c
fix sort ( #1835 )
2025-02-05 17:16:27 -08:00
Awni Hannun
ca305afdbe
loading empty list is ok when strict = false ( #1834 )
2025-02-05 16:19:27 -08:00
Awni Hannun
ec7c7def40
no line buffer for mpi jobs ( #1825 )
2025-02-03 12:02:15 -08:00
Angelos Katharopoulos
f5cc1eea72
Allow different value dimensions in sdpa_vector ( #1811 )
2025-01-31 20:58:59 -08:00
Awni Hannun
b7c9f1d38f
scatter axis + gather axis primitives ( #1813 )
...
* scatter axis + gather axis primitives
* add transforms
* comment
2025-01-31 20:48:08 -08:00
Angelos Katharopoulos
ded914f442
Small distributed launch helper ( #1810 )
2025-01-29 17:55:04 -08:00
Awni Hannun
4758c8baa1
Start to cleanup/unify accelerate and common back-ends (Part 1/N) ( #1777 )
...
* start to cleanup/unify accelerate and common back-ends
* more progress
* simplify
* add half type and allow infs in simd exp
* unify softmax + quantized, more dispatches to simd quantized mm
* add sin/cos, use simd in vector-scalar ops
* faster CPU vectorize quant
* faster erf/erfinv
2025-01-29 14:34:49 -08:00
Awni Hannun
1017ac4a9e
add dilation for conv 3d layers + test for 3d conv w/ dilation ( #1802 )
2025-01-28 06:17:07 -08:00
Angelos Katharopoulos
ccb61d7aae
Ring distributed backend ( #1784 )
2025-01-27 22:15:01 -08:00
Awni Hannun
2235dee906
catch stream errors earlier to avoid aborts ( #1801 )
2025-01-27 14:05:43 -08:00
Awni Hannun
28091aa1ff
allow build python lib without specifying path ( #1799 )
2025-01-27 11:22:35 -08:00
Awni Hannun
121d9a0702
Fix rope fallback to not upcast ( #1797 )
...
* fix rope fallback to not upcast
* Update mlx/fast.cpp
Co-authored-by: Angelos Katharopoulos <a_katharopoulos@apple.com>
---------
Co-authored-by: Angelos Katharopoulos <a_katharopoulos@apple.com>
2025-01-26 19:07:21 -08:00
Nick
0cea88bcc5
Use @ matrix multiplication syntax to document matrix-matrix multiplication ( #1793 )
...
Co-authored-by: Nick Thompson <nicholas_a_thompson@apple.com>
2025-01-25 16:02:36 -08:00
Angelos Katharopoulos
72146fc4cd
Einsum ellipsis ( #1788 )
2025-01-25 01:28:03 -08:00
Awni Hannun
e6a7ab9675
non square qr ( #1783 )
2025-01-21 14:07:47 -08:00
Awni Hannun
90532b1f37
recompile when shapeless is different ( #1776 )
2025-01-20 21:07:10 -08:00
Awni Hannun
0c259961ac
matmul jvps ( #1772 )
2025-01-17 10:36:26 -08:00
Awni Hannun
33421c1dd3
Limit grad recursion depth by not recursing through non-grad inputs ( #1764 )
...
* limit grad recursion depth
* add grad of module test
2025-01-14 14:33:18 -08:00
Nripesh Niketan
5cc5201914
feat: Add orthogonal initializer and corresponding tests ( #1651 )
...
* feat: Add orthogonal initializer and corresponding tests
* lint
* Add acknowledgements
* nits
---------
Co-authored-by: Awni Hannun <awni@apple.com>
2025-01-13 07:29:20 -08:00
wrmsr
a4a2764a52
Fix broadcast_arrays python sig ( #1763 )
2025-01-10 12:33:26 -08:00
Awni Hannun
657f466402
use sdpa and exportable functions in transformer multi head attention ( #1760 )
2025-01-09 13:11:55 -08:00
Alex Barron
c7b0300af5
Fix batched qmv bug ( #1758 )
2025-01-09 11:45:57 -08:00
Awni Hannun
1ccaf80575
Dynamic broadcasting for shapeless compile/export ( #1722 )
...
* working towards dynamic broadcast
* shapeless broadcast
* fix build + nits
* use broadcast arrays in quantize matmul
* some cleanup / consistency
* mend
* some comments
* add vjp, jvp for broadcast axes
2025-01-09 11:04:24 -08:00
Awni Hannun
d1766f2c70
Add boolean mask support in vector SDPA ( #1757 )
2025-01-07 20:24:53 -08:00
Awni Hannun
516ded618b
Dynamic slicing ( #1741 )
...
* dynamic slice and slice update
* python bindings + tests + fix set item
* fix compile issue
* comment
* fix jit
2025-01-07 14:02:16 -08:00
Awni Hannun
d5ec172c95
Allow boolean mask in sdpa ( #1753 )
...
* allow boolean mask in sdpa
* more permissive donation in ternary
2025-01-06 16:57:07 -08:00
Angelos Katharopoulos
25b3a3e541
Optionally specify names for arrays when exporting ( #1749 )
2025-01-06 13:07:46 -08:00
Awni Hannun
058d6ce683
mpi send use input as output ( #1750 )
...
* mpi send use input as output
* move earlier
2025-01-06 06:08:43 -08:00
Angelos Katharopoulos
eab93985b8
Update custom function docs ( #1748 )
2025-01-03 16:35:25 -08:00
Awni Hannun
259025100e
Fix nd ternary on GPU ( #1746 )
2025-01-03 11:52:17 -08:00
Awni Hannun
c9d30aa6ac
MLX in C++ example ( #1736 )
...
* MLX in C++ example
* nits
* fix docs
2025-01-02 19:09:04 -08:00