diff --git a/mlx/backend/cuda/primitives.cu b/mlx/backend/cuda/primitives.cu index d105a242b..11de02c8e 100644 --- a/mlx/backend/cuda/primitives.cu +++ b/mlx/backend/cuda/primitives.cu @@ -43,12 +43,29 @@ void Arange::eval_gpu(const std::vector& inputs, array& out) { }); } +bool fast::ScaledDotProductAttention::use_fallback( + const array& q, + const array& k, + const array& v, + bool has_mask, + bool has_arr_mask, + bool do_causal, + Stream s) { + return true; +} + #define NO_GPU_MULTI(func) \ void func::eval_gpu( \ const std::vector& inputs, std::vector& outputs) { \ throw std::runtime_error(#func " has no CUDA implementation."); \ } +#define NO_GPU_USE_FALLBACK(func) \ + bool func::use_fallback(Stream s) { \ + return true; \ + } \ + NO_GPU_MULTI(func) + #define NO_GPU(func) \ void func::eval_gpu(const std::vector& inputs, array& out) { \ throw std::runtime_error(#func " has no CUDA implementation."); \ @@ -144,11 +161,11 @@ NO_GPU_MULTI(Eig) NO_GPU_MULTI(Eigh) namespace fast { -NO_GPU_MULTI(LayerNorm) +NO_GPU_USE_FALLBACK(LayerNorm) NO_GPU_MULTI(LayerNormVJP) -NO_GPU_MULTI(RMSNorm) +NO_GPU_USE_FALLBACK(RMSNorm) NO_GPU_MULTI(RMSNormVJP) -NO_GPU_MULTI(RoPE) +NO_GPU_USE_FALLBACK(RoPE) NO_GPU(ScaledDotProductAttention) NO_GPU_MULTI(AffineQuantize) NO_GPU_MULTI(CustomKernel) diff --git a/mlx/backend/metal/normalization.cpp b/mlx/backend/metal/normalization.cpp index 21142183e..c0901ccec 100644 --- a/mlx/backend/metal/normalization.cpp +++ b/mlx/backend/metal/normalization.cpp @@ -10,6 +10,10 @@ namespace mlx::core::fast { +bool RMSNorm::use_fallback(Stream s) { + return s.device == Device::cpu; +} + void RMSNorm::eval_gpu( const std::vector& inputs, std::vector& outputs) { @@ -207,6 +211,10 @@ void RMSNormVJP::eval_gpu( } } +bool LayerNorm::use_fallback(Stream s) { + return s.device == Device::cpu; +} + void LayerNorm::eval_gpu( const std::vector& inputs, std::vector& outputs) { diff --git a/mlx/backend/metal/rope.cpp b/mlx/backend/metal/rope.cpp index d8201afe6..e141df630 100644 --- a/mlx/backend/metal/rope.cpp +++ b/mlx/backend/metal/rope.cpp @@ -7,6 +7,10 @@ namespace mlx::core::fast { constexpr int n_per_thread = 4; +bool RoPE::use_fallback(Stream s) { + return s.device == Device::cpu; +} + void RoPE::eval_gpu( const std::vector& inputs, std::vector& outputs) { diff --git a/mlx/backend/metal/scaled_dot_product_attention.cpp b/mlx/backend/metal/scaled_dot_product_attention.cpp index 3c7b7ff19..aad1a0018 100644 --- a/mlx/backend/metal/scaled_dot_product_attention.cpp +++ b/mlx/backend/metal/scaled_dot_product_attention.cpp @@ -4,10 +4,10 @@ #include "mlx/backend/common/compiled.h" #include "mlx/backend/gpu/copy.h" #include "mlx/backend/metal/device.h" - #include "mlx/backend/metal/kernels/steel/attn/params.h" #include "mlx/backend/metal/utils.h" #include "mlx/fast_primitives.h" +#include "mlx/transforms_impl.h" #include "mlx/utils.h" namespace mlx::core::fast { @@ -339,6 +339,46 @@ void sdpa_vector_2pass( } // namespace +bool ScaledDotProductAttention::use_fallback( + const array& q, + const array& k, + const array& v, + bool has_mask, + bool has_arr_mask, + bool do_causal, + Stream s) { + if (detail::in_grad_tracing()) { + return true; + } + if (s.device == Device::cpu) { + return true; + } + + const int value_head_dim = v.shape(-1); + const int query_head_dim = q.shape(-1); + const int query_sequence_length = q.shape(2); + const int key_sequence_length = k.shape(2); + + const bool sdpa_vector_supported_head_dim = + query_head_dim == value_head_dim && + (query_head_dim == 64 || query_head_dim == 96 || query_head_dim == 128 || + query_head_dim == 256); + const bool sdpa_full_supported_head_dim = query_head_dim == value_head_dim && + (query_head_dim == 64 || query_head_dim == 80 || query_head_dim == 128); + + const bool sdpa_full_supported_mask = !has_mask || has_arr_mask || + (query_sequence_length <= key_sequence_length && do_causal); + + const bool supports_sdpa_full = + sdpa_full_supported_mask && sdpa_full_supported_head_dim; + + const bool supports_sdpa_vector = (query_sequence_length <= 8) && + (query_sequence_length <= key_sequence_length) && + sdpa_vector_supported_head_dim; + + return !(supports_sdpa_full || supports_sdpa_vector); +} + void ScaledDotProductAttention::eval_gpu( const std::vector& inputs, array& out) { diff --git a/mlx/backend/no_gpu/primitives.cpp b/mlx/backend/no_gpu/primitives.cpp index 676a6e550..409aa2c89 100644 --- a/mlx/backend/no_gpu/primitives.cpp +++ b/mlx/backend/no_gpu/primitives.cpp @@ -10,6 +10,12 @@ throw std::runtime_error(#func " has no GPU implementation."); \ } +#define NO_GPU_USE_FALLBACK(func) \ + bool func::use_fallback(Stream s) { \ + return true; \ + } \ + NO_GPU_MULTI(func) + #define NO_GPU(func) \ void func::eval_gpu(const std::vector& inputs, array& out) { \ throw std::runtime_error(#func " has no GPU implementation."); \ @@ -17,6 +23,17 @@ namespace mlx::core { +bool fast::ScaledDotProductAttention::use_fallback( + const array& q, + const array& k, + const array& v, + bool has_mask, + bool has_arr_mask, + bool do_causal, + Stream s) { + return true; +} + NO_GPU(Abs) NO_GPU(Add) NO_GPU(AddMM) @@ -130,11 +147,11 @@ NO_GPU_MULTI(Eig) NO_GPU(View) namespace fast { -NO_GPU_MULTI(LayerNorm) +NO_GPU_USE_FALLBACK(LayerNorm) NO_GPU_MULTI(LayerNormVJP) -NO_GPU_MULTI(RMSNorm) +NO_GPU_USE_FALLBACK(RMSNorm) NO_GPU_MULTI(RMSNormVJP) -NO_GPU_MULTI(RoPE) +NO_GPU_USE_FALLBACK(RoPE) NO_GPU(ScaledDotProductAttention) NO_GPU_MULTI(AffineQuantize) NO_GPU_MULTI(CustomKernel) diff --git a/mlx/fast.cpp b/mlx/fast.cpp index c77b97de5..eab22f14d 100644 --- a/mlx/fast.cpp +++ b/mlx/fast.cpp @@ -9,7 +9,6 @@ #include "mlx/fast_primitives.h" #include "mlx/ops.h" #include "mlx/transforms.h" -#include "mlx/transforms_impl.h" namespace mlx::core::fast { @@ -112,7 +111,8 @@ array rms_norm( auto passed_weight = (has_weight) ? astype(*weight, out_type, s) : array(1, out_type); - if (s.device == Device::gpu) { + + if (!RMSNorm::use_fallback(s)) { return array( x.shape(), out_type, @@ -256,7 +256,7 @@ array layer_norm( auto passed_bias = (has_bias) ? astype(*bias, out_type, s) : array(0, out_type); - if (s.device == Device::gpu) { + if (!LayerNorm::use_fallback(s)) { return array( x.shape(), out_type, @@ -470,7 +470,7 @@ array rope( } }; auto stream = to_stream(s); - if (stream.device == Device::gpu) { + if (!RoPE::use_fallback(stream)) { return array( x.shape(), x.dtype(), @@ -727,31 +727,6 @@ array scaled_dot_product_attention( }; auto stream = to_stream(s); - const int value_head_dim = v.shape(-1); - const int query_head_dim = q.shape(-1); - const int query_sequence_length = q.shape(2); - const int key_sequence_length = k.shape(2); - - const bool sdpa_vector_supported_head_dim = - query_head_dim == value_head_dim && - (query_head_dim == 64 || query_head_dim == 96 || query_head_dim == 128 || - query_head_dim == 256); - const bool sdpa_full_supported_head_dim = query_head_dim == value_head_dim && - (query_head_dim == 64 || query_head_dim == 80 || query_head_dim == 128); - - const bool sdpa_full_supported_mask = !has_mask || has_arr_mask || - (query_sequence_length <= key_sequence_length && do_causal); - - const bool supports_sdpa_full = sdpa_full_supported_mask && - sdpa_full_supported_head_dim && stream.device == Device::gpu; - - const bool supports_sdpa_vector = (query_sequence_length <= 8) && - (query_sequence_length <= key_sequence_length) && - sdpa_vector_supported_head_dim && stream.device == Device::gpu; - - const bool implementation_supports_use_case = - supports_sdpa_full || supports_sdpa_vector; - std::vector inputs = {q, k, v}; if (has_arr_mask) { // Check type @@ -770,7 +745,8 @@ array scaled_dot_product_attention( mask_shape.back() = keys.shape(-2); inputs.push_back(broadcast_to(mask_arr, mask_shape, stream)); } - if (!detail::in_grad_tracing() && implementation_supports_use_case) { + if (!ScaledDotProductAttention::use_fallback( + q, k, v, has_mask, has_arr_mask, do_causal, stream)) { auto out_shape = Shape{q.shape(0), q.shape(1), q.shape(2), v.shape(-1)}; return array( std::move(out_shape), @@ -779,7 +755,7 @@ array scaled_dot_product_attention( stream, fallback, scale, do_causal), std::move(inputs)); } - return fallback(inputs)[0]; + return fallback(std::move(inputs))[0]; } bool ScaledDotProductAttention::is_equivalent(const Primitive& other) const { diff --git a/mlx/fast_primitives.h b/mlx/fast_primitives.h index 4d9e505ee..51050ea50 100644 --- a/mlx/fast_primitives.h +++ b/mlx/fast_primitives.h @@ -43,6 +43,8 @@ class RMSNorm : public Custom { float eps) : Custom(stream, fallback), eps_(eps) {} + static bool use_fallback(Stream stream); + void eval_cpu(const std::vector& inputs, std::vector& outputs) override { throw std::runtime_error("NYI"); @@ -65,7 +67,6 @@ class RMSNorm : public Custom { } private: - std::function(std::vector)> fallback_; float eps_; }; @@ -91,7 +92,6 @@ class RMSNormVJP : public Custom { } private: - std::function(std::vector)> fallback_; float eps_; }; @@ -103,6 +103,8 @@ class LayerNorm : public Custom { float eps) : Custom(stream, fallback), eps_(eps) {} + static bool use_fallback(Stream s); + void eval_cpu(const std::vector& inputs, std::vector& outputs) override { throw std::runtime_error("NYI"); @@ -124,7 +126,6 @@ class LayerNorm : public Custom { } private: - std::function(std::vector)> fallback_; float eps_; }; @@ -150,7 +151,6 @@ class LayerNormVJP : public Custom { } private: - std::function(std::vector)> fallback_; float eps_; }; @@ -171,6 +171,8 @@ class RoPE : public Custom { scale_(scale), forward_(forward) {} + static bool use_fallback(Stream s); + void eval_cpu(const std::vector& inputs, std::vector& outputs) override { throw std::runtime_error("NYI"); @@ -193,7 +195,6 @@ class RoPE : public Custom { } private: - std::function(std::vector)> fallback_; int dims_; bool traditional_; float base_; @@ -210,6 +211,15 @@ class ScaledDotProductAttention : public Custom { const bool do_causal) : Custom(stream, fallback), scale_(scale), do_causal_(do_causal) {} + static bool use_fallback( + const array& q, + const array& k, + const array& v, + bool has_mask, + bool has_arr_mask, + bool do_causal, + Stream s); + void eval_cpu(const std::vector& inputs, std::vector& outputs) override { throw std::runtime_error("NYI"); @@ -230,7 +240,6 @@ class ScaledDotProductAttention : public Custom { } private: - std::function(std::vector)> fallback_; float scale_; bool do_causal_; }; @@ -263,7 +272,6 @@ class AffineQuantize : public Custom { } private: - std::function(std::vector)> fallback_; int group_size_; int bits_; bool dequantize_;