Commit Graph

648 Commits

Author SHA1 Message Date
Awni Hannun
7bb063bcb3
Enable vjp for quantized scale and bias (#2129)
* Enable vjp for quantized scale and bias

* higher tol
2025-04-29 13:03:09 -07:00
hdeng-apple
167b759a38
Fix typos (#2136) 2025-04-29 07:26:05 -07:00
charan-003
99b9868859
Clarify dimension notation in conv1d, conv2d, and conv3d docstrings (#2123)
* Clarify dimension notation in conv1d, conv2d, and conv3d docstrings

* Updating transposed convs in conv1d, conv2d, and conv3d

---------

Co-authored-by: Sai Charan Arvapally <saicharan@Sais-MacBook-Pro.local>
2025-04-25 12:18:30 -07:00
Awni Hannun
fbc89e3ced
fix pinv (#2110) 2025-04-23 13:08:28 -07:00
Param Thakkar
600e87e03c
Added output_padding parameters in conv_transpose (#2092) 2025-04-23 09:26:33 -07:00
Hyunsung Lee
3836445241
Add broadcast_shapes in python API (#2091) 2025-04-22 18:57:39 -07:00
Yury Popov
1d2c9d6a07
Complex scan (#2094) 2025-04-22 18:56:28 -07:00
Awni Hannun
e8ac6bd2f5
irfft throws instead of segfaults on scalars (#2109) 2025-04-22 10:25:55 -07:00
Awni Hannun
fdadc4f22c
Add more complex unary ops (#2101) 2025-04-21 13:04:54 -07:00
Awni Hannun
79b527f45f
conv vmap (#2102) 2025-04-21 13:04:39 -07:00
Awni Hannun
dc4eada7f0
Use unordered map for kwargs in export/import (#2087)
* use unordered map for kwargs in export/import

* comment
2025-04-21 07:17:22 -07:00
Awni Hannun
55935ccae7
fix py gc edge case (#2079) 2025-04-18 12:46:53 -07:00
Angelos Katharopoulos
5de6d94a90
Gather qmm batched kernel and refactoring of quantized (#2078) 2025-04-17 13:53:11 -07:00
Angelos Katharopoulos
99eefd2ec0
Gather mm new kernel and small refactoring (#2040) 2025-04-14 16:37:36 -07:00
Yury Popov
e9e268336b
LogCumSumExp (#2069) 2025-04-13 01:27:29 -07:00
Angelos Katharopoulos
c4189a38e4
Add float mask to sdpa vector (#2068) 2025-04-11 17:29:40 -07:00
Awni Hannun
68d1b3256b
nit: fix exception handling (#2066) 2025-04-11 14:12:08 -07:00
Awni Hannun
ef7ece9851
fix fft bug (#2062) 2025-04-10 19:41:27 -07:00
Angelos Katharopoulos
ddaa4b7dcb
Fix the test and add custom min/max reductions for uncommon MPI types (#2060) 2025-04-10 17:01:17 -07:00
Anastasiia Filippova
515f104926
Min / max reductions (#2041) 2025-04-09 23:22:20 -07:00
Awni Hannun
00794c42bc
Fix causal mask sdpa vec (#2053)
* fix sdpa vector causal mask

* test
2025-04-08 09:11:23 -07:00
Awni Hannun
f2c85308c1
add a half simd gemm fallback (#2046)
* add a half simd gemm fallback

* nit
2025-04-07 09:31:29 -07:00
Awni Hannun
ec5e2aae61
nit in doc (#2044) 2025-04-04 12:04:17 -07:00
Jagrit Digani
3290bfa690
Add new sdpa function overload (#2035)
* Add new sdpa function overload

* Address comments

* Remove std::varaint from cpp sdpa function
2025-04-03 11:58:28 -07:00
Jagrit Digani
8777fd104f
Depthwise Conv2D optimization (#2036)
- Add new specialized kernel for small kernel (kernels size <= 7), small strides (strides <= 2) depthwise 2d convolutions
- Add related tests
2025-04-03 09:42:04 -07:00
Awni Hannun
de5f38fd48
Custom logsumexp (#2028)
* initial custom logsumexp

* more tests

* comments + fix
2025-03-31 07:36:55 -07:00
Angelos Katharopoulos
ec2854b13a
Swap -inf for finite_minimum value (#2029) 2025-03-30 21:55:04 -07:00
Stephen Panaro
90823d2938
Add missing funcs to docs (#2021) 2025-03-30 18:29:33 -07:00
Awni Hannun
28f39e9038
Log for complex numbers in Metal (#2025)
* Log for complex numbers in Metal

* fix log2
2025-03-30 17:04:38 -07:00
Awni Hannun
05d7118561
causal vector sdpa (#2018)
* causal vector sdpa

* get rid of memory threshold
2025-03-28 12:36:13 -07:00
Awni Hannun
98b901ad66
enable complex gemm (#2017) 2025-03-28 10:45:13 -07:00
Awni Hannun
5580b47291
iinfo and scalar overflow detection (#2009) 2025-03-27 19:54:56 -07:00
Yi Wang
a8931306e1
Remove unused variable in CMakeBuild (#2011)
Fix https://github.com/ml-explore/mlx/issues/2010
2025-03-27 16:00:51 -07:00
Chunyang Wen
022eabb734
Remove unused import (#1987) 2025-03-24 20:19:32 -07:00
Awni Hannun
a84cc0123f
promote mask when needed (#1998) 2025-03-23 19:58:28 -07:00
Angelos Katharopoulos
4eef8102c9
Distributed layers (#1270) 2025-03-21 13:52:17 -07:00
Angelos Katharopoulos
69e4dd506b
Add a ring all gather (#1985) 2025-03-21 13:36:51 -07:00
Awni Hannun
2a980a76ce
Add stats and limit to common allocator and enable tests (#1988)
* add stats to common allocator and enable tests

* linux memory and default

* fix
2025-03-21 12:28:36 -07:00
Awni Hannun
4e1994e9d7
move memory APIs into top level mlx.core (#1982) 2025-03-21 07:25:12 -07:00
jiyzhang
65a38c452b
update the formula of smooth_l1_loss (#1986) 2025-03-21 06:25:23 -07:00
Awni Hannun
7b7e2352cd
fix malloc or wait deadlock (#1976) 2025-03-20 16:48:43 -07:00
Awni Hannun
005e7efa64
fix mask in sdpa (#1980)
* fix mask in sdpa

* fix attention mask

* Re-enable routing for array mask

---------

Co-authored-by: Jagrit Digani <digani@apple.com>
2025-03-20 14:53:12 -07:00
Jagrit Digani
b42d13ec84
Update attention tests to show diff, disable array masks (#1978) 2025-03-20 14:25:38 -07:00
Jagrit Digani
9adcd1a650
Support fused masking in Attention (#1924)
* Update API to allow mask='causal' in fast::sdpa

* Add fallback

* Update steel::AttnParams

* Fix typo

* WIP, basic causal

* Update tests

* Update benchmarking

* Update masking loop limits

* Add bool masking and update tests

* Update additive mask

* Update benchmarks

* Update benchmarks

* Update tests

* Update for bfloat error

* Update early exit

* Add random seed to tests
2025-03-20 11:01:32 -07:00
Awni Hannun
3c164fca8c
Fix multistream GPU deadlock (#1969)
* fix multistream GPU deadlock

* comments
2025-03-20 07:19:47 -07:00
jiyzhang
95e335db7b
Update smooth_l1_loss in losses.py (#1974)
According the definition of smooth_l1_loss, the line 

diff = predictions - targets

Should be updated to 

diff = mx.abs(predictions - targets)

After the modification, the result is consistent with PyTorch smooth_l1_loss
2025-03-19 20:19:02 -07:00
Chunyang Wen
3779150750
refactor: all use schedule (#1973) 2025-03-19 11:24:04 -07:00
Chunyang Wen
45ad06aac8
Fix typo; Fix lint warning when reuse the same name (#1968)
* Fix typo; Fix lint warning when reuse the same name

* Add missing period
2025-03-18 07:12:24 -07:00
Awni Hannun
c6ea2ba329
Use same accumulation precision in gemv as gemm (#1962)
* use same accumulation precision in gemv as gemm

* faster

* fix compile
2025-03-16 07:13:24 -07:00
Awni Hannun
2770a10240
fix grad with inplace updates (#1961) 2025-03-13 19:13:09 -07:00