Fast primitives decide when to use the fallback (#2216)

This commit is contained in:
Cheng 2025-06-03 05:26:37 +09:00 committed by GitHub
parent 95b7551d65
commit 1b021f6984
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 115 additions and 45 deletions

View File

@ -43,12 +43,29 @@ void Arange::eval_gpu(const std::vector<array>& inputs, array& out) {
}); });
} }
bool fast::ScaledDotProductAttention::use_fallback(
const array& q,
const array& k,
const array& v,
bool has_mask,
bool has_arr_mask,
bool do_causal,
Stream s) {
return true;
}
#define NO_GPU_MULTI(func) \ #define NO_GPU_MULTI(func) \
void func::eval_gpu( \ void func::eval_gpu( \
const std::vector<array>& inputs, std::vector<array>& outputs) { \ const std::vector<array>& inputs, std::vector<array>& outputs) { \
throw std::runtime_error(#func " has no CUDA implementation."); \ throw std::runtime_error(#func " has no CUDA implementation."); \
} }
#define NO_GPU_USE_FALLBACK(func) \
bool func::use_fallback(Stream s) { \
return true; \
} \
NO_GPU_MULTI(func)
#define NO_GPU(func) \ #define NO_GPU(func) \
void func::eval_gpu(const std::vector<array>& inputs, array& out) { \ void func::eval_gpu(const std::vector<array>& inputs, array& out) { \
throw std::runtime_error(#func " has no CUDA implementation."); \ throw std::runtime_error(#func " has no CUDA implementation."); \
@ -144,11 +161,11 @@ NO_GPU_MULTI(Eig)
NO_GPU_MULTI(Eigh) NO_GPU_MULTI(Eigh)
namespace fast { namespace fast {
NO_GPU_MULTI(LayerNorm) NO_GPU_USE_FALLBACK(LayerNorm)
NO_GPU_MULTI(LayerNormVJP) NO_GPU_MULTI(LayerNormVJP)
NO_GPU_MULTI(RMSNorm) NO_GPU_USE_FALLBACK(RMSNorm)
NO_GPU_MULTI(RMSNormVJP) NO_GPU_MULTI(RMSNormVJP)
NO_GPU_MULTI(RoPE) NO_GPU_USE_FALLBACK(RoPE)
NO_GPU(ScaledDotProductAttention) NO_GPU(ScaledDotProductAttention)
NO_GPU_MULTI(AffineQuantize) NO_GPU_MULTI(AffineQuantize)
NO_GPU_MULTI(CustomKernel) NO_GPU_MULTI(CustomKernel)

View File

@ -10,6 +10,10 @@
namespace mlx::core::fast { namespace mlx::core::fast {
bool RMSNorm::use_fallback(Stream s) {
return s.device == Device::cpu;
}
void RMSNorm::eval_gpu( void RMSNorm::eval_gpu(
const std::vector<array>& inputs, const std::vector<array>& inputs,
std::vector<array>& outputs) { std::vector<array>& outputs) {
@ -207,6 +211,10 @@ void RMSNormVJP::eval_gpu(
} }
} }
bool LayerNorm::use_fallback(Stream s) {
return s.device == Device::cpu;
}
void LayerNorm::eval_gpu( void LayerNorm::eval_gpu(
const std::vector<array>& inputs, const std::vector<array>& inputs,
std::vector<array>& outputs) { std::vector<array>& outputs) {

View File

@ -7,6 +7,10 @@ namespace mlx::core::fast {
constexpr int n_per_thread = 4; constexpr int n_per_thread = 4;
bool RoPE::use_fallback(Stream s) {
return s.device == Device::cpu;
}
void RoPE::eval_gpu( void RoPE::eval_gpu(
const std::vector<array>& inputs, const std::vector<array>& inputs,
std::vector<array>& outputs) { std::vector<array>& outputs) {

View File

@ -4,10 +4,10 @@
#include "mlx/backend/common/compiled.h" #include "mlx/backend/common/compiled.h"
#include "mlx/backend/gpu/copy.h" #include "mlx/backend/gpu/copy.h"
#include "mlx/backend/metal/device.h" #include "mlx/backend/metal/device.h"
#include "mlx/backend/metal/kernels/steel/attn/params.h" #include "mlx/backend/metal/kernels/steel/attn/params.h"
#include "mlx/backend/metal/utils.h" #include "mlx/backend/metal/utils.h"
#include "mlx/fast_primitives.h" #include "mlx/fast_primitives.h"
#include "mlx/transforms_impl.h"
#include "mlx/utils.h" #include "mlx/utils.h"
namespace mlx::core::fast { namespace mlx::core::fast {
@ -339,6 +339,46 @@ void sdpa_vector_2pass(
} // namespace } // namespace
bool ScaledDotProductAttention::use_fallback(
const array& q,
const array& k,
const array& v,
bool has_mask,
bool has_arr_mask,
bool do_causal,
Stream s) {
if (detail::in_grad_tracing()) {
return true;
}
if (s.device == Device::cpu) {
return true;
}
const int value_head_dim = v.shape(-1);
const int query_head_dim = q.shape(-1);
const int query_sequence_length = q.shape(2);
const int key_sequence_length = k.shape(2);
const bool sdpa_vector_supported_head_dim =
query_head_dim == value_head_dim &&
(query_head_dim == 64 || query_head_dim == 96 || query_head_dim == 128 ||
query_head_dim == 256);
const bool sdpa_full_supported_head_dim = query_head_dim == value_head_dim &&
(query_head_dim == 64 || query_head_dim == 80 || query_head_dim == 128);
const bool sdpa_full_supported_mask = !has_mask || has_arr_mask ||
(query_sequence_length <= key_sequence_length && do_causal);
const bool supports_sdpa_full =
sdpa_full_supported_mask && sdpa_full_supported_head_dim;
const bool supports_sdpa_vector = (query_sequence_length <= 8) &&
(query_sequence_length <= key_sequence_length) &&
sdpa_vector_supported_head_dim;
return !(supports_sdpa_full || supports_sdpa_vector);
}
void ScaledDotProductAttention::eval_gpu( void ScaledDotProductAttention::eval_gpu(
const std::vector<array>& inputs, const std::vector<array>& inputs,
array& out) { array& out) {

View File

@ -10,6 +10,12 @@
throw std::runtime_error(#func " has no GPU implementation."); \ throw std::runtime_error(#func " has no GPU implementation."); \
} }
#define NO_GPU_USE_FALLBACK(func) \
bool func::use_fallback(Stream s) { \
return true; \
} \
NO_GPU_MULTI(func)
#define NO_GPU(func) \ #define NO_GPU(func) \
void func::eval_gpu(const std::vector<array>& inputs, array& out) { \ void func::eval_gpu(const std::vector<array>& inputs, array& out) { \
throw std::runtime_error(#func " has no GPU implementation."); \ throw std::runtime_error(#func " has no GPU implementation."); \
@ -17,6 +23,17 @@
namespace mlx::core { namespace mlx::core {
bool fast::ScaledDotProductAttention::use_fallback(
const array& q,
const array& k,
const array& v,
bool has_mask,
bool has_arr_mask,
bool do_causal,
Stream s) {
return true;
}
NO_GPU(Abs) NO_GPU(Abs)
NO_GPU(Add) NO_GPU(Add)
NO_GPU(AddMM) NO_GPU(AddMM)
@ -130,11 +147,11 @@ NO_GPU_MULTI(Eig)
NO_GPU(View) NO_GPU(View)
namespace fast { namespace fast {
NO_GPU_MULTI(LayerNorm) NO_GPU_USE_FALLBACK(LayerNorm)
NO_GPU_MULTI(LayerNormVJP) NO_GPU_MULTI(LayerNormVJP)
NO_GPU_MULTI(RMSNorm) NO_GPU_USE_FALLBACK(RMSNorm)
NO_GPU_MULTI(RMSNormVJP) NO_GPU_MULTI(RMSNormVJP)
NO_GPU_MULTI(RoPE) NO_GPU_USE_FALLBACK(RoPE)
NO_GPU(ScaledDotProductAttention) NO_GPU(ScaledDotProductAttention)
NO_GPU_MULTI(AffineQuantize) NO_GPU_MULTI(AffineQuantize)
NO_GPU_MULTI(CustomKernel) NO_GPU_MULTI(CustomKernel)

View File

@ -9,7 +9,6 @@
#include "mlx/fast_primitives.h" #include "mlx/fast_primitives.h"
#include "mlx/ops.h" #include "mlx/ops.h"
#include "mlx/transforms.h" #include "mlx/transforms.h"
#include "mlx/transforms_impl.h"
namespace mlx::core::fast { namespace mlx::core::fast {
@ -112,7 +111,8 @@ array rms_norm(
auto passed_weight = auto passed_weight =
(has_weight) ? astype(*weight, out_type, s) : array(1, out_type); (has_weight) ? astype(*weight, out_type, s) : array(1, out_type);
if (s.device == Device::gpu) {
if (!RMSNorm::use_fallback(s)) {
return array( return array(
x.shape(), x.shape(),
out_type, out_type,
@ -256,7 +256,7 @@ array layer_norm(
auto passed_bias = auto passed_bias =
(has_bias) ? astype(*bias, out_type, s) : array(0, out_type); (has_bias) ? astype(*bias, out_type, s) : array(0, out_type);
if (s.device == Device::gpu) { if (!LayerNorm::use_fallback(s)) {
return array( return array(
x.shape(), x.shape(),
out_type, out_type,
@ -470,7 +470,7 @@ array rope(
} }
}; };
auto stream = to_stream(s); auto stream = to_stream(s);
if (stream.device == Device::gpu) { if (!RoPE::use_fallback(stream)) {
return array( return array(
x.shape(), x.shape(),
x.dtype(), x.dtype(),
@ -727,31 +727,6 @@ array scaled_dot_product_attention(
}; };
auto stream = to_stream(s); auto stream = to_stream(s);
const int value_head_dim = v.shape(-1);
const int query_head_dim = q.shape(-1);
const int query_sequence_length = q.shape(2);
const int key_sequence_length = k.shape(2);
const bool sdpa_vector_supported_head_dim =
query_head_dim == value_head_dim &&
(query_head_dim == 64 || query_head_dim == 96 || query_head_dim == 128 ||
query_head_dim == 256);
const bool sdpa_full_supported_head_dim = query_head_dim == value_head_dim &&
(query_head_dim == 64 || query_head_dim == 80 || query_head_dim == 128);
const bool sdpa_full_supported_mask = !has_mask || has_arr_mask ||
(query_sequence_length <= key_sequence_length && do_causal);
const bool supports_sdpa_full = sdpa_full_supported_mask &&
sdpa_full_supported_head_dim && stream.device == Device::gpu;
const bool supports_sdpa_vector = (query_sequence_length <= 8) &&
(query_sequence_length <= key_sequence_length) &&
sdpa_vector_supported_head_dim && stream.device == Device::gpu;
const bool implementation_supports_use_case =
supports_sdpa_full || supports_sdpa_vector;
std::vector<array> inputs = {q, k, v}; std::vector<array> inputs = {q, k, v};
if (has_arr_mask) { if (has_arr_mask) {
// Check type // Check type
@ -770,7 +745,8 @@ array scaled_dot_product_attention(
mask_shape.back() = keys.shape(-2); mask_shape.back() = keys.shape(-2);
inputs.push_back(broadcast_to(mask_arr, mask_shape, stream)); inputs.push_back(broadcast_to(mask_arr, mask_shape, stream));
} }
if (!detail::in_grad_tracing() && implementation_supports_use_case) { if (!ScaledDotProductAttention::use_fallback(
q, k, v, has_mask, has_arr_mask, do_causal, stream)) {
auto out_shape = Shape{q.shape(0), q.shape(1), q.shape(2), v.shape(-1)}; auto out_shape = Shape{q.shape(0), q.shape(1), q.shape(2), v.shape(-1)};
return array( return array(
std::move(out_shape), std::move(out_shape),
@ -779,7 +755,7 @@ array scaled_dot_product_attention(
stream, fallback, scale, do_causal), stream, fallback, scale, do_causal),
std::move(inputs)); std::move(inputs));
} }
return fallback(inputs)[0]; return fallback(std::move(inputs))[0];
} }
bool ScaledDotProductAttention::is_equivalent(const Primitive& other) const { bool ScaledDotProductAttention::is_equivalent(const Primitive& other) const {

View File

@ -43,6 +43,8 @@ class RMSNorm : public Custom {
float eps) float eps)
: Custom(stream, fallback), eps_(eps) {} : Custom(stream, fallback), eps_(eps) {}
static bool use_fallback(Stream stream);
void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs) void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs)
override { override {
throw std::runtime_error("NYI"); throw std::runtime_error("NYI");
@ -65,7 +67,6 @@ class RMSNorm : public Custom {
} }
private: private:
std::function<std::vector<array>(std::vector<array>)> fallback_;
float eps_; float eps_;
}; };
@ -91,7 +92,6 @@ class RMSNormVJP : public Custom {
} }
private: private:
std::function<std::vector<array>(std::vector<array>)> fallback_;
float eps_; float eps_;
}; };
@ -103,6 +103,8 @@ class LayerNorm : public Custom {
float eps) float eps)
: Custom(stream, fallback), eps_(eps) {} : Custom(stream, fallback), eps_(eps) {}
static bool use_fallback(Stream s);
void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs) void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs)
override { override {
throw std::runtime_error("NYI"); throw std::runtime_error("NYI");
@ -124,7 +126,6 @@ class LayerNorm : public Custom {
} }
private: private:
std::function<std::vector<array>(std::vector<array>)> fallback_;
float eps_; float eps_;
}; };
@ -150,7 +151,6 @@ class LayerNormVJP : public Custom {
} }
private: private:
std::function<std::vector<array>(std::vector<array>)> fallback_;
float eps_; float eps_;
}; };
@ -171,6 +171,8 @@ class RoPE : public Custom {
scale_(scale), scale_(scale),
forward_(forward) {} forward_(forward) {}
static bool use_fallback(Stream s);
void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs) void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs)
override { override {
throw std::runtime_error("NYI"); throw std::runtime_error("NYI");
@ -193,7 +195,6 @@ class RoPE : public Custom {
} }
private: private:
std::function<std::vector<array>(std::vector<array>)> fallback_;
int dims_; int dims_;
bool traditional_; bool traditional_;
float base_; float base_;
@ -210,6 +211,15 @@ class ScaledDotProductAttention : public Custom {
const bool do_causal) const bool do_causal)
: Custom(stream, fallback), scale_(scale), do_causal_(do_causal) {} : Custom(stream, fallback), scale_(scale), do_causal_(do_causal) {}
static bool use_fallback(
const array& q,
const array& k,
const array& v,
bool has_mask,
bool has_arr_mask,
bool do_causal,
Stream s);
void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs) void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs)
override { override {
throw std::runtime_error("NYI"); throw std::runtime_error("NYI");
@ -230,7 +240,6 @@ class ScaledDotProductAttention : public Custom {
} }
private: private:
std::function<std::vector<array>(std::vector<array>)> fallback_;
float scale_; float scale_;
bool do_causal_; bool do_causal_;
}; };
@ -263,7 +272,6 @@ class AffineQuantize : public Custom {
} }
private: private:
std::function<std::vector<array>(std::vector<array>)> fallback_;
int group_size_; int group_size_;
int bits_; int bits_;
bool dequantize_; bool dequantize_;