Fast primitives decide when to use the fallback (#2216)

This commit is contained in:
Cheng
2025-06-03 05:26:37 +09:00
committed by GitHub
parent 95b7551d65
commit 1b021f6984
7 changed files with 115 additions and 45 deletions

View File

@@ -10,6 +10,10 @@
namespace mlx::core::fast {
bool RMSNorm::use_fallback(Stream s) {
return s.device == Device::cpu;
}
void RMSNorm::eval_gpu(
const std::vector<array>& inputs,
std::vector<array>& outputs) {
@@ -207,6 +211,10 @@ void RMSNormVJP::eval_gpu(
}
}
bool LayerNorm::use_fallback(Stream s) {
return s.device == Device::cpu;
}
void LayerNorm::eval_gpu(
const std::vector<array>& inputs,
std::vector<array>& outputs) {

View File

@@ -7,6 +7,10 @@ namespace mlx::core::fast {
constexpr int n_per_thread = 4;
bool RoPE::use_fallback(Stream s) {
return s.device == Device::cpu;
}
void RoPE::eval_gpu(
const std::vector<array>& inputs,
std::vector<array>& outputs) {

View File

@@ -4,10 +4,10 @@
#include "mlx/backend/common/compiled.h"
#include "mlx/backend/gpu/copy.h"
#include "mlx/backend/metal/device.h"
#include "mlx/backend/metal/kernels/steel/attn/params.h"
#include "mlx/backend/metal/utils.h"
#include "mlx/fast_primitives.h"
#include "mlx/transforms_impl.h"
#include "mlx/utils.h"
namespace mlx::core::fast {
@@ -339,6 +339,46 @@ void sdpa_vector_2pass(
} // namespace
bool ScaledDotProductAttention::use_fallback(
const array& q,
const array& k,
const array& v,
bool has_mask,
bool has_arr_mask,
bool do_causal,
Stream s) {
if (detail::in_grad_tracing()) {
return true;
}
if (s.device == Device::cpu) {
return true;
}
const int value_head_dim = v.shape(-1);
const int query_head_dim = q.shape(-1);
const int query_sequence_length = q.shape(2);
const int key_sequence_length = k.shape(2);
const bool sdpa_vector_supported_head_dim =
query_head_dim == value_head_dim &&
(query_head_dim == 64 || query_head_dim == 96 || query_head_dim == 128 ||
query_head_dim == 256);
const bool sdpa_full_supported_head_dim = query_head_dim == value_head_dim &&
(query_head_dim == 64 || query_head_dim == 80 || query_head_dim == 128);
const bool sdpa_full_supported_mask = !has_mask || has_arr_mask ||
(query_sequence_length <= key_sequence_length && do_causal);
const bool supports_sdpa_full =
sdpa_full_supported_mask && sdpa_full_supported_head_dim;
const bool supports_sdpa_vector = (query_sequence_length <= 8) &&
(query_sequence_length <= key_sequence_length) &&
sdpa_vector_supported_head_dim;
return !(supports_sdpa_full || supports_sdpa_vector);
}
void ScaledDotProductAttention::eval_gpu(
const std::vector<array>& inputs,
array& out) {