Commit Graph

613 Commits

Author SHA1 Message Date
Alex Barron
5cd97f7ffe
Bitwise Inverse (#1862)
* add bitwise inverse

* add vmap + fix nojit

* inverse -> invert

* add to compile + remove unused
2025-02-13 08:44:14 -08:00
Awni Hannun
d274ae77f2
More buffer donation in some cases (#1858)
* more donation

* fix

* add test
2025-02-12 19:41:37 -08:00
Alex Barron
55c5ac7820
fix int64 bug (#1860) 2025-02-12 19:23:46 -08:00
Angelos Katharopoulos
0145911bea
Fixes output donation for IO ops on the GPU (#1857) 2025-02-12 10:52:30 -08:00
Awni Hannun
0a5215693e
Fix grad copies (#1854)
* fix grad with copies

* add test

* add test
2025-02-11 15:26:42 -08:00
Awni Hannun
2a45056ba8
Cycle leak break (#1856)
* detect and break leaks in custom function

* detect and break leaks in custom function
2025-02-11 14:45:02 -08:00
Abe Leininger
a5ededf1c3
CPU LU factorization and linear solvers (#1451)
* linalg solve backend

* nits

* more nits + fix

* luf primitive and lu, solve, and solve_triangular backends

* changes / nits

---------

Co-authored-by: Awni Hannun <awni@apple.com>
2025-02-10 12:32:24 -08:00
Franck Verrot
7df3f792a2
Ensure Conv2D and Conv3D's kernel sizes aren't trimmed (#1852)
Before the change, this snippet:

```
print(nn.Conv1d(1, 32, 3, padding=1))
print(nn.Conv2d(1, 32, (3, 3), padding=1))
print(nn.Conv3d(1, 32, (3, 3, 3), padding=1))
```

would output:

```
Conv1d(1, 32, kernel_size=3, stride=1, padding=1, dilation=1, groups=1, bias=True)
Conv2d(1, 32, kernel_size=(3,), stride=(1, 1), padding=(1, 1), dilation=1, groups=1, bias=True)
Conv3d(1, 32, kernel_size=(3, 3), stride=(1, 1, 1), padding=(1, 1, 1), dilation=1, bias=True)
```

After the change, the output will be:

```
Conv1d(1, 32, kernel_size=3, stride=1, padding=1, dilation=1, groups=1, bias=True)
Conv2d(1, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), dilation=1, groups=1, bias=True)
Conv3d(1, 32, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), dilation=1, bias=True)
```
2025-02-10 06:27:01 -08:00
Angelos Katharopoulos
9eb7d7362f
Fix Split::vmap (#1845) 2025-02-08 09:22:13 -08:00
Awni Hannun
1c0c118f7c
Fp64 on the CPU (#1843)
* add fp64 data type

* clean build

* update docs

* fix bug
2025-02-07 15:52:22 -08:00
Awni Hannun
83a0340fa7
allow command (#1836) 2025-02-06 10:32:24 -08:00
Awni Hannun
af1b725fda
Fix a couple of slicing bugs (#1827)
* fix a few bugs

* fix conv grad

* speedup test

* comment
2025-02-05 19:50:08 -08:00
Awni Hannun
9174606d4c
fix sort (#1835) 2025-02-05 17:16:27 -08:00
Awni Hannun
ca305afdbe
loading empty list is ok when strict = false (#1834) 2025-02-05 16:19:27 -08:00
Awni Hannun
ec7c7def40
no line buffer for mpi jobs (#1825) 2025-02-03 12:02:15 -08:00
Angelos Katharopoulos
f5cc1eea72
Allow different value dimensions in sdpa_vector (#1811) 2025-01-31 20:58:59 -08:00
Awni Hannun
b7c9f1d38f
scatter axis + gather axis primitives (#1813)
* scatter axis + gather axis primitives

* add transforms

* comment
2025-01-31 20:48:08 -08:00
Angelos Katharopoulos
ded914f442
Small distributed launch helper (#1810) 2025-01-29 17:55:04 -08:00
Awni Hannun
4758c8baa1
Start to cleanup/unify accelerate and common back-ends (Part 1/N) (#1777)
* start to cleanup/unify accelerate and common back-ends

* more progress

* simplify

* add half type and allow infs in simd exp

* unify softmax + quantized, more dispatches to simd quantized mm

* add sin/cos, use simd in vector-scalar ops

* faster CPU vectorize quant

* faster erf/erfinv
2025-01-29 14:34:49 -08:00
Awni Hannun
1017ac4a9e
add dilation for conv 3d layers + test for 3d conv w/ dilation (#1802) 2025-01-28 06:17:07 -08:00
Angelos Katharopoulos
ccb61d7aae
Ring distributed backend (#1784) 2025-01-27 22:15:01 -08:00
Awni Hannun
2235dee906
catch stream errors earlier to avoid aborts (#1801) 2025-01-27 14:05:43 -08:00
Awni Hannun
28091aa1ff
allow build python lib without specifying path (#1799) 2025-01-27 11:22:35 -08:00
Awni Hannun
121d9a0702
Fix rope fallback to not upcast (#1797)
* fix rope fallback to not upcast

* Update mlx/fast.cpp

Co-authored-by: Angelos Katharopoulos <a_katharopoulos@apple.com>

---------

Co-authored-by: Angelos Katharopoulos <a_katharopoulos@apple.com>
2025-01-26 19:07:21 -08:00
Nick
0cea88bcc5
Use @ matrix multiplication syntax to document matrix-matrix multiplication (#1793)
Co-authored-by: Nick Thompson <nicholas_a_thompson@apple.com>
2025-01-25 16:02:36 -08:00
Angelos Katharopoulos
72146fc4cd
Einsum ellipsis (#1788) 2025-01-25 01:28:03 -08:00
Awni Hannun
e6a7ab9675
non square qr (#1783) 2025-01-21 14:07:47 -08:00
Awni Hannun
90532b1f37
recompile when shapeless is different (#1776) 2025-01-20 21:07:10 -08:00
Awni Hannun
0c259961ac
matmul jvps (#1772) 2025-01-17 10:36:26 -08:00
Awni Hannun
33421c1dd3
Limit grad recursion depth by not recursing through non-grad inputs (#1764)
* limit grad recursion depth

* add grad of module test
2025-01-14 14:33:18 -08:00
Nripesh Niketan
5cc5201914
feat: Add orthogonal initializer and corresponding tests (#1651)
* feat: Add orthogonal initializer and corresponding tests

* lint

* Add acknowledgements

* nits

---------

Co-authored-by: Awni Hannun <awni@apple.com>
2025-01-13 07:29:20 -08:00
wrmsr
a4a2764a52
Fix broadcast_arrays python sig (#1763) 2025-01-10 12:33:26 -08:00
Awni Hannun
657f466402
use sdpa and exportable functions in transformer multi head attention (#1760) 2025-01-09 13:11:55 -08:00
Alex Barron
c7b0300af5
Fix batched qmv bug (#1758) 2025-01-09 11:45:57 -08:00
Awni Hannun
1ccaf80575
Dynamic broadcasting for shapeless compile/export (#1722)
* working towards dynamic broadcast

* shapeless broadcast

* fix build + nits

* use broadcast arrays in quantize matmul

* some cleanup / consistency

* mend

* some comments

* add vjp, jvp for broadcast axes
2025-01-09 11:04:24 -08:00
Awni Hannun
d1766f2c70
Add boolean mask support in vector SDPA (#1757) 2025-01-07 20:24:53 -08:00
Awni Hannun
516ded618b
Dynamic slicing (#1741)
* dynamic slice and slice update

* python bindings + tests + fix set item

* fix compile issue

* comment

* fix jit
2025-01-07 14:02:16 -08:00
Awni Hannun
d5ec172c95
Allow boolean mask in sdpa (#1753)
* allow boolean mask in sdpa

* more permissive donation in ternary
2025-01-06 16:57:07 -08:00
Angelos Katharopoulos
25b3a3e541
Optionally specify names for arrays when exporting (#1749) 2025-01-06 13:07:46 -08:00
Awni Hannun
058d6ce683
mpi send use input as output (#1750)
* mpi send use input as output

* move earlier
2025-01-06 06:08:43 -08:00
Angelos Katharopoulos
eab93985b8
Update custom function docs (#1748) 2025-01-03 16:35:25 -08:00
Awni Hannun
259025100e
Fix nd ternary on GPU (#1746) 2025-01-03 11:52:17 -08:00
Awni Hannun
c9d30aa6ac
MLX in C++ example (#1736)
* MLX in C++ example

* nits

* fix docs
2025-01-02 19:09:04 -08:00
Angelos Katharopoulos
8544b42007
Add namespace (#1745) 2025-01-02 16:49:23 -08:00
Awni Hannun
6fa0501387
Fix concatenate/slice_update vjp + reduce binary size (#1735)
* fix concatenate vjp + reduce binary size

* also cast in slice update
2025-01-02 16:36:33 -08:00
Awni Hannun
ae69cb15e9
shapeless compile in docs and partially shapeless reshape (#1742) 2025-01-02 16:24:42 -08:00
Venkata Naga Aditya Datta Chivukula
491fa95b1f
Added Kronecker Product (#1728) 2025-01-02 16:00:34 -08:00
Awni Hannun
4ba0c24a8f
Export / import functions to / from a file (#1642)
* export and import functions

* refactor + works for few primitives

* nit

* allow primitives with state

* nit

* nit

* simplify serialize / deserialize

* fix for constants

* python bindings

* maybe fix serialize failure case

* add example

* more primitives, training kind of works

* same result for python and c++

* some fixes

* fix export

* template it up

* some simplificatoin

* rebase

* allow kwargs and multiple functions

* exporter

* more primitives for exporting

* deal with endianness

* handle invalid stream

* add docstring
2024-12-24 11:19:13 -08:00
Awni Hannun
ebfe64b92d
shapeless slice update and broadcast when possible (#1727) 2024-12-23 11:25:15 -08:00
Awni Hannun
0308e9af71
Allow offset to be an mx.array for mx.fast.rope (#1724)
* allow offset for rope

* comment
2024-12-19 15:51:44 -08:00
Awni Hannun
c3628eea49
Add mx.finfo and use it when making causal mask (#1726)
* finfo

* fixes

* docs
2024-12-19 14:52:41 -08:00
Awni Hannun
e03f0372b1
More shape type (#1705)
* more shape type

* fix
2024-12-19 08:08:20 -08:00
Awni Hannun
7480059306
track resource limit and throw if exceeded (#1718) 2024-12-18 18:45:58 -08:00
Awni Hannun
f110357aaa
Bump nanobind to 2.4 + fix (#1710)
* bump nanobind to 2.4 + fix

* fix
2024-12-17 10:57:54 -08:00
Tomohiro Oga
a6b426422e
add cubic to type hinting for upsample (#1709) 2024-12-17 07:30:23 -08:00
Awni Hannun
d03c01dfbc
fix unflatten vjp (#1708) 2024-12-16 18:37:57 -08:00
Cheng
af5a614aad
Eval before cleanup so model file is unlocked (#1702) 2024-12-14 21:41:49 -08:00
Cheng
dfccd17ab9
Use psutil to get memory info on Windows (#1700) 2024-12-13 19:50:13 -08:00
Awni Hannun
50f3535693
Use expand_dims / unflatten / etc in more places (#1696)
* use expand_dims / unflatten in a couple more places

* few more

* few more

* fix
2024-12-12 17:00:44 -08:00
Awni Hannun
9111999af3
Fix small sort with metal validation (#1695) 2024-12-12 09:21:45 -08:00
Awni Hannun
6bd28d246e
Allow no copy negative strides in as_strided and slice (#1688)
* allow no copy negative strides in as_strided and slice

* fix jit

* fix jit
2024-12-12 08:59:45 -08:00
Awni Hannun
3a21f61772
Fix build (#1693) 2024-12-11 23:56:25 -08:00
Awni Hannun
4e1e9520e1
Flatten and unflatten (#1692)
* flatten and unflatten

* fix grad

* fix shape infer

* use squeeze + unsqueeze in get_item
2024-12-11 21:51:37 -08:00
Cheng
0bf19037ca
Remove "using namespace mlx::core" in python/src (#1689) 2024-12-11 15:45:39 -08:00
Awni Hannun
f76a49e555
ExpandDims primitive (#1687)
* add squeeze primitive

* simplify squeeze, use in gather

* fix

* fix

* fix

* fix

* fix no cpu

* use squeeze in matmul and friends

* expand dims primitive

* comment
2024-12-10 16:39:07 -08:00
Cheng
92ab6bdeb8
Fix shared library not exporting symbols on Windows (#1684)
* Fix shared library not exporting symbols on Windows

* Function name style
2024-12-10 13:59:14 -08:00
Cheng
a59fae040f
Fix library output directory for MSVC (#1681) 2024-12-09 19:07:50 -08:00
Awni Hannun
29a620cab2
No reshapes in quantized embedding (#1682)
* no reshapes in quantized embedding

* fix inadvertant cast

* add tol
2024-12-09 18:57:38 -08:00
Cheng
87d7a2520e
Use Py_ssize_t in python bindings (#1678)
* Use Py_ssize_t in python bindings

* Args passed to std::max must be same type
2024-12-09 12:59:19 -08:00
Awni Hannun
40c62c1321
Use int64 stride everywhere (#1671)
* use int64 stride everywhere

* fix ext

* fix ext

* more shape + cleanup

* one more

* few more
2024-12-09 11:09:02 -08:00
Awni Hannun
35b412c099
Fix compile hasher for string constants. (#1677)
* fix hash

* add test

* nit
2024-12-09 09:26:18 -08:00
mt_caret
fd3377dd1f
Support bias correction in Adam and AdamW optimizers (#1640) 2024-12-06 12:13:34 -08:00
Awni Hannun
bc2a29f033
fix (#1654) 2024-12-06 10:48:58 -08:00
Awni Hannun
e047fd977d
compile changes if stream changes (#1644) 2024-12-03 14:37:44 -08:00
Alex Barron
1445dcaa60
let class predicate specify quantization parameters (#1638) 2024-12-02 14:09:28 -08:00
Awni Hannun
aa86876813
fix transformer decoder post norm LN (#1637) 2024-12-02 07:02:17 -08:00
Awni Hannun
7cbb4aef17
Doc fix (#1615) 2024-11-22 11:12:25 -08:00
Alex Barron
c79f6a4a8c
3 and 6 bit quantization (#1613)
* Support 3 and 6 bit quantization
2024-11-22 10:22:13 -08:00
Awni Hannun
0c5eea226b
Reduce specializations (#1607)
* start of reduce specializations

* fix all reduce

* fix many dims

* fix

* non-jit tests clear

* cleanup instantiations

* cpu merges

* change dim specializations

* optimize

* fix jit

* fix jit

* use higher precision for integer sum+prod

* fixes
2024-11-21 19:53:00 -08:00
Angelos Katharopoulos
d8c824c594
Formatting fixes (#1606) 2024-11-20 15:30:36 -08:00
Saanidhya
cb431dfc9f
Adds 3D pooling (#1526) 2024-11-19 16:45:24 -08:00
Awni Hannun
61d787726a
Fix view scalar bug segfault (#1603)
* fix view scalar bug

* fix view scalar bug

* one more fix
2024-11-19 10:54:05 -08:00
Angelos Katharopoulos
5e89aace9b
Fix concatenate vmap (#1600) 2024-11-19 10:44:04 -08:00
Awni Hannun
bf481e8e5d
Fix sibling leak (#1590)
* add test

* fix + test

* fix fix
2024-11-18 19:17:01 -08:00
Awni Hannun
9bd03dd9b4
More buffer donation with no-ops (#1591)
* more donation

* fix test

* fix build
2024-11-18 08:35:41 -08:00
Awni Hannun
8c34c9dac4
throw for invalid case and remove test (#1575) 2024-11-08 12:04:03 -08:00
Awni Hannun
91c0277356
fix per-example mask + docs in sdpa (#1574) 2024-11-08 11:51:15 -08:00
Awni Hannun
59247c2b62
add groups in conv2d (#1569) 2024-11-07 13:57:53 -08:00
Awni Hannun
54f05e7195
Fix gather vmap (#1563)
* fix gather

* fix
2024-11-05 11:29:20 -08:00
Alex Barron
26be608470
Add split_k qvm for long context (#1564)
* Add splitk qvm

* configurable splitk

* tuning

* remove extra instantiation

* remove refactor

* separate test

* cpu tolerance
2024-11-05 11:25:19 -08:00
Awni Hannun
76f275b4df
error in rms for wrong size (#1562) 2024-11-04 13:24:02 -08:00
Angelos Katharopoulos
62f297b51d
Sdpa fix (#1558) 2024-11-02 21:25:46 -07:00
Awni Hannun
09bc32f62f
No extra reshape (#1557)
* no extra reshape

* lint
2024-11-02 19:07:20 -07:00
Alex Barron
9e516b71ea
Add dispatchThreads to custom kernel doc (#1551)
* add dispatchThreads info

* update

* add link
2024-11-01 13:07:48 -07:00
Awni Hannun
57c6aa7188
fix multi output leak (#1548) 2024-10-31 09:32:01 -07:00
Awni Hannun
4f72c66911
improvements to scatter / gather (#1541) 2024-10-30 19:30:54 -07:00
Alex Barron
048fabdabd
Fix vmap constant output size (#1524)
* use inputs to determine output size

* remove noop vmap tests
2024-10-30 16:16:53 -07:00
Awni Hannun
d2ff04a4f2
fix format (#1539) 2024-10-28 18:29:14 -07:00
Awni Hannun
0eb56d5be0
Wired (#1510)
* expose residency sets as wire/unwire

* returns wired size

* fix

* runtime support check

* fix os check

* fix test

* fix no metal build

* docs

* nit

* nits in docs

* nits
2024-10-25 09:35:33 -07:00
Venkata Naga Aditya Datta Chivukula
430ffef58a
[Feature] Added Sparse Initialization (#1498)
Co-authored-by: Saanidhyavats <saanidhyavats@gmail.com>
2024-10-24 12:31:24 -07:00