Add new sdpa function overload (#2035)

* Add new sdpa function overload

* Address comments

* Remove std::varaint from cpp sdpa function
This commit is contained in:
Jagrit Digani
2025-04-03 11:58:28 -07:00
committed by GitHub
parent 8777fd104f
commit 3290bfa690
3 changed files with 71 additions and 17 deletions

View File

@@ -124,7 +124,39 @@ void init_fast(nb::module_& parent_module) {
m.def(
"scaled_dot_product_attention",
&mx::fast::scaled_dot_product_attention,
[](const mx::array& queries,
const mx::array& keys,
const mx::array& values,
const float scale,
const std::variant<std::monostate, std::string, mx::array>& mask,
mx::StreamOrDevice s) {
bool has_mask = !std::holds_alternative<std::monostate>(mask);
bool has_str_mask =
has_mask && std::holds_alternative<std::string>(mask);
bool has_arr_mask = has_mask && std::holds_alternative<mx::array>(mask);
if (has_mask) {
if (has_str_mask) {
auto mask_str = std::get<std::string>(mask);
if (mask_str != "causal") {
std::ostringstream msg;
msg << "[scaled_dot_product_attention] invalid mask option '"
<< mask_str << "'. Must be 'causal', or an array.";
throw std::invalid_argument(msg.str());
}
return mx::fast::scaled_dot_product_attention(
queries, keys, values, scale, mask_str, {}, s);
} else {
auto mask_arr = std::get<mx::array>(mask);
return mx::fast::scaled_dot_product_attention(
queries, keys, values, scale, "", {mask_arr}, s);
}
} else {
return mx::fast::scaled_dot_product_attention(
queries, keys, values, scale, "", {}, s);
}
},
"q"_a,
"k"_a,
"v"_a,