Commit Graph

283 Commits

Author SHA1 Message Date
Awni Hannun
baf9fa5f42
Einsum (#1269)
* einsum initial

* fix comma break

* sum axis was wrong

* small cleanups

* python binding

* changed bindings to resemble numpy

* remove todo comment

* comment changes

* add count of operands/inputs

* fail fast if operands list is empty

* ignore comma if no output

* einsum path matching numpy

* getting somewhere with path

* remove print

* it passes the first test

* moved einsum tests to seperate file

* seperated einsum path

* moved einsum naive

* remove space from equation

* fast fail if no operands passed

* update tests and remove printf

* small cleanup

* some more cleanups

* removed python helper file

* ack

* utilize std for finding min in vector

* duplicate def

* remove the tuple as it was unreadable

* moved einsum_naive back to ops

* remaining isn't needed

* avoid creating another set

* cleanup

* greedy path, start of naive einsum

* more einsum

* fix some bugs

* some more fixes, tests pass

* benchmark

* some simplify

* fix einsum and test

Co-authored-by: Angelos Katharopoulos <a_katharopoulos@apple.com>

* add a bunch more tests and fix a bunch more bugs

* some docs nits

---------

Co-authored-by: dc-dc-dc <dgcruz983@gmail.com>
Co-authored-by: Angelos Katharopoulos <a_katharopoulos@apple.com>
2024-07-25 09:36:44 -07:00
Jagrit Digani
7f914365fd
Fix GPU sort for large arrays (#1285)
* Fix GPU sort for large arrays
2024-07-24 14:37:10 -07:00
Paul Paczuski
ebd7135b50
Improve stability of BCE loss calculation for input probabilities close to or exactly 0 or 1 (#1280)
* Improve stability of BCE loss calculation

* Standardize comment

* Apply formatting with black via pre-commit

* Add usage recommendation to docstring

* Update python/mlx/nn/losses.py

---------

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>
2024-07-24 08:38:22 -07:00
fgranqvist
50eff6a10a
Implement sampling from laplace distribution. (#1279) 2024-07-24 15:15:37 +02:00
Alex Barron
c34a5ae7f7
Fix bfloat16 Hadamard (#1283)
* fix bfloat16 hadamard

* add scale

* review comments

---------

Co-authored-by: Alex Barron <abarron22@apple.com>
2024-07-23 14:54:43 -07:00
Awni Hannun
e2aa6ec8ae
some fixes (#1281) 2024-07-23 11:49:05 -07:00
Tim Gymnich
6307d166eb
Fix overflow / underflow handling for expm1f (#1278)
* Fix overflow / underflow handling for expm1f

* update tests
2024-07-23 07:29:06 -07:00
Awni Hannun
1fba87b0df
Fix leak with multi-output primitives (#1274)
* fix leak with multi-output primitives

* hopefully an actual fix
2024-07-23 06:34:18 -07:00
Angelos Katharopoulos
5c1fa64fb0
Custom transforms (#1246) 2024-07-10 18:00:01 -07:00
Alex Barron
a3c287354f
Fast Hadamard Transform (#1249)
* Working hadamard for powers of 2

* working for m*2^k

* add scale and check contiguity

* add size check

* clean up

* fix test

* add grads + vmap

* gpu only

* skip on linux

* test typo

* add cpu impl

* remove gpu only tests

* fix linux build + add is_equivalent
2024-07-09 20:39:01 -07:00
Alex Barron
bdb36c9a63
add zero vjps for bitwise ops and gather w.r.t. index (#1256) 2024-07-07 21:34:59 -07:00
Alex Barron
2615660e62
Fix strided sort bug (#1236)
* Use output strides in sort kernel

* fix zero strides bug
2024-06-26 14:32:11 -07:00
Awni Hannun
5b0af4cdb1
fix donation condition for compilation (#1237) 2024-06-26 09:04:05 -07:00
Jagrit Digani
2d6cd47713
Masked gemv (#1211) 2024-06-14 09:52:26 -07:00
Awni Hannun
df964132fb
fix scatter + test (#1202)
* fix scatter + test

* fix test warnings

* fix metal validation
2024-06-11 14:35:12 -07:00
Alex Barron
27d70c7d9d
Feature complete Metal FFT (#1102)
* feature complete metal fft

* fix contiguity bug

* jit fft

* simplify rader/bluestein constant computation

* remove kernel/utils.h dep

* remove bf16.h dep

* format

---------

Co-authored-by: Alex Barron <abarron22@apple.com>
2024-06-06 12:57:25 -07:00
Angelos Katharopoulos
0163a8e57a
Add docs for the distributed namespace (#1184) 2024-06-06 11:37:00 -07:00
Awni Hannun
496315fe1d
Fix scan (#1188)
* fix scan

* improve grid size

* fix cpu cummax
2024-06-05 14:21:58 -07:00
Angelos Katharopoulos
0fe6895893
Fix the hard-shrink test (#1185) 2024-06-04 16:22:56 -07:00
Nikhil Mehta
0b7d71fd2f
Add softmin, hardshrink, hardtanh (#1180)
---------

Co-authored-by: Nikhil Mehta <nikmehta@tesla.com>
2024-06-04 15:48:18 -07:00
Awni Hannun
83b11bc58d
Fix Metal API validation for empty concat (#1183) 2024-06-04 13:17:08 -07:00
Awni Hannun
ea9090bbc4
Add view op (#1179)
* add view primitive

* nit

* fix view
2024-06-04 08:05:27 -07:00
Brian Keene
1865299a30
Metal shaders for memory efficient self attention on large sequences (#964)
* Metal shaders for efficient self attention on large sequences

Updated fast attention: GEMM-ified with Steel primitives
Uses flash attention 1 for scale correction

* more compiler silencing

* Address rebase issues

* Templatize kernel instantiation, revise cpu bindings

* Safer writes to output

* Permit batch size > 1

* Numerical fixes for sdpa self attention

* Re-enable test, remove unused variable

* add benchmarking script

* Disable sdpa prior to perf tuning, and simplify tests for per-patch CI
2024-06-03 09:16:19 -07:00
K Venkat Ramnan
ab977109db
feat: Added dlpack device (#1165)
* feat: Added dlpack device

* feat: Added device_id to dlpack device

* feat: Added device_id to dlpack device

* doc: updated conversion docs

* doc: updated numpy.rst dlpack information

* doc: updated numpy.rst dlpack information

* Update docs/src/usage/numpy.rst

* Update docs/src/usage/numpy.rst

---------

Co-authored-by: Venkat Ramnan Kalyanakumar <venkatramnankalyanakumar@Venkats-MacBook-Air.local>
Co-authored-by: Awni Hannun <awni.hannun@gmail.com>
2024-05-31 12:29:01 -07:00
Awni Hannun
fd1c08137b
stable cumprod grad at 0 (#1167) 2024-05-31 12:28:42 -07:00
Jagrit Digani
76b6cece46
Fix multi-block sort stride management (#1169)
* Fix multi-block sort stride management

* Add seed to tests
2024-05-31 11:10:54 -07:00
Jagrit Digani
9f0df51f8d
Fix matvec vector stride bug (#1168) 2024-05-29 12:18:28 -07:00
Awni Hannun
e7a2a3dcd1
Fix a couple bugs (#1161)
* fix jit reduce for RMS norm

* make strides a single buffer

* better eval error message

* fix compiling with inf and bf16

* fix cpu compile with bf16
2024-05-28 15:18:18 -07:00
Awni Hannun
a87ef5bfc1
fix broadcast bug in bitwise ops (#1157) 2024-05-24 11:44:40 -07:00
Awni Hannun
7e26fd8032
Option to JIT steel gemm / conv (#1139) 2024-05-23 18:07:34 -07:00
Jagrit Digani
eab2685c67
Float mask update (#1152)
* Float mask update

* Update CPU impl
2024-05-23 17:20:44 -07:00
Angelos Katharopoulos
50dfb664db
Comms (#1097)
* Start the communications branch using MPI
* Add ops and primitives
* Add python bindings for distributed
2024-05-23 17:04:02 -07:00
Rifur13
9401507336
Add groups to 2-D convolutions (#1129)
* Added groups to 2-D convolutions. Only implemented for **some** specializations.

Also fixed 1D grouped convs with different kernel strides and added more tests.

* fix channels condition
2024-05-22 20:01:44 -07:00
Awni Hannun
eb8321d863
list based indexing (#1150) 2024-05-22 15:52:05 -07:00
Abe Leininger
79ef49b2c2
add mx.trace (#1143) (#1147)
* working c++ trace implementation

* updated throw + added overloads

* added python binding for trace function

* pre-commit reformatting

* add trace to docs

* resolve comments

* remove to_stream call
2024-05-22 15:50:27 -07:00
Awni Hannun
d568c7ee36
Rename block sparse (#1149)
* block_sparse_mm to gather_mm

* rename

* nit

* nit
2024-05-22 07:48:34 -07:00
jlwitthuhn
7e5674d8be
Treate 'minimum' differently in cosine decay (#1138) 2024-05-20 08:00:48 -07:00
Awni Hannun
fb71a82ada
Fix copy bug with many dims (#1137) 2024-05-17 21:10:03 -07:00
Luca Arnaboldi
b3ec792380
Implemented Cholesky on CPU (#1119) 2024-05-17 12:31:59 -07:00
Awni Hannun
81dd33af66
allow conversion to dlpack (#1120) 2024-05-16 16:11:37 -07:00
Angelos Katharopoulos
e78a6518fa
Block sparse qmm (#1124) 2024-05-16 15:24:14 -07:00
Awni Hannun
631dfbe673
fix scatter index bug (#1122) 2024-05-14 15:04:58 -07:00
Max-Heinrich Laves
ff4223904d
Conv3d (#993)
* added conv3d

added conv3d

implemented explicit_gemm_conv_ND_cpu and bounds checks for slow_conv_3D

* incorporated reviewer comments

* fixed test

* reduced tensor shapes in test for conv3d

* Reviewer suggestion

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>

Reviewer suggestion

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>

Reviewer suggestion

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>

Reviewer suggestion
2024-05-11 06:15:02 -07:00
Alex Barron
2e158cf6d0
Add conjugate operator (#1100)
* cpu and gpu impl

* add mx.conj and array.conj()

---------

Co-authored-by: Alex Barron <abarron22@apple.com>
2024-05-10 07:22:20 -07:00
Awni Hannun
b21242faf1
Allow unary ops to accept array like (#1093) 2024-05-09 09:36:02 -07:00
Rahul Yedida
cc05a281c4
Added ArcTan2 operation (#1079)
* Added ArcTan2 operation

* Cleanup, bug fixes from code review

* Minor cleanup, fixed Linux tests
2024-05-08 08:35:15 -07:00
Awni Hannun
9814a2ae12
fix conversion to array (#1070) 2024-05-06 16:02:49 -07:00
Awni Hannun
21623156a3
Reset peak memory (#1074)
* reset peak memory

* fix linux

* nits in docs
2024-05-03 17:12:51 -07:00
Nripesh Niketan
79c859e2e0
feat: implement clip_grad_norm (#1043)
* feat: implement `clip_grad_norm`

* pre-commit

* Add test for clip_grad_norm function in test_optimizers.py

* small fixes

* fix

* lint

* Update tree_reduce

* Update python/mlx/utils.py

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>

* Update python/mlx/utils.py

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>

* Update python/mlx/utils.py

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>

* Update python/mlx/utils.py

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>

* Update python/mlx/utils.py

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>

* Update python/mlx/utils.py

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>

* Refactor clip_grad_norm function to include documentation and improve readability

* format docstring

* Add acknowlegements

* text wrap

* pre-commit

* nits in docs

---------

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>
Co-authored-by: Awni Hannun <awni@apple.com>
2024-05-03 09:07:02 -07:00
Jagrit Digani
f390957685
Block sparse mm (#1058) 2024-05-02 14:03:58 -07:00