mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Fast primitives decide when to use the fallback (#2216)
This commit is contained in:
@@ -43,12 +43,29 @@ void Arange::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
});
|
||||
}
|
||||
|
||||
bool fast::ScaledDotProductAttention::use_fallback(
|
||||
const array& q,
|
||||
const array& k,
|
||||
const array& v,
|
||||
bool has_mask,
|
||||
bool has_arr_mask,
|
||||
bool do_causal,
|
||||
Stream s) {
|
||||
return true;
|
||||
}
|
||||
|
||||
#define NO_GPU_MULTI(func) \
|
||||
void func::eval_gpu( \
|
||||
const std::vector<array>& inputs, std::vector<array>& outputs) { \
|
||||
throw std::runtime_error(#func " has no CUDA implementation."); \
|
||||
}
|
||||
|
||||
#define NO_GPU_USE_FALLBACK(func) \
|
||||
bool func::use_fallback(Stream s) { \
|
||||
return true; \
|
||||
} \
|
||||
NO_GPU_MULTI(func)
|
||||
|
||||
#define NO_GPU(func) \
|
||||
void func::eval_gpu(const std::vector<array>& inputs, array& out) { \
|
||||
throw std::runtime_error(#func " has no CUDA implementation."); \
|
||||
@@ -144,11 +161,11 @@ NO_GPU_MULTI(Eig)
|
||||
NO_GPU_MULTI(Eigh)
|
||||
|
||||
namespace fast {
|
||||
NO_GPU_MULTI(LayerNorm)
|
||||
NO_GPU_USE_FALLBACK(LayerNorm)
|
||||
NO_GPU_MULTI(LayerNormVJP)
|
||||
NO_GPU_MULTI(RMSNorm)
|
||||
NO_GPU_USE_FALLBACK(RMSNorm)
|
||||
NO_GPU_MULTI(RMSNormVJP)
|
||||
NO_GPU_MULTI(RoPE)
|
||||
NO_GPU_USE_FALLBACK(RoPE)
|
||||
NO_GPU(ScaledDotProductAttention)
|
||||
NO_GPU_MULTI(AffineQuantize)
|
||||
NO_GPU_MULTI(CustomKernel)
|
||||
|
||||
@@ -10,6 +10,10 @@
|
||||
|
||||
namespace mlx::core::fast {
|
||||
|
||||
bool RMSNorm::use_fallback(Stream s) {
|
||||
return s.device == Device::cpu;
|
||||
}
|
||||
|
||||
void RMSNorm::eval_gpu(
|
||||
const std::vector<array>& inputs,
|
||||
std::vector<array>& outputs) {
|
||||
@@ -207,6 +211,10 @@ void RMSNormVJP::eval_gpu(
|
||||
}
|
||||
}
|
||||
|
||||
bool LayerNorm::use_fallback(Stream s) {
|
||||
return s.device == Device::cpu;
|
||||
}
|
||||
|
||||
void LayerNorm::eval_gpu(
|
||||
const std::vector<array>& inputs,
|
||||
std::vector<array>& outputs) {
|
||||
|
||||
@@ -7,6 +7,10 @@ namespace mlx::core::fast {
|
||||
|
||||
constexpr int n_per_thread = 4;
|
||||
|
||||
bool RoPE::use_fallback(Stream s) {
|
||||
return s.device == Device::cpu;
|
||||
}
|
||||
|
||||
void RoPE::eval_gpu(
|
||||
const std::vector<array>& inputs,
|
||||
std::vector<array>& outputs) {
|
||||
|
||||
@@ -4,10 +4,10 @@
|
||||
#include "mlx/backend/common/compiled.h"
|
||||
#include "mlx/backend/gpu/copy.h"
|
||||
#include "mlx/backend/metal/device.h"
|
||||
|
||||
#include "mlx/backend/metal/kernels/steel/attn/params.h"
|
||||
#include "mlx/backend/metal/utils.h"
|
||||
#include "mlx/fast_primitives.h"
|
||||
#include "mlx/transforms_impl.h"
|
||||
#include "mlx/utils.h"
|
||||
|
||||
namespace mlx::core::fast {
|
||||
@@ -339,6 +339,46 @@ void sdpa_vector_2pass(
|
||||
|
||||
} // namespace
|
||||
|
||||
bool ScaledDotProductAttention::use_fallback(
|
||||
const array& q,
|
||||
const array& k,
|
||||
const array& v,
|
||||
bool has_mask,
|
||||
bool has_arr_mask,
|
||||
bool do_causal,
|
||||
Stream s) {
|
||||
if (detail::in_grad_tracing()) {
|
||||
return true;
|
||||
}
|
||||
if (s.device == Device::cpu) {
|
||||
return true;
|
||||
}
|
||||
|
||||
const int value_head_dim = v.shape(-1);
|
||||
const int query_head_dim = q.shape(-1);
|
||||
const int query_sequence_length = q.shape(2);
|
||||
const int key_sequence_length = k.shape(2);
|
||||
|
||||
const bool sdpa_vector_supported_head_dim =
|
||||
query_head_dim == value_head_dim &&
|
||||
(query_head_dim == 64 || query_head_dim == 96 || query_head_dim == 128 ||
|
||||
query_head_dim == 256);
|
||||
const bool sdpa_full_supported_head_dim = query_head_dim == value_head_dim &&
|
||||
(query_head_dim == 64 || query_head_dim == 80 || query_head_dim == 128);
|
||||
|
||||
const bool sdpa_full_supported_mask = !has_mask || has_arr_mask ||
|
||||
(query_sequence_length <= key_sequence_length && do_causal);
|
||||
|
||||
const bool supports_sdpa_full =
|
||||
sdpa_full_supported_mask && sdpa_full_supported_head_dim;
|
||||
|
||||
const bool supports_sdpa_vector = (query_sequence_length <= 8) &&
|
||||
(query_sequence_length <= key_sequence_length) &&
|
||||
sdpa_vector_supported_head_dim;
|
||||
|
||||
return !(supports_sdpa_full || supports_sdpa_vector);
|
||||
}
|
||||
|
||||
void ScaledDotProductAttention::eval_gpu(
|
||||
const std::vector<array>& inputs,
|
||||
array& out) {
|
||||
|
||||
@@ -10,6 +10,12 @@
|
||||
throw std::runtime_error(#func " has no GPU implementation."); \
|
||||
}
|
||||
|
||||
#define NO_GPU_USE_FALLBACK(func) \
|
||||
bool func::use_fallback(Stream s) { \
|
||||
return true; \
|
||||
} \
|
||||
NO_GPU_MULTI(func)
|
||||
|
||||
#define NO_GPU(func) \
|
||||
void func::eval_gpu(const std::vector<array>& inputs, array& out) { \
|
||||
throw std::runtime_error(#func " has no GPU implementation."); \
|
||||
@@ -17,6 +23,17 @@
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
bool fast::ScaledDotProductAttention::use_fallback(
|
||||
const array& q,
|
||||
const array& k,
|
||||
const array& v,
|
||||
bool has_mask,
|
||||
bool has_arr_mask,
|
||||
bool do_causal,
|
||||
Stream s) {
|
||||
return true;
|
||||
}
|
||||
|
||||
NO_GPU(Abs)
|
||||
NO_GPU(Add)
|
||||
NO_GPU(AddMM)
|
||||
@@ -130,11 +147,11 @@ NO_GPU_MULTI(Eig)
|
||||
NO_GPU(View)
|
||||
|
||||
namespace fast {
|
||||
NO_GPU_MULTI(LayerNorm)
|
||||
NO_GPU_USE_FALLBACK(LayerNorm)
|
||||
NO_GPU_MULTI(LayerNormVJP)
|
||||
NO_GPU_MULTI(RMSNorm)
|
||||
NO_GPU_USE_FALLBACK(RMSNorm)
|
||||
NO_GPU_MULTI(RMSNormVJP)
|
||||
NO_GPU_MULTI(RoPE)
|
||||
NO_GPU_USE_FALLBACK(RoPE)
|
||||
NO_GPU(ScaledDotProductAttention)
|
||||
NO_GPU_MULTI(AffineQuantize)
|
||||
NO_GPU_MULTI(CustomKernel)
|
||||
|
||||
Reference in New Issue
Block a user