* Metal shaders for efficient self attention on large sequences
Updated fast attention: GEMM-ified with Steel primitives
Uses flash attention 1 for scale correction
* more compiler silencing
* Address rebase issues
* Templatize kernel instantiation, revise cpu bindings
* Safer writes to output
* Permit batch size > 1
* Numerical fixes for sdpa self attention
* Re-enable test, remove unused variable
* add benchmarking script
* Disable sdpa prior to perf tuning, and simplify tests for per-patch CI
The arm64 macbook pros are heavy and I usually care my intel one for
mobile, it would be nice if I can play with MLX on it.
To build with x64, user must pass `MLX_ENABLE_X64_MAC` to cmake:
CMAKE_ARGS='-DMLX_ENABLE_X64_MAC=ON' python setup.py
* Fast Inference SDPA op
Implements metal shaders for:
o = mx.fast_inference_sdpa(queries, keys, values, scale, mask)
Supports fp16, fp32 dtypes; assumes d_k = 128.
Generic op support / prompt encoding supported via mlx primitives.
Metal implementation is for the inference use case only.
Majority of performance benefits appears to results from GQA & reduced
bandwidth requirements; there is approximate performance parity for the
MHA use case (from some measurements on M3 Max).
* Flush shared memory to zero before unprotected reads for (scores @ values)
* Move to fast:: namespace, address reviewer comments
... also attempt to revert formatter auto-change for files not relevant
to this change
* Shared memory flush to top of kernel
* Resolve compiler warnings
* Update python/src/fast.cpp
Co-authored-by: Awni Hannun <awni.hannun@gmail.com>
* Update python/src/fast.cpp
Co-authored-by: Awni Hannun <awni.hannun@gmail.com>
* Update python/src/fast.cpp
Co-authored-by: Awni Hannun <awni.hannun@gmail.com>
* Update python/src/fast.cpp
Co-authored-by: Awni Hannun <awni.hannun@gmail.com>
* Update docstring per PR feedback
* Softmax in higher precision, ...
* route to fallback for more use cases - batch size > 1, head_dim other
than 128, etc.
* Address linux build failure
* Address other reviewer comments
* Remove extraneous eval_cpu function per review
---------
Co-authored-by: Atila Orhon <64497909+atiorh@users.noreply.github.com>
Co-authored-by: Awni Hannun <awni.hannun@gmail.com>
Co-authored-by: atila <atiorh@icloud.com>