Fast primitives decide when to use the fallback (#2216)

This commit is contained in:
Cheng
2025-06-03 05:26:37 +09:00
committed by GitHub
parent 95b7551d65
commit 1b021f6984
7 changed files with 115 additions and 45 deletions

View File

@@ -9,7 +9,6 @@
#include "mlx/fast_primitives.h"
#include "mlx/ops.h"
#include "mlx/transforms.h"
#include "mlx/transforms_impl.h"
namespace mlx::core::fast {
@@ -112,7 +111,8 @@ array rms_norm(
auto passed_weight =
(has_weight) ? astype(*weight, out_type, s) : array(1, out_type);
if (s.device == Device::gpu) {
if (!RMSNorm::use_fallback(s)) {
return array(
x.shape(),
out_type,
@@ -256,7 +256,7 @@ array layer_norm(
auto passed_bias =
(has_bias) ? astype(*bias, out_type, s) : array(0, out_type);
if (s.device == Device::gpu) {
if (!LayerNorm::use_fallback(s)) {
return array(
x.shape(),
out_type,
@@ -470,7 +470,7 @@ array rope(
}
};
auto stream = to_stream(s);
if (stream.device == Device::gpu) {
if (!RoPE::use_fallback(stream)) {
return array(
x.shape(),
x.dtype(),
@@ -727,31 +727,6 @@ array scaled_dot_product_attention(
};
auto stream = to_stream(s);
const int value_head_dim = v.shape(-1);
const int query_head_dim = q.shape(-1);
const int query_sequence_length = q.shape(2);
const int key_sequence_length = k.shape(2);
const bool sdpa_vector_supported_head_dim =
query_head_dim == value_head_dim &&
(query_head_dim == 64 || query_head_dim == 96 || query_head_dim == 128 ||
query_head_dim == 256);
const bool sdpa_full_supported_head_dim = query_head_dim == value_head_dim &&
(query_head_dim == 64 || query_head_dim == 80 || query_head_dim == 128);
const bool sdpa_full_supported_mask = !has_mask || has_arr_mask ||
(query_sequence_length <= key_sequence_length && do_causal);
const bool supports_sdpa_full = sdpa_full_supported_mask &&
sdpa_full_supported_head_dim && stream.device == Device::gpu;
const bool supports_sdpa_vector = (query_sequence_length <= 8) &&
(query_sequence_length <= key_sequence_length) &&
sdpa_vector_supported_head_dim && stream.device == Device::gpu;
const bool implementation_supports_use_case =
supports_sdpa_full || supports_sdpa_vector;
std::vector<array> inputs = {q, k, v};
if (has_arr_mask) {
// Check type
@@ -770,7 +745,8 @@ array scaled_dot_product_attention(
mask_shape.back() = keys.shape(-2);
inputs.push_back(broadcast_to(mask_arr, mask_shape, stream));
}
if (!detail::in_grad_tracing() && implementation_supports_use_case) {
if (!ScaledDotProductAttention::use_fallback(
q, k, v, has_mask, has_arr_mask, do_causal, stream)) {
auto out_shape = Shape{q.shape(0), q.shape(1), q.shape(2), v.shape(-1)};
return array(
std::move(out_shape),
@@ -779,7 +755,7 @@ array scaled_dot_product_attention(
stream, fallback, scale, do_causal),
std::move(inputs));
}
return fallback(inputs)[0];
return fallback(std::move(inputs))[0];
}
bool ScaledDotProductAttention::is_equivalent(const Primitive& other) const {