mirror of
https://github.com/ml-explore/mlx.git
synced 2025-10-19 00:04:41 +08:00
Add new sdpa function overload (#2035)
* Add new sdpa function overload * Address comments * Remove std::varaint from cpp sdpa function
This commit is contained in:
@@ -124,7 +124,39 @@ void init_fast(nb::module_& parent_module) {
|
||||
|
||||
m.def(
|
||||
"scaled_dot_product_attention",
|
||||
&mx::fast::scaled_dot_product_attention,
|
||||
[](const mx::array& queries,
|
||||
const mx::array& keys,
|
||||
const mx::array& values,
|
||||
const float scale,
|
||||
const std::variant<std::monostate, std::string, mx::array>& mask,
|
||||
mx::StreamOrDevice s) {
|
||||
bool has_mask = !std::holds_alternative<std::monostate>(mask);
|
||||
bool has_str_mask =
|
||||
has_mask && std::holds_alternative<std::string>(mask);
|
||||
bool has_arr_mask = has_mask && std::holds_alternative<mx::array>(mask);
|
||||
|
||||
if (has_mask) {
|
||||
if (has_str_mask) {
|
||||
auto mask_str = std::get<std::string>(mask);
|
||||
if (mask_str != "causal") {
|
||||
std::ostringstream msg;
|
||||
msg << "[scaled_dot_product_attention] invalid mask option '"
|
||||
<< mask_str << "'. Must be 'causal', or an array.";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
return mx::fast::scaled_dot_product_attention(
|
||||
queries, keys, values, scale, mask_str, {}, s);
|
||||
} else {
|
||||
auto mask_arr = std::get<mx::array>(mask);
|
||||
return mx::fast::scaled_dot_product_attention(
|
||||
queries, keys, values, scale, "", {mask_arr}, s);
|
||||
}
|
||||
|
||||
} else {
|
||||
return mx::fast::scaled_dot_product_attention(
|
||||
queries, keys, values, scale, "", {}, s);
|
||||
}
|
||||
},
|
||||
"q"_a,
|
||||
"k"_a,
|
||||
"v"_a,
|
||||
|
Reference in New Issue
Block a user