Commit Graph

80 Commits

Author SHA1 Message Date
Rifur13
126c9869c8 Implement the 'where' primitive for conditional selection (#664) 2024-02-22 15:10:48 -08:00
Jagrit Digani
884b4ed43b Fix threadgroup memory in arg reduce (#723) 2024-02-21 19:42:16 -08:00
Vijay Krish
972d9a3aea Up to 10x faster scatter. (#709)
* Faster scatter.

Add specialization for 1-d index tensors.

* Address review comments.

- Check for row contiguity of index, update tensors
  instead of checking strides.
- Add support for 1d specialization with col contiguous update
  tensor, along with a test.

* Nit1

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>

* Nit2

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>

---------

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>
2024-02-21 11:09:30 -08:00
Awni Hannun
5798256fcf Shapeless compilation for some graphs (#687)
* shapeless compilation for some graphs

* update compile benchmark

* default compile a few activations

* buffer donation

* bugfix

* shapeless fix

* update tests to work for cpu and gpu fusion

* test kwargs

* add kwargs to compile

* Recompile when python arguments change

* no compile for tanh

* some constant tests

---------

Co-authored-by: Angelos Katharopoulos <a_katharopoulos@apple.com>
2024-02-19 21:43:54 -08:00
Awni Hannun
1a4f4c5ea6 Refactor CPU compile preamble (#708)
* refactor cpu preamble

* fix include order

* fix some issues'

* fixes for linux

* try to fix includes

* add back warning suppression

* more linux fixes
2024-02-19 06:12:53 -08:00
Jack Mousseau
0925af43b0 Remove unused variables (#706) 2024-02-18 12:50:10 -08:00
Awni Hannun
dc937b8ed3 CPU compile (#691)
* build and load shared object for cpu compile

* nits

* cpu compile tests pass

* cpu compile tests pass

* fix preamble for g++

* donation

* fix gpu buffer donation

* reuse prebuilt libraries

* faster contiguity conditoins

* fix test

* rid compiler warning

* fast erf

* Fix float16 for compile and add more types to cpu compile

* Remove a forgotten comment

* use cached libs

* nits

---------

Co-authored-by: Angelos Katharopoulos <a_katharopoulos@apple.com>
2024-02-17 06:54:32 -08:00
Awni Hannun
c3965fc5ee Separate fast ops and primitives (#699) 2024-02-16 19:16:39 -08:00
Awni Hannun
ccf1645995 Custom primitive + RoPE fat op (#676)
* extensions start

* rope custom op

* fix build

* docs + rope benchmark

* fix test

* Add a Metal kernel for RoPE

* Fix position of traditional

* transform tests

* Move rope computation to float and fix tests

* Fix the test and a typo

* change to fast

* fix no metal build

---------

Co-authored-by: Angelos Katharopoulos <a_katharopoulos@apple.com>
2024-02-14 14:04:25 -08:00
Jagrit Digani
1a48713d32 Update gather and scatter to not use Argument Encoder (#683)
* Replace argument encoder usage for gather and scatter

* Use constant address space for shapes and strides

* Split gather and scatter to improve compile times

* Enable the GPU tests

* Update the CI config

* Fix scatter dispatch for scalar indices

* Remove arg encoder utils

---------

Co-authored-by: Angelos Katharopoulos <a_katharopoulos@apple.com>
2024-02-14 13:42:13 -08:00
Vijay Krish
2fdc2462c3 Faster gather and scatter. (#682)
Reduce unnecessary integer ops, especially since
there kernels are integer bound.

Increase number of iterations for benchmarks for
better smoothing.

Github Issue #506

Co-authored-by: Vijay Krishnamoorthy <vijay_krish@apple.com>
2024-02-13 17:47:41 -08:00
Angelos Katharopoulos
40c108766b Quantized matmul fix (#677)
* Fix qmv for small or unaligned matrices

* Fix qmm
2024-02-12 18:54:21 -08:00
Awni Hannun
3756381358 Faster bfloat quantized mat-vec and vec-mat (#663) 2024-02-11 21:53:16 -08:00
Awni Hannun
d12573daa6 quote file name (#670) 2024-02-11 10:33:30 -08:00
Vijay Krish
06072601ce Scatter optimization : Eliminate 64b integer divide. (#662)
Launch 2D grid to eliminate divide and mod in device code,
since 64b integer division is very expensive.

Github Issue #506

Co-authored-by: Vijay Krishnamoorthy <vijay_krish@apple.com>
2024-02-10 08:49:51 -08:00
Awni Hannun
7f3f8d8f8d Fix the softmax fix (#661) 2024-02-09 17:02:13 -08:00
Awni Hannun
b96be943dc bug fix (#658) 2024-02-09 16:50:45 -08:00
Abdussamet Türker
b670485185 Remainder negative numerator bug fixed (#641)
Co-authored-by: Angelos Katharopoulos <a_katharopoulos@apple.com>
2024-02-09 16:49:14 -08:00
Angelos Katharopoulos
28eac18571 Kernel generation (#614)
Generate reusable element-wise kernels given a computation graph.
2024-02-07 13:15:59 -08:00
Jagrit Digani
316ff490b3 Remove masks from BlockLoader and clear out load case for invalid thread (#634) 2024-02-05 16:00:17 -08:00
Awni Hannun
d40a04f8dc minor fixes (#631)
* minor fixes

* var with ddof >= nelements
2024-02-05 13:27:49 -08:00
Awni Hannun
d75ae52ecd Compile primitive (#571)
* Compiled primitive with basic binary, unary graph-level fusion
2024-02-05 06:51:22 -08:00
Awni Hannun
e319383ef9 Faster gather (#626)
* faster gather

* update copyright
2024-02-04 17:25:44 -08:00
David Koski
ebfd3618b0 fixes for building and running on iOS (#619)
* fixes for building and running on iOS

* per suggestion just use Accelerate
2024-02-04 12:29:17 -08:00
Awni Hannun
cb6156d35d Fix eval in trace bugs (#612)
* Fix eval in trace bugs

* comment nit
2024-02-02 09:57:12 -08:00
Vijay Krish
fcc5ac1c64 Add GPU support for uint64/int64 reductions (#569) 2024-01-31 11:18:04 -08:00
Angelos Katharopoulos
0de5988f92 Custom VJP and checkpointing (#541)
* Implement custom_vjp and checkpointing
* Add a dependency management primitive
* Change the eval order to deep branches first
* Add graph depth tracking to the array
2024-01-30 16:04:45 -08:00
Jagrit Digani
375446453e Update Compute Pipeline Creation API (#581)
* Add option to specialize metal functions on function constants
* Update Compute Pipeline Creation API
* Add options to make libraries from source and stitching
* Update function specialization name options
2024-01-30 15:42:36 -08:00
Angelos Katharopoulos
1895d34c20 Fix log1p with inf inputs (#592) 2024-01-30 14:02:50 -08:00
Angelos Katharopoulos
65d0b8df9f Fix binary op dispatch (#584) 2024-01-29 19:36:17 -08:00
Awni Hannun
3c2f192345 Propagate nans in binary ops (#579)
* propagate nans in binary ops

* handle empty matmul

* cpu minimum/maximum propagate nan

* benchmark maximum

* add min as well

* throw on negative indices with full

* verbose on linux

* fix matmul for zero K
2024-01-29 11:19:38 -08:00
Awni Hannun
8993382aaa Buffer Donation (#519)
* buffer donation

* fix to move shared pointer

* format

* gpu in place for copy and binary

* revert ops test

* cpu in place

* a little cleanup

* remove useless bench
2024-01-26 16:30:33 -08:00
taher
077c1ee64a QR factorization (#310)
* add qr factorization

---------

Co-authored-by: Awni Hannun <awni@apple.com>
2024-01-26 09:27:31 -08:00
Jagrit Digani
6d3bee3364 Fix oob reads in gemv kernel (#523) 2024-01-22 12:06:04 -08:00
Awni Hannun
7a34e46677 Quantize with groups of 32 (#511)
* allow quantize with group sizes of 32

* missing cpu dispatch

* remove print

* Fix qvm for group_size 32

---------

Co-authored-by: Angelos Katharopoulos <a_katharopoulos@apple.com>
2024-01-21 06:19:05 -08:00
Awni Hannun
3d99a8d31d Fix format / build (#489) 2024-01-18 10:01:59 -08:00
Ethan
a749a91c75 Support disable metal buffer cache to prevent performance degradation caused by large memory caching (#390)
* support disable metal buffer cache, due to large unused memory buffered when llm generated long context tokens

* Run format and add "cache_enabled" feature tests
2024-01-18 08:33:34 -08:00
Angelos Katharopoulos
90c234b7ac Fix round to round half-cases to even (#482) 2024-01-17 15:27:23 -08:00
Angelos Katharopoulos
135fd796d2 Fix detach for multi-output primitives (#480) 2024-01-17 14:08:07 -08:00
Jagrit Digani
78102a47ad Update GEMM (#424)
* Organize and collect metal subroutine templates and elements in `metal/kernels/steel/`
* Update gemm elements for better performance 
* Add split-K specialization for gemm
* Add `addmm` primitive, op and bindings for fused matmul and bias addition 
* Update tests and benchmarks as needed
2024-01-17 12:42:39 -08:00
Awni Hannun
275db7221a Command buffer reports errors (#479)
* command buffer reports errors

* typo

* simplify
2024-01-17 11:53:30 -08:00
Angelos Katharopoulos
d8fabaa12b Split multi output (#461)
* Multi-output split primitive
* Add the multi-output split to the ArrayIterator
* Add some grad tests for split
2024-01-16 13:33:55 -08:00
Awni Hannun
a2ffea683a Fix eye for larger matrices (#463)
* fix eye
* fix scatter for <32bit (non native atomic) types
* fix int overflow
2024-01-16 00:51:24 -08:00
Angelos Katharopoulos
c15fe3e61b Allow arbitrary first dimension in quantization kernels. (#458)
* Allow arbitrary first dim on qmm_t and qmv
* Allow arbitrary first dim on qmm and qvm
* Specialized aligned vs unaligned case
* Add more checks for valid quantizations
2024-01-16 00:46:21 -08:00
Awni Hannun
c9934fe8a4 Metal validation (#432)
* tests clear metal validation

* add cpp test with metal validation to circleci

* nit
2024-01-11 11:57:24 -08:00
Awni Hannun
f099ebe535 Multi output primitives (#330)
* Multi-output primitives

---------

Co-authored-by: Angelos Katharopoulos <a_katharopoulos@apple.com>
2024-01-08 16:39:08 -08:00
Nripesh Niketan
73321b8097 feat: add logicalAnd and logicalOR (#386)
* feat: add logicalAnd and logicalOR

* run pre-commit

* Refactor logical_and and logical_or functions

* Add acknowledgement

* Add logical AND and logical OR operators

* Refactor logical_and and logical_or functions

* Add support for logical operators on bool arrays

* Update mlx/ops.cpp

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>

* Update mlx/ops.cpp

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>

* Add logical AND and OR operators for arrays and scalars

* Refactor vjp and jvp methods in primitives.cpp

* Add overloaded operators for logical AND and OR

* format

---------

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>
Co-authored-by: Awni Hannun <awni@apple.com>
2024-01-08 07:00:05 -08:00
Angelos Katharopoulos
a611b0bc82 Removes the retain_graph flag (#385)
* Adds global tracing flag
* Removes retain_graph in favor of is_tracer
2024-01-07 15:16:51 -08:00
Awni Hannun
c6d2878c1a safely divide for 0 size inputs (#388) 2024-01-07 00:19:54 -08:00
Awni Hannun
b9e415d19c bump pre commit and fix format (#373) 2024-01-04 16:28:52 -08:00