Add new sdpa function overload (#2035)

* Add new sdpa function overload

* Address comments

* Remove std::varaint from cpp sdpa function
This commit is contained in:
Jagrit Digani 2025-04-03 11:58:28 -07:00 committed by GitHub
parent 8777fd104f
commit 3290bfa690
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 71 additions and 17 deletions

View File

@ -567,8 +567,9 @@ array scaled_dot_product_attention(
const array& keys,
const array& values,
const float scale,
const std::variant<std::monostate, std::string, array>& mask /* = {}*/,
StreamOrDevice s) {
const std::string& mask_mode /* = "" */,
const std::vector<array>& mask_arrs /* = {} */,
StreamOrDevice s /* = {}*/) {
for (const auto& tensor : {queries, keys, values}) {
if (tensor.ndim() != 4) {
std::ostringstream msg;
@ -577,29 +578,49 @@ array scaled_dot_product_attention(
throw std::invalid_argument(msg.str());
}
}
// Check valid mask
if (mask_mode != "" && mask_mode != "causal" && mask_mode != "array") {
std::ostringstream msg;
msg << "[scaled_dot_product_attention] Invalid mask_mode " << mask_mode
<< ". mask_mode must be 'causal', 'array' or ''.";
throw std::invalid_argument(msg.str());
}
bool do_causal = false;
bool has_mask = !std::holds_alternative<std::monostate>(mask);
bool has_str_mask = has_mask && std::holds_alternative<std::string>(mask);
bool has_arr_mask = has_mask && std::holds_alternative<array>(mask);
bool has_mask = false;
bool has_arr_mask = false;
bool has_bool_mask = false;
if (has_str_mask) {
if (std::get<std::string>(mask) != "causal") {
if (mask_mode == "causal") {
has_mask = true;
do_causal = true;
if (!mask_arrs.empty()) {
std::ostringstream msg;
msg << "[scaled_dot_product_attention] invalid mask option '"
<< std::get<std::string>(mask) << "'. Must be 'causal', or an array.";
msg << "[scaled_dot_product_attention] Invalid mask_arrs for mask_mode "
<< "'casusal'. No array masks supported.";
throw std::invalid_argument(msg.str());
} else {
do_causal = true;
}
}
if (has_arr_mask && (std::get<array>(mask)).ndim() > 4) {
if (mask_mode == "array" || (mask_mode == "" && !mask_arrs.empty())) {
if (mask_arrs.size() != 1) {
std::ostringstream msg;
msg << "[scaled_dot_product_attention] Invalid mask_arrs for mask_mode "
<< "'" << mask_mode << "'. Only 1 mask array is supported, got "
<< mask_arrs.size() << "arrays.";
throw std::invalid_argument(msg.str());
}
has_mask = true;
has_arr_mask = true;
has_bool_mask = mask_arrs[0].dtype() == bool_;
}
if (has_arr_mask && (mask_arrs[0]).ndim() > 4) {
std::ostringstream msg;
msg << "[scaled_dot_product_attention] the mask with shape "
<< (std::get<array>(mask)).shape()
<< " expected to have at most rank 4";
<< mask_arrs[0].shape() << " expected to have at most rank 4.";
throw std::invalid_argument(msg.str());
}
@ -736,7 +757,7 @@ array scaled_dot_product_attention(
std::vector<array> inputs = {q, k, v};
if (has_arr_mask) {
// Check type
auto mask_arr = std::get<array>(mask);
auto mask_arr = mask_arrs[0];
has_bool_mask = mask_arr.dtype() == bool_;
if (promote_types(mask_arr.dtype(), final_type) != final_type) {
std::ostringstream msg;

View File

@ -48,7 +48,8 @@ array scaled_dot_product_attention(
const array& keys,
const array& values,
const float scale,
const std::variant<std::monostate, std::string, array>& mask = {},
const std::string& mask_mode = "",
const std::vector<array>& mask_arrs = {},
StreamOrDevice s = {});
std::tuple<array, array, array> affine_quantize(

View File

@ -124,7 +124,39 @@ void init_fast(nb::module_& parent_module) {
m.def(
"scaled_dot_product_attention",
&mx::fast::scaled_dot_product_attention,
[](const mx::array& queries,
const mx::array& keys,
const mx::array& values,
const float scale,
const std::variant<std::monostate, std::string, mx::array>& mask,
mx::StreamOrDevice s) {
bool has_mask = !std::holds_alternative<std::monostate>(mask);
bool has_str_mask =
has_mask && std::holds_alternative<std::string>(mask);
bool has_arr_mask = has_mask && std::holds_alternative<mx::array>(mask);
if (has_mask) {
if (has_str_mask) {
auto mask_str = std::get<std::string>(mask);
if (mask_str != "causal") {
std::ostringstream msg;
msg << "[scaled_dot_product_attention] invalid mask option '"
<< mask_str << "'. Must be 'causal', or an array.";
throw std::invalid_argument(msg.str());
}
return mx::fast::scaled_dot_product_attention(
queries, keys, values, scale, mask_str, {}, s);
} else {
auto mask_arr = std::get<mx::array>(mask);
return mx::fast::scaled_dot_product_attention(
queries, keys, values, scale, "", {mask_arr}, s);
}
} else {
return mx::fast::scaled_dot_product_attention(
queries, keys, values, scale, "", {}, s);
}
},
"q"_a,
"k"_a,
"v"_a,