2024-01-31 08:04:45 +08:00
|
|
|
// Copyright © 2023-2024 Apple Inc.
|
2023-12-01 03:12:53 +08:00
|
|
|
|
2023-11-30 02:52:08 +08:00
|
|
|
#include "mlx/primitives.h"
|
2024-08-27 06:12:50 +08:00
|
|
|
#include "mlx/distributed/primitives.h"
|
2024-02-17 11:16:39 +08:00
|
|
|
#include "mlx/fast_primitives.h"
|
2023-11-30 02:52:08 +08:00
|
|
|
|
2024-01-09 08:39:08 +08:00
|
|
|
#define NO_GPU_MULTI(func) \
|
|
|
|
void func::eval_gpu( \
|
|
|
|
const std::vector<array>& inputs, std::vector<array>& outputs) { \
|
|
|
|
throw std::runtime_error(#func " has no GPU implementation."); \
|
|
|
|
}
|
|
|
|
|
2025-06-03 04:26:37 +08:00
|
|
|
#define NO_GPU_USE_FALLBACK(func) \
|
|
|
|
bool func::use_fallback(Stream s) { \
|
|
|
|
return true; \
|
|
|
|
} \
|
|
|
|
NO_GPU_MULTI(func)
|
|
|
|
|
2023-11-30 02:52:08 +08:00
|
|
|
#define NO_GPU(func) \
|
|
|
|
void func::eval_gpu(const std::vector<array>& inputs, array& out) { \
|
|
|
|
throw std::runtime_error(#func " has no GPU implementation."); \
|
|
|
|
}
|
|
|
|
|
|
|
|
namespace mlx::core {
|
|
|
|
|
2025-06-03 04:26:37 +08:00
|
|
|
bool fast::ScaledDotProductAttention::use_fallback(
|
|
|
|
const array& q,
|
|
|
|
const array& k,
|
|
|
|
const array& v,
|
|
|
|
bool has_mask,
|
|
|
|
bool has_arr_mask,
|
|
|
|
bool do_causal,
|
|
|
|
Stream s) {
|
|
|
|
return true;
|
|
|
|
}
|
|
|
|
|
2023-11-30 02:52:08 +08:00
|
|
|
NO_GPU(Abs)
|
|
|
|
NO_GPU(Add)
|
2024-01-18 04:42:39 +08:00
|
|
|
NO_GPU(AddMM)
|
2023-11-30 02:52:08 +08:00
|
|
|
NO_GPU(Arange)
|
|
|
|
NO_GPU(ArcCos)
|
|
|
|
NO_GPU(ArcCosh)
|
|
|
|
NO_GPU(ArcSin)
|
|
|
|
NO_GPU(ArcSinh)
|
|
|
|
NO_GPU(ArcTan)
|
2024-05-08 23:35:15 +08:00
|
|
|
NO_GPU(ArcTan2)
|
2023-11-30 02:52:08 +08:00
|
|
|
NO_GPU(ArcTanh)
|
|
|
|
NO_GPU(ArgPartition)
|
|
|
|
NO_GPU(ArgReduce)
|
|
|
|
NO_GPU(ArgSort)
|
|
|
|
NO_GPU(AsType)
|
|
|
|
NO_GPU(AsStrided)
|
2024-04-27 13:03:42 +08:00
|
|
|
NO_GPU(BitwiseBinary)
|
2025-02-14 00:44:14 +08:00
|
|
|
NO_GPU(BitwiseInvert)
|
2024-05-03 05:03:58 +08:00
|
|
|
NO_GPU(BlockMaskedMM)
|
2023-11-30 02:52:08 +08:00
|
|
|
NO_GPU(Broadcast)
|
2025-01-10 03:04:24 +08:00
|
|
|
NO_GPU(BroadcastAxes)
|
2023-12-15 02:00:23 +08:00
|
|
|
NO_GPU(Ceil)
|
2024-02-05 22:51:22 +08:00
|
|
|
NO_GPU_MULTI(Compiled)
|
2023-11-30 02:52:08 +08:00
|
|
|
NO_GPU(Concatenate)
|
2024-05-10 22:22:20 +08:00
|
|
|
NO_GPU(Conjugate)
|
2024-11-22 11:51:49 +08:00
|
|
|
NO_GPU(Contiguous)
|
2023-11-30 02:52:08 +08:00
|
|
|
NO_GPU(Convolution)
|
|
|
|
NO_GPU(Copy)
|
|
|
|
NO_GPU(Cos)
|
|
|
|
NO_GPU(Cosh)
|
2024-07-11 09:00:01 +08:00
|
|
|
NO_GPU_MULTI(CustomTransforms)
|
2024-01-31 08:04:45 +08:00
|
|
|
NO_GPU_MULTI(Depends)
|
2023-11-30 02:52:08 +08:00
|
|
|
NO_GPU(Divide)
|
2024-02-05 22:51:22 +08:00
|
|
|
NO_GPU_MULTI(DivMod)
|
2025-01-08 06:02:16 +08:00
|
|
|
NO_GPU(DynamicSlice)
|
|
|
|
NO_GPU(DynamicSliceUpdate)
|
2024-03-14 01:34:14 +08:00
|
|
|
NO_GPU(NumberOfElements)
|
2023-12-09 07:08:52 +08:00
|
|
|
NO_GPU(Remainder)
|
2023-11-30 02:52:08 +08:00
|
|
|
NO_GPU(Equal)
|
|
|
|
NO_GPU(Erf)
|
|
|
|
NO_GPU(ErfInv)
|
|
|
|
NO_GPU(Exp)
|
2024-12-11 08:39:07 +08:00
|
|
|
NO_GPU(ExpandDims)
|
2024-04-09 05:26:01 +08:00
|
|
|
NO_GPU(Expm1)
|
2023-11-30 02:52:08 +08:00
|
|
|
NO_GPU(FFT)
|
2024-12-12 13:51:37 +08:00
|
|
|
NO_GPU(Flatten)
|
2023-12-15 02:00:23 +08:00
|
|
|
NO_GPU(Floor)
|
2023-11-30 02:52:08 +08:00
|
|
|
NO_GPU(Full)
|
|
|
|
NO_GPU(Gather)
|
2025-02-01 12:48:08 +08:00
|
|
|
NO_GPU(GatherAxis)
|
2024-05-22 22:48:34 +08:00
|
|
|
NO_GPU(GatherMM)
|
|
|
|
NO_GPU(GatherQMM)
|
2023-11-30 02:52:08 +08:00
|
|
|
NO_GPU(Greater)
|
|
|
|
NO_GPU(GreaterEqual)
|
2024-07-10 11:39:01 +08:00
|
|
|
NO_GPU(Hadamard)
|
2024-10-16 07:23:15 +08:00
|
|
|
NO_GPU(Imag)
|
2023-11-30 02:52:08 +08:00
|
|
|
NO_GPU(Less)
|
|
|
|
NO_GPU(LessEqual)
|
|
|
|
NO_GPU(Load)
|
|
|
|
NO_GPU(Log)
|
|
|
|
NO_GPU(Log1p)
|
|
|
|
NO_GPU(LogicalNot)
|
2024-01-08 23:00:05 +08:00
|
|
|
NO_GPU(LogicalAnd)
|
|
|
|
NO_GPU(LogicalOr)
|
2023-11-30 02:52:08 +08:00
|
|
|
NO_GPU(LogAddExp)
|
2025-03-31 22:36:55 +08:00
|
|
|
NO_GPU(LogSumExp)
|
2025-02-11 04:32:24 +08:00
|
|
|
NO_GPU_MULTI(LUF)
|
2023-11-30 02:52:08 +08:00
|
|
|
NO_GPU(Matmul)
|
|
|
|
NO_GPU(Maximum)
|
|
|
|
NO_GPU(Minimum)
|
|
|
|
NO_GPU(Multiply)
|
|
|
|
NO_GPU(Negative)
|
|
|
|
NO_GPU(NotEqual)
|
|
|
|
NO_GPU(Pad)
|
|
|
|
NO_GPU(Partition)
|
|
|
|
NO_GPU(Power)
|
2024-02-05 22:51:22 +08:00
|
|
|
NO_GPU_MULTI(QRF)
|
2023-12-19 15:18:57 +08:00
|
|
|
NO_GPU(QuantizedMatmul)
|
2023-11-30 02:52:08 +08:00
|
|
|
NO_GPU(RandomBits)
|
2024-10-16 07:23:15 +08:00
|
|
|
NO_GPU(Real)
|
2023-11-30 02:52:08 +08:00
|
|
|
NO_GPU(Reduce)
|
|
|
|
NO_GPU(Reshape)
|
2023-12-19 03:32:48 +08:00
|
|
|
NO_GPU(Round)
|
2023-11-30 02:52:08 +08:00
|
|
|
NO_GPU(Scan)
|
|
|
|
NO_GPU(Scatter)
|
2025-02-01 12:48:08 +08:00
|
|
|
NO_GPU(ScatterAxis)
|
2024-02-23 07:10:48 +08:00
|
|
|
NO_GPU(Select)
|
2023-11-30 02:52:08 +08:00
|
|
|
NO_GPU(Sigmoid)
|
|
|
|
NO_GPU(Sign)
|
|
|
|
NO_GPU(Sin)
|
|
|
|
NO_GPU(Sinh)
|
|
|
|
NO_GPU(Slice)
|
2024-03-21 01:39:25 +08:00
|
|
|
NO_GPU(SliceUpdate)
|
2023-11-30 02:52:08 +08:00
|
|
|
NO_GPU(Softmax)
|
|
|
|
NO_GPU(Sort)
|
2024-01-17 05:33:55 +08:00
|
|
|
NO_GPU_MULTI(Split)
|
2023-11-30 02:52:08 +08:00
|
|
|
NO_GPU(Square)
|
2024-12-11 08:39:07 +08:00
|
|
|
NO_GPU(Squeeze)
|
2023-11-30 02:52:08 +08:00
|
|
|
NO_GPU(Sqrt)
|
|
|
|
NO_GPU(StopGradient)
|
|
|
|
NO_GPU(Subtract)
|
2024-03-13 03:30:11 +08:00
|
|
|
NO_GPU_MULTI(SVD)
|
2023-11-30 02:52:08 +08:00
|
|
|
NO_GPU(Tan)
|
|
|
|
NO_GPU(Tanh)
|
|
|
|
NO_GPU(Transpose)
|
2024-12-12 13:51:37 +08:00
|
|
|
NO_GPU(Unflatten)
|
2024-03-15 21:34:36 +08:00
|
|
|
NO_GPU(Inverse)
|
2024-05-18 03:31:59 +08:00
|
|
|
NO_GPU(Cholesky)
|
2024-10-23 03:18:48 +08:00
|
|
|
NO_GPU_MULTI(Eigh)
|
2025-05-16 04:01:44 +08:00
|
|
|
NO_GPU_MULTI(Eig)
|
2024-06-04 23:05:27 +08:00
|
|
|
NO_GPU(View)
|
2024-02-05 22:51:22 +08:00
|
|
|
|
2024-02-15 06:04:25 +08:00
|
|
|
namespace fast {
|
2025-06-03 04:26:37 +08:00
|
|
|
NO_GPU_USE_FALLBACK(LayerNorm)
|
2024-03-27 07:35:34 +08:00
|
|
|
NO_GPU_MULTI(LayerNormVJP)
|
2025-06-03 04:26:37 +08:00
|
|
|
NO_GPU_USE_FALLBACK(RMSNorm)
|
2024-03-27 07:35:34 +08:00
|
|
|
NO_GPU_MULTI(RMSNormVJP)
|
2025-06-03 04:26:37 +08:00
|
|
|
NO_GPU_USE_FALLBACK(RoPE)
|
2024-03-05 13:06:11 +08:00
|
|
|
NO_GPU(ScaledDotProductAttention)
|
2024-07-30 06:11:38 +08:00
|
|
|
NO_GPU_MULTI(AffineQuantize)
|
2024-08-23 04:46:29 +08:00
|
|
|
NO_GPU_MULTI(CustomKernel)
|
2024-02-15 06:04:25 +08:00
|
|
|
} // namespace fast
|
|
|
|
|
2024-08-27 06:12:50 +08:00
|
|
|
namespace distributed {
|
|
|
|
NO_GPU_MULTI(AllReduce)
|
|
|
|
NO_GPU_MULTI(AllGather)
|
2024-08-27 14:01:37 +08:00
|
|
|
NO_GPU_MULTI(Send)
|
|
|
|
NO_GPU_MULTI(Recv)
|
2024-08-27 06:12:50 +08:00
|
|
|
} // namespace distributed
|
|
|
|
|
2023-11-30 02:52:08 +08:00
|
|
|
} // namespace mlx::core
|