mlx/python/tests
Brian Keene 0787724c44
Fast Inference SDPA op (#735)
* Fast Inference SDPA op

Implements metal shaders for:

o = mx.fast_inference_sdpa(queries, keys, values, scale, mask)

Supports fp16, fp32 dtypes; assumes d_k = 128.

Generic op support / prompt encoding supported via mlx primitives.
Metal implementation is for the inference use case only.

Majority of performance benefits appears to results from GQA & reduced
bandwidth requirements; there is approximate performance parity for the
MHA use case (from some measurements on M3 Max).

* Flush shared memory to zero before unprotected reads for (scores @ values)

* Move to fast:: namespace, address reviewer comments

... also attempt to revert formatter auto-change for files not relevant
to this change

* Shared memory flush to top of kernel

* Resolve compiler warnings

* Update python/src/fast.cpp

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>

* Update python/src/fast.cpp

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>

* Update python/src/fast.cpp

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>

* Update python/src/fast.cpp

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>

* Update docstring per PR feedback

* Softmax in higher precision, ...

* route to fallback for more use cases - batch size > 1, head_dim other
  than 128, etc.
* Address linux build failure
* Address other reviewer comments

* Remove extraneous eval_cpu function per review

---------

Co-authored-by: Atila Orhon <64497909+atiorh@users.noreply.github.com>
Co-authored-by: Awni Hannun <awni.hannun@gmail.com>
Co-authored-by: atila <atiorh@icloud.com>
2024-03-04 21:06:11 -08:00
..
mlx_tests.py Buffer Donation (#519) 2024-01-26 16:30:33 -08:00
test_array.py Pickle + dtype fix for numpy conversion (#763) 2024-03-02 06:09:29 -08:00
test_autograd.py Fix logsumexp edge case (#740) 2024-02-25 08:39:55 -08:00
test_bf16.py bump pre commit and fix format (#373) 2024-01-04 16:28:52 -08:00
test_blas.py minor fixes (#631) 2024-02-05 13:27:49 -08:00
test_compile.py Fix compile with non standard types (#745) 2024-02-26 19:28:53 -08:00
test_constants.py feat: Add numpy constants (#428) 2024-01-11 06:47:29 -08:00
test_conv.py Convolution update (#651) 2024-02-28 20:11:16 -08:00
test_device.py Adds device context manager (#679) 2024-02-14 14:14:58 -08:00
test_eval.py Compile with capture (#629) 2024-02-07 17:29:22 -08:00
test_fast_sdpa.py Fast Inference SDPA op (#735) 2024-03-04 21:06:11 -08:00
test_fast.py Custom primitive + RoPE fat op (#676) 2024-02-14 14:04:25 -08:00
test_fft.py Adds device context manager (#679) 2024-02-14 14:14:58 -08:00
test_graph.py Multi output primitives (#330) 2024-01-08 16:39:08 -08:00
test_init.py Make shape a tuple (#591) 2024-01-30 13:11:01 -08:00
test_linalg.py QR factorization (#310) 2024-01-26 09:27:31 -08:00
test_load.py Fix logsumexp edge case (#740) 2024-02-25 08:39:55 -08:00
test_losses.py Feat: Add weights argument in BCE Loss and tests (#620) 2024-02-07 09:39:52 -08:00
test_metal.py bindings for memory info (#761) 2024-03-01 19:51:58 -08:00
test_nn.py Dilation for convolutional layers (#766) 2024-03-04 06:43:00 -08:00
test_ops.py Fix the top-k op (#768) 2024-03-01 22:08:43 -08:00
test_optimizers.py Add linear warmup and schedule joining for use with existing schedules (#721) 2024-02-26 07:28:48 -08:00
test_quantized.py Quantized matmul fix (#677) 2024-02-12 18:54:21 -08:00
test_random.py Add loc and scale to random.normal (#638) 2024-02-07 11:49:59 -08:00
test_reduce.py Add GPU support for uint64/int64 reductions (#569) 2024-01-31 11:18:04 -08:00
test_tree.py Add isort pre-commit and run (#68) 2023-12-08 11:31:47 -08:00
test_vmap.py Reduce vmap + some fixes (#601) 2024-02-01 11:30:28 -08:00