Fix causal mask sdpa vec (#2053)

* fix sdpa vector causal mask

* test
This commit is contained in:
Awni Hannun
2025-04-08 09:11:23 -07:00
committed by GitHub
parent 08a1bf3f10
commit 00794c42bc
2 changed files with 6 additions and 1 deletions

View File

@@ -172,6 +172,7 @@ void sdpa_vector(
std::string hash_name = kname;
hash_name += has_mask ? "_mask" : "_nomask";
hash_name += query_transposed ? "_qt" : "_qnt";
hash_name += do_causal ? "_c" : "_nc";
// Get the kernel
auto& compute_encoder = d.get_command_encoder(s.index);
@@ -268,6 +269,7 @@ void sdpa_vector_2pass(
std::string hash_name = kname;
hash_name += has_mask ? "_mask" : "_nomask";
hash_name += query_transposed ? "_qt" : "_qnt";
hash_name += do_causal ? "_c" : "_nc";
// Get the kernel
auto& compute_encoder = d.get_command_encoder(s.index);