Commit Graph

23 Commits

Author SHA1 Message Date
Awni Hannun
4fda5fbdf9
add python testing for cuda with ability to skip list of tests (#2295) 2025-06-15 10:56:48 -07:00
Awni Hannun
af705590ac
fix batched vector sdpa (#2152) 2025-05-05 13:13:03 -07:00
Angelos Katharopoulos
c4189a38e4
Add float mask to sdpa vector (#2068) 2025-04-11 17:29:40 -07:00
Awni Hannun
00794c42bc
Fix causal mask sdpa vec (#2053)
* fix sdpa vector causal mask

* test
2025-04-08 09:11:23 -07:00
Angelos Katharopoulos
ec2854b13a
Swap -inf for finite_minimum value (#2029) 2025-03-30 21:55:04 -07:00
Awni Hannun
05d7118561
causal vector sdpa (#2018)
* causal vector sdpa

* get rid of memory threshold
2025-03-28 12:36:13 -07:00
Awni Hannun
a84cc0123f
promote mask when needed (#1998) 2025-03-23 19:58:28 -07:00
Awni Hannun
005e7efa64
fix mask in sdpa (#1980)
* fix mask in sdpa

* fix attention mask

* Re-enable routing for array mask

---------

Co-authored-by: Jagrit Digani <digani@apple.com>
2025-03-20 14:53:12 -07:00
Jagrit Digani
b42d13ec84
Update attention tests to show diff, disable array masks (#1978) 2025-03-20 14:25:38 -07:00
Jagrit Digani
9adcd1a650
Support fused masking in Attention (#1924)
* Update API to allow mask='causal' in fast::sdpa

* Add fallback

* Update steel::AttnParams

* Fix typo

* WIP, basic causal

* Update tests

* Update benchmarking

* Update masking loop limits

* Add bool masking and update tests

* Update additive mask

* Update benchmarks

* Update benchmarks

* Update tests

* Update for bfloat error

* Update early exit

* Add random seed to tests
2025-03-20 11:01:32 -07:00
Awni Hannun
3c3e558c60
Support transposed head/seq for kv (#1950)
* support transposed head/seq for kv

* fix flaky test

* nit
2025-03-10 10:53:45 -07:00
Awni Hannun
e613d0eaf0
SDPA support for small batch (over sequence) queries (#1922)
* batch query sdpa

* batch sdpa for query
2025-03-04 10:59:04 -08:00
Angelos Katharopoulos
f5cc1eea72
Allow different value dimensions in sdpa_vector (#1811) 2025-01-31 20:58:59 -08:00
Awni Hannun
d1766f2c70
Add boolean mask support in vector SDPA (#1757) 2025-01-07 20:24:53 -08:00
Awni Hannun
d5ec172c95
Allow boolean mask in sdpa (#1753)
* allow boolean mask in sdpa

* more permissive donation in ternary
2025-01-06 16:57:07 -08:00
Awni Hannun
91c0277356
fix per-example mask + docs in sdpa (#1574) 2024-11-08 11:51:15 -08:00
Angelos Katharopoulos
62f297b51d
Sdpa fix (#1558) 2024-11-02 21:25:46 -07:00
Brian Keene
19fb69e2ed
Add memory_efficient_threshold kwarg to sdpa kernel (#1319)
Allows opt-in to memory efficient GPU shader at proscribed sequence
length.  Otherwise, utilizes aggregate MLX primitives for best latency.
2024-08-12 12:57:09 -07:00
Nikhil Mehta
0b7d71fd2f
Add softmin, hardshrink, hardtanh (#1180)
---------

Co-authored-by: Nikhil Mehta <nikmehta@tesla.com>
2024-06-04 15:48:18 -07:00
Brian Keene
1865299a30
Metal shaders for memory efficient self attention on large sequences (#964)
* Metal shaders for efficient self attention on large sequences

Updated fast attention: GEMM-ified with Steel primitives
Uses flash attention 1 for scale correction

* more compiler silencing

* Address rebase issues

* Templatize kernel instantiation, revise cpu bindings

* Safer writes to output

* Permit batch size > 1

* Numerical fixes for sdpa self attention

* Re-enable test, remove unused variable

* add benchmarking script

* Disable sdpa prior to perf tuning, and simplify tests for per-patch CI
2024-06-03 09:16:19 -07:00
Cheng
f30b659291
Make MLX build on x64 macOS (#901)
The arm64 macbook pros are heavy and I usually care my intel one for
mobile, it would be nice if I can play with MLX on it.

To build with x64, user must pass `MLX_ENABLE_X64_MAC` to cmake:
CMAKE_ARGS='-DMLX_ENABLE_X64_MAC=ON' python setup.py
2024-03-27 06:14:29 -07:00
Awni Hannun
859ae15a54
Fix test (#785) 2024-03-04 23:02:27 -08:00
Brian Keene
0787724c44
Fast Inference SDPA op (#735)
* Fast Inference SDPA op

Implements metal shaders for:

o = mx.fast_inference_sdpa(queries, keys, values, scale, mask)

Supports fp16, fp32 dtypes; assumes d_k = 128.

Generic op support / prompt encoding supported via mlx primitives.
Metal implementation is for the inference use case only.

Majority of performance benefits appears to results from GQA & reduced
bandwidth requirements; there is approximate performance parity for the
MHA use case (from some measurements on M3 Max).

* Flush shared memory to zero before unprotected reads for (scores @ values)

* Move to fast:: namespace, address reviewer comments

... also attempt to revert formatter auto-change for files not relevant
to this change

* Shared memory flush to top of kernel

* Resolve compiler warnings

* Update python/src/fast.cpp

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>

* Update python/src/fast.cpp

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>

* Update python/src/fast.cpp

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>

* Update python/src/fast.cpp

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>

* Update docstring per PR feedback

* Softmax in higher precision, ...

* route to fallback for more use cases - batch size > 1, head_dim other
  than 128, etc.
* Address linux build failure
* Address other reviewer comments

* Remove extraneous eval_cpu function per review

---------

Co-authored-by: Atila Orhon <64497909+atiorh@users.noreply.github.com>
Co-authored-by: Awni Hannun <awni.hannun@gmail.com>
Co-authored-by: atila <atiorh@icloud.com>
2024-03-04 21:06:11 -08:00