mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-19 02:38:09 +08:00
Add memory_efficient_threshold kwarg to sdpa kernel (#1319)
Allows opt-in to memory efficient GPU shader at proscribed sequence length. Otherwise, utilizes aggregate MLX primitives for best latency.
This commit is contained in:
13
mlx/fast.cpp
13
mlx/fast.cpp
@@ -465,6 +465,7 @@ array scaled_dot_product_attention(
|
||||
const array& values,
|
||||
const float scale,
|
||||
const std::optional<array>& mask,
|
||||
const std::optional<int>& memory_efficient_threshold,
|
||||
StreamOrDevice s) {
|
||||
for (const auto& tensor : {queries, keys, values}) {
|
||||
if (tensor.ndim() != 4) {
|
||||
@@ -535,6 +536,11 @@ array scaled_dot_product_attention(
|
||||
* * dtype is not fp32 or fp16
|
||||
*/
|
||||
|
||||
int threshold = 1e6;
|
||||
if (memory_efficient_threshold.has_value()) {
|
||||
threshold = std::max(1, memory_efficient_threshold.value());
|
||||
}
|
||||
|
||||
bool needs_mask = mask.has_value();
|
||||
auto fallback = [scale, needs_mask, final_type, n_q_heads, n_kv_heads, &s](
|
||||
const std::vector<array>& inputs) {
|
||||
@@ -581,9 +587,10 @@ array scaled_dot_product_attention(
|
||||
bool implementation_supports_use_case =
|
||||
supports_sdpa || supports_full_self_attention;
|
||||
|
||||
// disabling full self attention until perf is tuned;
|
||||
// likewise for sdpa
|
||||
implementation_supports_use_case &= false;
|
||||
// sdpa gpu shader is disabled except for memory efficient opt-in
|
||||
const int seq_for_threshold = queries.shape(2);
|
||||
bool use_memory_efficient_impl = seq_for_threshold >= threshold;
|
||||
implementation_supports_use_case &= use_memory_efficient_impl;
|
||||
|
||||
if (implementation_supports_use_case) {
|
||||
auto out_shape =
|
||||
|
@@ -37,6 +37,7 @@ array scaled_dot_product_attention(
|
||||
const array& values,
|
||||
const float scale,
|
||||
const std::optional<array>& mask = std::nullopt,
|
||||
const std::optional<int>& memory_efficient_threshold = std::nullopt,
|
||||
StreamOrDevice s = {});
|
||||
|
||||
std::tuple<array, array, array> affine_quantize(
|
||||
|
Reference in New Issue
Block a user