Angelos Katharopoulos
99eefd2ec0
Gather mm new kernel and small refactoring ( #2040 )
2025-04-14 16:37:36 -07:00
Yury Popov
e9e268336b
LogCumSumExp ( #2069 )
2025-04-13 01:27:29 -07:00
Angelos Katharopoulos
c4189a38e4
Add float mask to sdpa vector ( #2068 )
2025-04-11 17:29:40 -07:00
Awni Hannun
68d1b3256b
nit: fix exception handling ( #2066 )
2025-04-11 14:12:08 -07:00
Awni Hannun
ef7ece9851
fix fft bug ( #2062 )
2025-04-10 19:41:27 -07:00
Angelos Katharopoulos
ddaa4b7dcb
Fix the test and add custom min/max reductions for uncommon MPI types ( #2060 )
2025-04-10 17:01:17 -07:00
Anastasiia Filippova
515f104926
Min / max reductions ( #2041 )
2025-04-09 23:22:20 -07:00
Awni Hannun
00794c42bc
Fix causal mask sdpa vec ( #2053 )
...
* fix sdpa vector causal mask
* test
2025-04-08 09:11:23 -07:00
Awni Hannun
f2c85308c1
add a half simd gemm fallback ( #2046 )
...
* add a half simd gemm fallback
* nit
2025-04-07 09:31:29 -07:00
Awni Hannun
ec5e2aae61
nit in doc ( #2044 )
2025-04-04 12:04:17 -07:00
Jagrit Digani
3290bfa690
Add new sdpa function overload ( #2035 )
...
* Add new sdpa function overload
* Address comments
* Remove std::varaint from cpp sdpa function
2025-04-03 11:58:28 -07:00
Jagrit Digani
8777fd104f
Depthwise Conv2D optimization ( #2036 )
...
- Add new specialized kernel for small kernel (kernels size <= 7), small strides (strides <= 2) depthwise 2d convolutions
- Add related tests
2025-04-03 09:42:04 -07:00
Awni Hannun
de5f38fd48
Custom logsumexp ( #2028 )
...
* initial custom logsumexp
* more tests
* comments + fix
2025-03-31 07:36:55 -07:00
Angelos Katharopoulos
ec2854b13a
Swap -inf for finite_minimum value ( #2029 )
2025-03-30 21:55:04 -07:00
Stephen Panaro
90823d2938
Add missing funcs to docs ( #2021 )
2025-03-30 18:29:33 -07:00
Awni Hannun
28f39e9038
Log for complex numbers in Metal ( #2025 )
...
* Log for complex numbers in Metal
* fix log2
2025-03-30 17:04:38 -07:00
Awni Hannun
05d7118561
causal vector sdpa ( #2018 )
...
* causal vector sdpa
* get rid of memory threshold
2025-03-28 12:36:13 -07:00
Awni Hannun
98b901ad66
enable complex gemm ( #2017 )
2025-03-28 10:45:13 -07:00
Awni Hannun
5580b47291
iinfo and scalar overflow detection ( #2009 )
2025-03-27 19:54:56 -07:00
Yi Wang
a8931306e1
Remove unused variable in CMakeBuild ( #2011 )
...
Fix https://github.com/ml-explore/mlx/issues/2010
2025-03-27 16:00:51 -07:00
Chunyang Wen
022eabb734
Remove unused import ( #1987 )
2025-03-24 20:19:32 -07:00
Awni Hannun
a84cc0123f
promote mask when needed ( #1998 )
2025-03-23 19:58:28 -07:00
Angelos Katharopoulos
4eef8102c9
Distributed layers ( #1270 )
2025-03-21 13:52:17 -07:00
Angelos Katharopoulos
69e4dd506b
Add a ring all gather ( #1985 )
2025-03-21 13:36:51 -07:00
Awni Hannun
2a980a76ce
Add stats and limit to common allocator and enable tests ( #1988 )
...
* add stats to common allocator and enable tests
* linux memory and default
* fix
2025-03-21 12:28:36 -07:00
Awni Hannun
4e1994e9d7
move memory APIs into top level mlx.core ( #1982 )
2025-03-21 07:25:12 -07:00
jiyzhang
65a38c452b
update the formula of smooth_l1_loss ( #1986 )
2025-03-21 06:25:23 -07:00
Awni Hannun
7b7e2352cd
fix malloc or wait deadlock ( #1976 )
2025-03-20 16:48:43 -07:00
Awni Hannun
005e7efa64
fix mask in sdpa ( #1980 )
...
* fix mask in sdpa
* fix attention mask
* Re-enable routing for array mask
---------
Co-authored-by: Jagrit Digani <digani@apple.com>
2025-03-20 14:53:12 -07:00
Jagrit Digani
b42d13ec84
Update attention tests to show diff, disable array masks ( #1978 )
2025-03-20 14:25:38 -07:00
Jagrit Digani
9adcd1a650
Support fused masking in Attention ( #1924 )
...
* Update API to allow mask='causal' in fast::sdpa
* Add fallback
* Update steel::AttnParams
* Fix typo
* WIP, basic causal
* Update tests
* Update benchmarking
* Update masking loop limits
* Add bool masking and update tests
* Update additive mask
* Update benchmarks
* Update benchmarks
* Update tests
* Update for bfloat error
* Update early exit
* Add random seed to tests
2025-03-20 11:01:32 -07:00
Awni Hannun
3c164fca8c
Fix multistream GPU deadlock ( #1969 )
...
* fix multistream GPU deadlock
* comments
2025-03-20 07:19:47 -07:00
jiyzhang
95e335db7b
Update smooth_l1_loss in losses.py ( #1974 )
...
According the definition of smooth_l1_loss, the line
diff = predictions - targets
Should be updated to
diff = mx.abs(predictions - targets)
After the modification, the result is consistent with PyTorch smooth_l1_loss
2025-03-19 20:19:02 -07:00
Chunyang Wen
3779150750
refactor: all use schedule ( #1973 )
2025-03-19 11:24:04 -07:00
Chunyang Wen
45ad06aac8
Fix typo; Fix lint warning when reuse the same name ( #1968 )
...
* Fix typo; Fix lint warning when reuse the same name
* Add missing period
2025-03-18 07:12:24 -07:00
Awni Hannun
c6ea2ba329
Use same accumulation precision in gemv as gemm ( #1962 )
...
* use same accumulation precision in gemv as gemm
* faster
* fix compile
2025-03-16 07:13:24 -07:00
Awni Hannun
2770a10240
fix grad with inplace updates ( #1961 )
2025-03-13 19:13:09 -07:00
Awni Hannun
32da94507a
fix vmap for flatten ( #1955 )
2025-03-11 10:42:22 -07:00
Awni Hannun
3c3e558c60
Support transposed head/seq for kv ( #1950 )
...
* support transposed head/seq for kv
* fix flaky test
* nit
2025-03-10 10:53:45 -07:00
Chunyang Wen
cffceda6ee
Add type hint for _extra_repr ( #1948 )
2025-03-10 06:05:36 -07:00
Chunyang Wen
d14c9fe7ea
Add file info when raising errors in save ( #1943 )
2025-03-08 14:51:04 -08:00
Chunyang Wen
5db90ce822
Fix obsured warning ( #1944 )
2025-03-08 14:50:39 -08:00
Chunyang Wen
d699cc1330
Fix unreachable warning ( #1939 )
...
* Fix unreachable warning
* Update error message
2025-03-07 17:23:04 -08:00
Chunyang Wen
a198b2787e
Remove unused modules ( #1936 )
2025-03-06 14:20:27 -08:00
Chunyang Wen
04edad8c59
Add doc string for path ( #1937 )
2025-03-06 14:20:09 -08:00
David Wisdom
392b3060b0
Fix typo in randint docstring ( #1932 )
...
This commit fixes a typo in the docstring for mlx.core.random.randint() by changing "roadcastable" to "broadcastable".
2025-03-05 21:48:00 -08:00
Angelos Katharopoulos
0792ff02ff
Only fail when 10 consecutive socket errors occur ( #1928 )
2025-03-05 13:16:19 -08:00
Abe Leininger
3835a428c5
Adds nuclear norm support ( #1894 )
...
* adjust norm unit test tolerance
2025-03-04 13:26:02 -08:00
Angelos Katharopoulos
9680f72cca
Add a multi optimizer ( #1916 )
2025-03-04 13:16:35 -08:00
Angelos Katharopoulos
a0737273d3
Allow debugging in distributed mode ( #1920 )
2025-03-04 13:01:10 -08:00