mirror of
https://github.com/ml-explore/mlx.git
synced 2025-10-22 02:58:16 +08:00
fix mask in sdpa (#1980)
* fix mask in sdpa * fix attention mask * Re-enable routing for array mask --------- Co-authored-by: Jagrit Digani <digani@apple.com>
This commit is contained in:
@@ -1,5 +1,4 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
#include <sstream>
|
||||
|
||||
#include "mlx/backend/common/compiled.h"
|
||||
@@ -59,7 +58,7 @@ void sdpa_full_self_attention_metal(
|
||||
<< "_bq" << bq
|
||||
<< "_bk" << bk
|
||||
<< "_bd" << bd
|
||||
<< "_wm" << wm
|
||||
<< "_wm" << wm
|
||||
<< "_wn" << wn
|
||||
<< "_mask" << (type_to_name(has_mask ? *mask : q)); // clang-format on
|
||||
|
||||
@@ -67,7 +66,7 @@ void sdpa_full_self_attention_metal(
|
||||
|
||||
// clang-format off
|
||||
kname << "_align_Q_" << (align_Q ? 't' : 'n')
|
||||
<< "_align_K_" << (align_K ? 't' : 'n')
|
||||
<< "_align_K_" << (align_K ? 't' : 'n')
|
||||
<< "_has_mask_" << (has_mask ? 't' : 'n')
|
||||
<< "_do_causal_" << (do_causal ? 't' : 'n'); // clang-format on
|
||||
|
||||
@@ -117,6 +116,7 @@ void sdpa_full_self_attention_metal(
|
||||
|
||||
if (mask) {
|
||||
auto m = *mask;
|
||||
|
||||
AttnMaskParams mask_params{/* int64_t M_strides[3] = */ {
|
||||
m.strides(0), m.strides(1), m.strides(2)}};
|
||||
|
||||
|
Reference in New Issue
Block a user