mlx/mlx/backend/no_gpu/primitives.cpp

168 lines
3.3 KiB
C++
Raw Normal View History

// Copyright © 2023-2024 Apple Inc.
2023-12-01 03:12:53 +08:00
2023-11-30 02:52:08 +08:00
#include "mlx/primitives.h"
#include "mlx/distributed/primitives.h"
#include "mlx/fast_primitives.h"
2023-11-30 02:52:08 +08:00
#define NO_GPU_MULTI(func) \
void func::eval_gpu( \
const std::vector<array>& inputs, std::vector<array>& outputs) { \
throw std::runtime_error(#func " has no GPU implementation."); \
}
#define NO_GPU_USE_FALLBACK(func) \
bool func::use_fallback(Stream s) { \
return true; \
} \
NO_GPU_MULTI(func)
2023-11-30 02:52:08 +08:00
#define NO_GPU(func) \
void func::eval_gpu(const std::vector<array>& inputs, array& out) { \
throw std::runtime_error(#func " has no GPU implementation."); \
}
namespace mlx::core {
bool fast::ScaledDotProductAttention::use_fallback(
const array& q,
const array& k,
const array& v,
bool has_mask,
bool has_arr_mask,
bool do_causal,
Stream s) {
return true;
}
2023-11-30 02:52:08 +08:00
NO_GPU(Abs)
NO_GPU(Add)
NO_GPU(AddMM)
2023-11-30 02:52:08 +08:00
NO_GPU(Arange)
NO_GPU(ArcCos)
NO_GPU(ArcCosh)
NO_GPU(ArcSin)
NO_GPU(ArcSinh)
NO_GPU(ArcTan)
NO_GPU(ArcTan2)
2023-11-30 02:52:08 +08:00
NO_GPU(ArcTanh)
NO_GPU(ArgPartition)
NO_GPU(ArgReduce)
NO_GPU(ArgSort)
NO_GPU(AsType)
NO_GPU(AsStrided)
NO_GPU(BitwiseBinary)
NO_GPU(BitwiseInvert)
2024-05-03 05:03:58 +08:00
NO_GPU(BlockMaskedMM)
2023-11-30 02:52:08 +08:00
NO_GPU(Broadcast)
NO_GPU(BroadcastAxes)
NO_GPU(Ceil)
NO_GPU_MULTI(Compiled)
2023-11-30 02:52:08 +08:00
NO_GPU(Concatenate)
NO_GPU(Conjugate)
2024-11-22 11:51:49 +08:00
NO_GPU(Contiguous)
2023-11-30 02:52:08 +08:00
NO_GPU(Convolution)
NO_GPU(Copy)
NO_GPU(Cos)
NO_GPU(Cosh)
2024-07-11 09:00:01 +08:00
NO_GPU_MULTI(CustomTransforms)
NO_GPU_MULTI(Depends)
2023-11-30 02:52:08 +08:00
NO_GPU(Divide)
NO_GPU_MULTI(DivMod)
NO_GPU(DynamicSlice)
NO_GPU(DynamicSliceUpdate)
NO_GPU(NumberOfElements)
NO_GPU(Remainder)
2023-11-30 02:52:08 +08:00
NO_GPU(Equal)
NO_GPU(Erf)
NO_GPU(ErfInv)
NO_GPU(Exp)
NO_GPU(ExpandDims)
NO_GPU(Expm1)
2023-11-30 02:52:08 +08:00
NO_GPU(FFT)
NO_GPU(Flatten)
NO_GPU(Floor)
2023-11-30 02:52:08 +08:00
NO_GPU(Full)
NO_GPU(Gather)
NO_GPU(GatherAxis)
NO_GPU(GatherMM)
NO_GPU(GatherQMM)
2023-11-30 02:52:08 +08:00
NO_GPU(Greater)
NO_GPU(GreaterEqual)
NO_GPU(Hadamard)
NO_GPU(Imag)
2023-11-30 02:52:08 +08:00
NO_GPU(Less)
NO_GPU(LessEqual)
NO_GPU(Load)
NO_GPU(Log)
NO_GPU(Log1p)
NO_GPU(LogicalNot)
NO_GPU(LogicalAnd)
NO_GPU(LogicalOr)
2023-11-30 02:52:08 +08:00
NO_GPU(LogAddExp)
NO_GPU(LogSumExp)
NO_GPU_MULTI(LUF)
2023-11-30 02:52:08 +08:00
NO_GPU(Matmul)
NO_GPU(Maximum)
NO_GPU(Minimum)
NO_GPU(Multiply)
NO_GPU(Negative)
NO_GPU(NotEqual)
NO_GPU(Pad)
NO_GPU(Partition)
NO_GPU(Power)
NO_GPU_MULTI(QRF)
NO_GPU(QuantizedMatmul)
2023-11-30 02:52:08 +08:00
NO_GPU(RandomBits)
NO_GPU(Real)
2023-11-30 02:52:08 +08:00
NO_GPU(Reduce)
NO_GPU(Reshape)
2023-12-19 03:32:48 +08:00
NO_GPU(Round)
2023-11-30 02:52:08 +08:00
NO_GPU(Scan)
NO_GPU(Scatter)
NO_GPU(ScatterAxis)
NO_GPU(Select)
2023-11-30 02:52:08 +08:00
NO_GPU(Sigmoid)
NO_GPU(Sign)
NO_GPU(Sin)
NO_GPU(Sinh)
NO_GPU(Slice)
NO_GPU(SliceUpdate)
2023-11-30 02:52:08 +08:00
NO_GPU(Softmax)
NO_GPU(Sort)
NO_GPU_MULTI(Split)
2023-11-30 02:52:08 +08:00
NO_GPU(Square)
NO_GPU(Squeeze)
2023-11-30 02:52:08 +08:00
NO_GPU(Sqrt)
NO_GPU(StopGradient)
NO_GPU(Subtract)
NO_GPU_MULTI(SVD)
2023-11-30 02:52:08 +08:00
NO_GPU(Tan)
NO_GPU(Tanh)
NO_GPU(Transpose)
NO_GPU(Unflatten)
2024-03-15 21:34:36 +08:00
NO_GPU(Inverse)
2024-05-18 03:31:59 +08:00
NO_GPU(Cholesky)
NO_GPU_MULTI(Eigh)
2025-05-16 04:01:44 +08:00
NO_GPU_MULTI(Eig)
NO_GPU(View)
namespace fast {
NO_GPU_USE_FALLBACK(LayerNorm)
NO_GPU_MULTI(LayerNormVJP)
NO_GPU_USE_FALLBACK(RMSNorm)
NO_GPU_MULTI(RMSNormVJP)
NO_GPU_USE_FALLBACK(RoPE)
Fast Inference SDPA op (#735) * Fast Inference SDPA op Implements metal shaders for: o = mx.fast_inference_sdpa(queries, keys, values, scale, mask) Supports fp16, fp32 dtypes; assumes d_k = 128. Generic op support / prompt encoding supported via mlx primitives. Metal implementation is for the inference use case only. Majority of performance benefits appears to results from GQA & reduced bandwidth requirements; there is approximate performance parity for the MHA use case (from some measurements on M3 Max). * Flush shared memory to zero before unprotected reads for (scores @ values) * Move to fast:: namespace, address reviewer comments ... also attempt to revert formatter auto-change for files not relevant to this change * Shared memory flush to top of kernel * Resolve compiler warnings * Update python/src/fast.cpp Co-authored-by: Awni Hannun <awni.hannun@gmail.com> * Update python/src/fast.cpp Co-authored-by: Awni Hannun <awni.hannun@gmail.com> * Update python/src/fast.cpp Co-authored-by: Awni Hannun <awni.hannun@gmail.com> * Update python/src/fast.cpp Co-authored-by: Awni Hannun <awni.hannun@gmail.com> * Update docstring per PR feedback * Softmax in higher precision, ... * route to fallback for more use cases - batch size > 1, head_dim other than 128, etc. * Address linux build failure * Address other reviewer comments * Remove extraneous eval_cpu function per review --------- Co-authored-by: Atila Orhon <64497909+atiorh@users.noreply.github.com> Co-authored-by: Awni Hannun <awni.hannun@gmail.com> Co-authored-by: atila <atiorh@icloud.com>
2024-03-05 13:06:11 +08:00
NO_GPU(ScaledDotProductAttention)
NO_GPU_MULTI(AffineQuantize)
NO_GPU_MULTI(CustomKernel)
} // namespace fast
namespace distributed {
NO_GPU_MULTI(AllReduce)
NO_GPU_MULTI(AllGather)
NO_GPU_MULTI(Send)
NO_GPU_MULTI(Recv)
} // namespace distributed
2023-11-30 02:52:08 +08:00
} // namespace mlx::core