mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-25 01:41:17 +08:00

* Fast Inference SDPA op Implements metal shaders for: o = mx.fast_inference_sdpa(queries, keys, values, scale, mask) Supports fp16, fp32 dtypes; assumes d_k = 128. Generic op support / prompt encoding supported via mlx primitives. Metal implementation is for the inference use case only. Majority of performance benefits appears to results from GQA & reduced bandwidth requirements; there is approximate performance parity for the MHA use case (from some measurements on M3 Max). * Flush shared memory to zero before unprotected reads for (scores @ values) * Move to fast:: namespace, address reviewer comments ... also attempt to revert formatter auto-change for files not relevant to this change * Shared memory flush to top of kernel * Resolve compiler warnings * Update python/src/fast.cpp Co-authored-by: Awni Hannun <awni.hannun@gmail.com> * Update python/src/fast.cpp Co-authored-by: Awni Hannun <awni.hannun@gmail.com> * Update python/src/fast.cpp Co-authored-by: Awni Hannun <awni.hannun@gmail.com> * Update python/src/fast.cpp Co-authored-by: Awni Hannun <awni.hannun@gmail.com> * Update docstring per PR feedback * Softmax in higher precision, ... * route to fallback for more use cases - batch size > 1, head_dim other than 128, etc. * Address linux build failure * Address other reviewer comments * Remove extraneous eval_cpu function per review --------- Co-authored-by: Atila Orhon <64497909+atiorh@users.noreply.github.com> Co-authored-by: Awni Hannun <awni.hannun@gmail.com> Co-authored-by: atila <atiorh@icloud.com>
106 lines
2.0 KiB
C++
106 lines
2.0 KiB
C++
// Copyright © 2023-2024 Apple Inc.
|
|
|
|
#include "mlx/primitives.h"
|
|
#include "mlx/fast_primitives.h"
|
|
|
|
#define NO_GPU_MULTI(func) \
|
|
void func::eval_gpu( \
|
|
const std::vector<array>& inputs, std::vector<array>& outputs) { \
|
|
throw std::runtime_error(#func " has no GPU implementation."); \
|
|
}
|
|
|
|
#define NO_GPU(func) \
|
|
void func::eval_gpu(const std::vector<array>& inputs, array& out) { \
|
|
throw std::runtime_error(#func " has no GPU implementation."); \
|
|
}
|
|
|
|
namespace mlx::core {
|
|
|
|
NO_GPU(Abs)
|
|
NO_GPU(Add)
|
|
NO_GPU(AddMM)
|
|
NO_GPU(Arange)
|
|
NO_GPU(ArcCos)
|
|
NO_GPU(ArcCosh)
|
|
NO_GPU(ArcSin)
|
|
NO_GPU(ArcSinh)
|
|
NO_GPU(ArcTan)
|
|
NO_GPU(ArcTanh)
|
|
NO_GPU(ArgPartition)
|
|
NO_GPU(ArgReduce)
|
|
NO_GPU(ArgSort)
|
|
NO_GPU(AsType)
|
|
NO_GPU(AsStrided)
|
|
NO_GPU(Broadcast)
|
|
NO_GPU(Ceil)
|
|
NO_GPU_MULTI(Compiled)
|
|
NO_GPU(Concatenate)
|
|
NO_GPU(Convolution)
|
|
NO_GPU(Copy)
|
|
NO_GPU(Cos)
|
|
NO_GPU(Cosh)
|
|
NO_GPU_MULTI(CustomVJP)
|
|
NO_GPU_MULTI(Depends)
|
|
NO_GPU(Divide)
|
|
NO_GPU_MULTI(DivMod)
|
|
NO_GPU(Remainder)
|
|
NO_GPU(Equal)
|
|
NO_GPU(Erf)
|
|
NO_GPU(ErfInv)
|
|
NO_GPU(Exp)
|
|
NO_GPU(FFT)
|
|
NO_GPU(Floor)
|
|
NO_GPU(Full)
|
|
NO_GPU(Gather)
|
|
NO_GPU(Greater)
|
|
NO_GPU(GreaterEqual)
|
|
NO_GPU(Less)
|
|
NO_GPU(LessEqual)
|
|
NO_GPU(Load)
|
|
NO_GPU(Log)
|
|
NO_GPU(Log1p)
|
|
NO_GPU(LogicalNot)
|
|
NO_GPU(LogicalAnd)
|
|
NO_GPU(LogicalOr)
|
|
NO_GPU(LogAddExp)
|
|
NO_GPU(Matmul)
|
|
NO_GPU(Maximum)
|
|
NO_GPU(Minimum)
|
|
NO_GPU(Multiply)
|
|
NO_GPU(Negative)
|
|
NO_GPU(NotEqual)
|
|
NO_GPU(Pad)
|
|
NO_GPU(Partition)
|
|
NO_GPU(Power)
|
|
NO_GPU_MULTI(QRF)
|
|
NO_GPU(QuantizedMatmul)
|
|
NO_GPU(RandomBits)
|
|
NO_GPU(Reduce)
|
|
NO_GPU(Reshape)
|
|
NO_GPU(Round)
|
|
NO_GPU(Scan)
|
|
NO_GPU(Scatter)
|
|
NO_GPU(Select)
|
|
NO_GPU(Sigmoid)
|
|
NO_GPU(Sign)
|
|
NO_GPU(Sin)
|
|
NO_GPU(Sinh)
|
|
NO_GPU(Slice)
|
|
NO_GPU(Softmax)
|
|
NO_GPU(Sort)
|
|
NO_GPU_MULTI(Split)
|
|
NO_GPU(Square)
|
|
NO_GPU(Sqrt)
|
|
NO_GPU(StopGradient)
|
|
NO_GPU(Subtract)
|
|
NO_GPU(Tan)
|
|
NO_GPU(Tanh)
|
|
NO_GPU(Transpose)
|
|
|
|
namespace fast {
|
|
NO_GPU_MULTI(RoPE)
|
|
NO_GPU(ScaledDotProductAttention)
|
|
} // namespace fast
|
|
|
|
} // namespace mlx::core
|