mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
causal vector sdpa (#2018)
* causal vector sdpa * get rid of memory threshold
This commit is contained in:
16
mlx/fast.cpp
16
mlx/fast.cpp
@@ -568,7 +568,6 @@ array scaled_dot_product_attention(
|
||||
const array& values,
|
||||
const float scale,
|
||||
const std::variant<std::monostate, std::string, array>& mask /* = {}*/,
|
||||
const std::optional<int> memory_efficient_threshold,
|
||||
StreamOrDevice s) {
|
||||
for (const auto& tensor : {queries, keys, values}) {
|
||||
if (tensor.ndim() != 4) {
|
||||
@@ -654,13 +653,6 @@ array scaled_dot_product_attention(
|
||||
auto k = astype(keys, final_type, s);
|
||||
auto v = astype(values, final_type, s);
|
||||
|
||||
/* Generic implementation for use cases that Metal implementation does not
|
||||
* support. */
|
||||
int threshold = 32; // TODO: Fix after dev
|
||||
if (memory_efficient_threshold.has_value()) {
|
||||
threshold = std::max(1, memory_efficient_threshold.value());
|
||||
}
|
||||
|
||||
auto fallback = [scale, final_type, n_q_heads, n_kv_heads, do_causal, s](
|
||||
const std::vector<array>& inputs) {
|
||||
auto q = multiply(array(scale, inputs[0].dtype()), inputs[0], s);
|
||||
@@ -725,13 +717,13 @@ array scaled_dot_product_attention(
|
||||
const bool sdpa_full_supported_head_dim = query_head_dim == value_head_dim &&
|
||||
(query_head_dim == 64 || query_head_dim == 80 || query_head_dim == 128);
|
||||
|
||||
const bool sdpa_vector_supported_mask = (!has_mask || has_bool_mask);
|
||||
const bool sdpa_vector_supported_mask =
|
||||
!has_mask || has_bool_mask || do_causal;
|
||||
const bool sdpa_full_supported_mask = !has_mask || has_arr_mask ||
|
||||
(query_sequence_length <= key_sequence_length && do_causal);
|
||||
|
||||
const bool supports_sdpa_full = query_sequence_length >= threshold &&
|
||||
sdpa_full_supported_mask && sdpa_full_supported_head_dim &&
|
||||
stream.device == Device::gpu;
|
||||
const bool supports_sdpa_full = sdpa_full_supported_mask &&
|
||||
sdpa_full_supported_head_dim && stream.device == Device::gpu;
|
||||
|
||||
const bool supports_sdpa_vector = (query_sequence_length <= 8) &&
|
||||
(query_sequence_length <= key_sequence_length) &&
|
||||
|
||||
Reference in New Issue
Block a user