mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-26 10:41:14 +08:00
![]() * Fast Inference SDPA op Implements metal shaders for: o = mx.fast_inference_sdpa(queries, keys, values, scale, mask) Supports fp16, fp32 dtypes; assumes d_k = 128. Generic op support / prompt encoding supported via mlx primitives. Metal implementation is for the inference use case only. Majority of performance benefits appears to results from GQA & reduced bandwidth requirements; there is approximate performance parity for the MHA use case (from some measurements on M3 Max). * Flush shared memory to zero before unprotected reads for (scores @ values) * Move to fast:: namespace, address reviewer comments ... also attempt to revert formatter auto-change for files not relevant to this change * Shared memory flush to top of kernel * Resolve compiler warnings * Update python/src/fast.cpp Co-authored-by: Awni Hannun <awni.hannun@gmail.com> * Update python/src/fast.cpp Co-authored-by: Awni Hannun <awni.hannun@gmail.com> * Update python/src/fast.cpp Co-authored-by: Awni Hannun <awni.hannun@gmail.com> * Update python/src/fast.cpp Co-authored-by: Awni Hannun <awni.hannun@gmail.com> * Update docstring per PR feedback * Softmax in higher precision, ... * route to fallback for more use cases - batch size > 1, head_dim other than 128, etc. * Address linux build failure * Address other reviewer comments * Remove extraneous eval_cpu function per review --------- Co-authored-by: Atila Orhon <64497909+atiorh@users.noreply.github.com> Co-authored-by: Awni Hannun <awni.hannun@gmail.com> Co-authored-by: atila <atiorh@icloud.com> |
||
---|---|---|
.. | ||
mlx_tests.py | ||
test_array.py | ||
test_autograd.py | ||
test_bf16.py | ||
test_blas.py | ||
test_compile.py | ||
test_constants.py | ||
test_conv.py | ||
test_device.py | ||
test_eval.py | ||
test_fast_sdpa.py | ||
test_fast.py | ||
test_fft.py | ||
test_graph.py | ||
test_init.py | ||
test_linalg.py | ||
test_load.py | ||
test_losses.py | ||
test_metal.py | ||
test_nn.py | ||
test_ops.py | ||
test_optimizers.py | ||
test_quantized.py | ||
test_random.py | ||
test_reduce.py | ||
test_tree.py | ||
test_vmap.py |