Angelos Katharopoulos
5e6c130d93
RMS norm without scaling ( #1915 )
2025-02-28 20:26:57 -08:00
Angelos Katharopoulos
f5cc1eea72
Allow different value dimensions in sdpa_vector ( #1811 )
2025-01-31 20:58:59 -08:00
Awni Hannun
a4667da1eb
Faster synchronization Fence
primitive ( #1773 )
...
* try faster synchronization
move event
fixes
update bench
fix
fix
* non-functioning kernel
* try alternative fence
* cleanup barrier
* get rid of event_fence
* update benchmarks
* doc string in metal fence
2025-01-17 18:42:19 -08:00
Awni Hannun
d1766f2c70
Add boolean mask support in vector SDPA ( #1757 )
2025-01-07 20:24:53 -08:00
Cheng
4f9b60dd53
Remove "using namespace mlx::core" in benchmarks/examples ( #1685 )
...
* Remove "using namespace mlx::core" in benchmarks/examples
* Fix building example extension
* A missing one in comment
* Fix building on M chips
2024-12-11 07:08:29 -08:00
Jagrit Digani
02bec0bb6d
Matrix Attention kernel ( #1610 )
...
* Rough INIT
* [WIP]: Loading and Matmuls added
* [WIP]: Reductions and min working aligned kernel at headdim = 64
* [WIP] Added headdim 80 for testing
* [WIP] Update dispatch params for testing
* [WIP] Add support for unaligned seq lengths - still looks messy
* Update sdpa_benchmarks
* Update sdpa_benchmarks
* Update sdpa_benchmarks
* Enable gqa support
* Update benchmark and switch off 128 headdim
* Update headdim 128 tuning
* Remove older fast attention code. Write out O strided
* Disable hd=128 until further optimizations
* Enable bf16
* Fix data size bug
* Enable attn build outside of jit
2024-11-22 10:34:05 -08:00
Angelos Katharopoulos
073076ac7d
2-Pass Sdpa Inference Kernel ( #1597 )
2024-11-18 17:31:53 -08:00
Angelos Katharopoulos
248431eb3c
Reductions update ( #1351 )
2024-11-04 22:25:16 -08:00
Awni Hannun
4f72c66911
improvements to scatter / gather ( #1541 )
2024-10-30 19:30:54 -07:00
Angelos Katharopoulos
50d8bed468
Fused attention for single query ( #1497 )
2024-10-18 00:58:52 -07:00
Max-Heinrich Laves
adcc88e208
Conv cpu improvements ( #1410 )
2024-09-15 18:45:10 -07:00
Max-Heinrich Laves
efeb9c0f02
Transposed Convolution ( #1245 )
...
* initial implementation for conv_transpose
ran pre-commit
implemented conv_transpose
updated conv_general docstring
updated conv_general docstring
updated code comments
removed commented run_conv_checks
updated acknowledgments
added missing entry to ops.rst
added op to nn.layers
resolved merge conflicts
* removed ConvolutionTranspose primitive as suggested by reviewer
removed ConvolutionTranspose primitive as suggested by reviewer
* remove transpose flag, add another test
---------
Co-authored-by: Awni Hannun <awni@apple.com>
2024-09-06 19:52:38 -07:00
Awni Hannun
5f7d19d1f5
MPI ops in GPU stream for faster comms ( #1356 )
2024-08-26 15:12:50 -07:00
Awni Hannun
baf9fa5f42
Einsum ( #1269 )
...
* einsum initial
* fix comma break
* sum axis was wrong
* small cleanups
* python binding
* changed bindings to resemble numpy
* remove todo comment
* comment changes
* add count of operands/inputs
* fail fast if operands list is empty
* ignore comma if no output
* einsum path matching numpy
* getting somewhere with path
* remove print
* it passes the first test
* moved einsum tests to seperate file
* seperated einsum path
* moved einsum naive
* remove space from equation
* fast fail if no operands passed
* update tests and remove printf
* small cleanup
* some more cleanups
* removed python helper file
* ack
* utilize std for finding min in vector
* duplicate def
* remove the tuple as it was unreadable
* moved einsum_naive back to ops
* remaining isn't needed
* avoid creating another set
* cleanup
* greedy path, start of naive einsum
* more einsum
* fix some bugs
* some more fixes, tests pass
* benchmark
* some simplify
* fix einsum and test
Co-authored-by: Angelos Katharopoulos <a_katharopoulos@apple.com>
* add a bunch more tests and fix a bunch more bugs
* some docs nits
---------
Co-authored-by: dc-dc-dc <dgcruz983@gmail.com>
Co-authored-by: Angelos Katharopoulos <a_katharopoulos@apple.com>
2024-07-25 09:36:44 -07:00
Alex Barron
a3c287354f
Fast Hadamard Transform ( #1249 )
...
* Working hadamard for powers of 2
* working for m*2^k
* add scale and check contiguity
* add size check
* clean up
* fix test
* add grads + vmap
* gpu only
* skip on linux
* test typo
* add cpu impl
* remove gpu only tests
* fix linux build + add is_equivalent
2024-07-09 20:39:01 -07:00
Alex Barron
27d70c7d9d
Feature complete Metal FFT ( #1102 )
...
* feature complete metal fft
* fix contiguity bug
* jit fft
* simplify rader/bluestein constant computation
* remove kernel/utils.h dep
* remove bf16.h dep
* format
---------
Co-authored-by: Alex Barron <abarron22@apple.com>
2024-06-06 12:57:25 -07:00
Nikhil Mehta
0b7d71fd2f
Add softmin, hardshrink, hardtanh ( #1180 )
...
---------
Co-authored-by: Nikhil Mehta <nikmehta@tesla.com>
2024-06-04 15:48:18 -07:00
nicolov
81def6ac76
Fix benchmark ( #1175 )
2024-06-04 07:50:46 -07:00
Brian Keene
1865299a30
Metal shaders for memory efficient self attention on large sequences ( #964 )
...
* Metal shaders for efficient self attention on large sequences
Updated fast attention: GEMM-ified with Steel primitives
Uses flash attention 1 for scale correction
* more compiler silencing
* Address rebase issues
* Templatize kernel instantiation, revise cpu bindings
* Safer writes to output
* Permit batch size > 1
* Numerical fixes for sdpa self attention
* Re-enable test, remove unused variable
* add benchmarking script
* Disable sdpa prior to perf tuning, and simplify tests for per-patch CI
2024-06-03 09:16:19 -07:00
Rifur13
9401507336
Add groups to 2-D convolutions ( #1129 )
...
* Added groups to 2-D convolutions. Only implemented for **some** specializations.
Also fixed 1D grouped convs with different kernel strides and added more tests.
* fix channels condition
2024-05-22 20:01:44 -07:00
Rifur13
c4a471c99d
Add groups to Conv1d ( #948 )
...
* Add conv1d grouped convs on CPU
* Add GPU support
* Parallelize inside metal kernel
* clenaup
* Update mlx/ops.cpp
Co-authored-by: Awni Hannun <awni.hannun@gmail.com>
* New unfold kernel + remove unused code
* Remove copy and refactor
* Update vjp and reuse steel gemm
* Fixed groups on cpu
* Fix metal validation
---------
Co-authored-by: Awni Hannun <awni.hannun@gmail.com>
2024-04-27 06:24:57 -07:00
Alex Barron
2e7c02d5cd
Metal FFT for powers of 2 up to 2048 ( #915 )
...
* add Metal FFT for powers of 2
* skip GPU test on linux
* fix contiguity bug
* address comments
* Update mlx/backend/metal/fft.cpp
* Update mlx/backend/metal/fft.cpp
* fix bug in synch
---------
Co-authored-by: Alex Barron <abarron22@apple.com>
Co-authored-by: Awni Hannun <awni.hannun@gmail.com>
Co-authored-by: Awni Hannun <awni@apple.com>
2024-04-11 21:40:06 -07:00
Nripesh Niketan
ffff671273
Update pre-commit hooks ( #984 )
2024-04-11 07:27:53 -07:00
Cheng
913b19329c
Add missing && when forwarding args ( #925 )
...
Without the && args would be copied and perfect forwarding won't work.
2024-03-29 06:48:29 -07:00
Angelos Katharopoulos
29221fa238
Implement vjps for some primitives in the fast namespace ( #883 )
...
* Implement rope vjp in terms of rope
* RMSNormVJP primitive and kernel
* Add LayerNormVJP primitive and kernel
2024-03-26 16:35:34 -07:00
Angelos Katharopoulos
6ee1112f30
Fix copy donation and add partial rope ( #881 )
2024-03-22 17:28:26 -07:00
Jagrit Digani
6686e61ca4
Reduce update ( #783 )
...
* Split reduction files to reduce compile times
* Add small and medium axis size specializations for row reductions
* Add non-row-reduction options for small and med kernels
2024-03-04 19:09:51 -08:00
Jagrit Digani
776c3d226d
Convolution update ( #651 )
...
* Init steel conv and update Conv primitive
* Update slow CPU implementation to support flipping and input dilation winograd conv routing
Co-authored-by: Awni Hannun <awni@apple.com>
2024-02-28 20:11:16 -08:00
Rifur13
126c9869c8
Implement the 'where' primitive for conditional selection ( #664 )
2024-02-22 15:10:48 -08:00
Vijay Krish
972d9a3aea
Up to 10x faster scatter. ( #709 )
...
* Faster scatter.
Add specialization for 1-d index tensors.
* Address review comments.
- Check for row contiguity of index, update tensors
instead of checking strides.
- Add support for 1d specialization with col contiguous update
tensor, along with a test.
* Nit1
Co-authored-by: Awni Hannun <awni.hannun@gmail.com>
* Nit2
Co-authored-by: Awni Hannun <awni.hannun@gmail.com>
---------
Co-authored-by: Awni Hannun <awni.hannun@gmail.com>
2024-02-21 11:09:30 -08:00
Awni Hannun
5798256fcf
Shapeless compilation for some graphs ( #687 )
...
* shapeless compilation for some graphs
* update compile benchmark
* default compile a few activations
* buffer donation
* bugfix
* shapeless fix
* update tests to work for cpu and gpu fusion
* test kwargs
* add kwargs to compile
* Recompile when python arguments change
* no compile for tanh
* some constant tests
---------
Co-authored-by: Angelos Katharopoulos <a_katharopoulos@apple.com>
2024-02-19 21:43:54 -08:00
Awni Hannun
ccf1645995
Custom primitive + RoPE fat op ( #676 )
...
* extensions start
* rope custom op
* fix build
* docs + rope benchmark
* fix test
* Add a Metal kernel for RoPE
* Fix position of traditional
* transform tests
* Move rope computation to float and fix tests
* Fix the test and a typo
* change to fast
* fix no metal build
---------
Co-authored-by: Angelos Katharopoulos <a_katharopoulos@apple.com>
2024-02-14 14:04:25 -08:00
Vijay Krish
2fdc2462c3
Faster gather and scatter. ( #682 )
...
Reduce unnecessary integer ops, especially since
there kernels are integer bound.
Increase number of iterations for benchmarks for
better smoothing.
Github Issue #506
Co-authored-by: Vijay Krishnamoorthy <vijay_krish@apple.com>
2024-02-13 17:47:41 -08:00
Nripesh Niketan
0dbc4c7547
feat: Update pre-commit-config.yaml ( #667 )
2024-02-11 06:08:20 -08:00
Vijay Krish
06072601ce
Scatter optimization : Eliminate 64b integer divide. ( #662 )
...
Launch 2D grid to eliminate divide and mod in device code,
since 64b integer division is very expensive.
Github Issue #506
Co-authored-by: Vijay Krishnamoorthy <vijay_krish@apple.com>
2024-02-10 08:49:51 -08:00
Awni Hannun
e319383ef9
Faster gather ( #626 )
...
* faster gather
* update copyright
2024-02-04 17:25:44 -08:00
Awni Hannun
3c2f192345
Propagate nans in binary ops ( #579 )
...
* propagate nans in binary ops
* handle empty matmul
* cpu minimum/maximum propagate nan
* benchmark maximum
* add min as well
* throw on negative indices with full
* verbose on linux
* fix matmul for zero K
2024-01-29 11:19:38 -08:00
Awni Hannun
86e0c79467
remove stale benchmarks ( #527 )
2024-01-22 22:17:58 -08:00
Awni Hannun
7a34e46677
Quantize with groups of 32 ( #511 )
...
* allow quantize with group sizes of 32
* missing cpu dispatch
* remove print
* Fix qvm for group_size 32
---------
Co-authored-by: Angelos Katharopoulos <a_katharopoulos@apple.com>
2024-01-21 06:19:05 -08:00
Jagrit Digani
78102a47ad
Update GEMM ( #424 )
...
* Organize and collect metal subroutine templates and elements in `metal/kernels/steel/`
* Update gemm elements for better performance
* Add split-K specialization for gemm
* Add `addmm` primitive, op and bindings for fused matmul and bias addition
* Update tests and benchmarks as needed
2024-01-17 12:42:39 -08:00
Angelos Katharopoulos
c15fe3e61b
Allow arbitrary first dimension in quantization kernels. ( #458 )
...
* Allow arbitrary first dim on qmm_t and qmv
* Allow arbitrary first dim on qmm and qvm
* Specialized aligned vs unaligned case
* Add more checks for valid quantizations
2024-01-16 00:46:21 -08:00
Awni Hannun
f099ebe535
Multi output primitives ( #330 )
...
* Multi-output primitives
---------
Co-authored-by: Angelos Katharopoulos <a_katharopoulos@apple.com>
2024-01-08 16:39:08 -08:00
Angelos Katharopoulos
e7f5059fe4
Support for quantized matmul with w and w^T ( #349 )
...
* Add the metal qvm implementation
* Add qmm_n
* Add gradient wrt to input for quantized_matmul
2024-01-03 14:22:36 -08:00
Josh Soref
44c1ce5e6a
Spelling ( #342 )
...
* spelling: accumulates
Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>
* spelling: across
Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>
* spelling: additional
Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>
* spelling: against
Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>
* spelling: among
Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>
* spelling: array
Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>
* spelling: at least
Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>
* spelling: available
Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>
* spelling: axes
Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>
* spelling: basically
Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>
* spelling: bfloat
Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>
* spelling: bounds
Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>
* spelling: broadcast
Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>
* spelling: buffer
Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>
* spelling: class
Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>
* spelling: coefficients
Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>
* spelling: collision
Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>
* spelling: combinations
Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>
* spelling: committing
Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>
* spelling: computation
Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>
* spelling: consider
Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>
* spelling: constructing
Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>
* spelling: conversions
Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>
* spelling: correctly
Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>
* spelling: corresponding
Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>
* spelling: declaration
Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>
* spelling: default
Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>
* spelling: dependency
Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>
* spelling: destination
Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>
* spelling: destructor
Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>
* spelling: dimensions
Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>
* spelling: divided
Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>
* spelling: element-wise
Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>
* spelling: elements
Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>
* spelling: endianness
Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>
* spelling: equivalent
Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>
* spelling: explicitly
Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>
* spelling: github
Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>
* spelling: indices
Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>
* spelling: irregularly
Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>
* spelling: memory
Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>
* spelling: metallib
Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>
* spelling: negative
Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>
* spelling: notable
Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>
* spelling: optional
Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>
* spelling: otherwise
Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>
* spelling: overridden
Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>
* spelling: partially
Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>
* spelling: partition
Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>
* spelling: perform
Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>
* spelling: perturbations
Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>
* spelling: positively
Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>
* spelling: primitive
Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>
* spelling: repeat
Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>
* spelling: repeats
Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>
* spelling: respect
Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>
* spelling: respectively
Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>
* spelling: result
Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>
* spelling: rounding
Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>
* spelling: separate
Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>
* spelling: skipping
Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>
* spelling: structure
Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>
* spelling: the
Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>
* spelling: transpose
Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>
* spelling: unnecessary
Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>
* spelling: unneeded
Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>
* spelling: unsupported
Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>
---------
Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>
2024-01-01 21:08:17 -08:00
Angelos Katharopoulos
9e6b8c9f48
Refactor the reduction kernels ( #277 )
2023-12-24 14:47:57 -08:00
Angelos Katharopoulos
dfa9f4bc58
An initial quantized matmul implementation ( #205 )
...
* Add quantized matvec
* Add quantized matrix matrix with 2nd matrix transposed
* Add quantized matmul tests
* Add a slow cpu quantized matmul
* Add a slightly faster vectorized cpu version
2023-12-18 23:18:57 -08:00
Diogo
02de234ef0
Activations LeakyReLU / PReLU / Softplus / Mish ( #109 )
...
* Leaky_relu / prelu / softplus / mish
* added tests
* updated bench
* remove torch refs, add init to PReLU
* added arvix reference to mish
* added missing docs
2023-12-11 19:40:57 -08:00
Nicholas Santavas
f5df47ec6e
Add Step, ELU, SELU, Swish activation functions ( #117 )
...
* Add Step, ELU, SELU, Swish activation functions
This commit adds the Step, ELU, SELU and Swish activations functions
* add to the docs
* review
2023-12-11 17:04:07 -08:00
Jason
b0cd092b7f
Added activation functions: leaky_relu relu6 softplus elu celu logsigmoid ( #108 )
...
* added leaky_relu relu6 softplus elu celu logsigmoid
* minor fixes for docstring and benchmark imports
* fixed elu implementation and added tests
* added tests for optional param, changed leaky_relu param to fit pytorch documentation
2023-12-10 16:31:38 -08:00
Zach Schillaci
5b9be57ac3
Add isort pre-commit and run ( #68 )
2023-12-08 11:31:47 -08:00