mirror of
https://github.com/ml-explore/mlx.git
synced 2025-10-19 16:48:10 +08:00
@@ -172,6 +172,7 @@ void sdpa_vector(
|
||||
std::string hash_name = kname;
|
||||
hash_name += has_mask ? "_mask" : "_nomask";
|
||||
hash_name += query_transposed ? "_qt" : "_qnt";
|
||||
hash_name += do_causal ? "_c" : "_nc";
|
||||
|
||||
// Get the kernel
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
@@ -268,6 +269,7 @@ void sdpa_vector_2pass(
|
||||
std::string hash_name = kname;
|
||||
hash_name += has_mask ? "_mask" : "_nomask";
|
||||
hash_name += query_transposed ? "_qt" : "_qnt";
|
||||
hash_name += do_causal ? "_c" : "_nc";
|
||||
|
||||
// Get the kernel
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
|
Reference in New Issue
Block a user