Add new sdpa function overload (#2035)

* Add new sdpa function overload

* Address comments

* Remove std::varaint from cpp sdpa function
This commit is contained in:
Jagrit Digani 2025-04-03 11:58:28 -07:00 committed by GitHub
parent 8777fd104f
commit 3290bfa690
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 71 additions and 17 deletions

View File

@ -567,8 +567,9 @@ array scaled_dot_product_attention(
const array& keys, const array& keys,
const array& values, const array& values,
const float scale, const float scale,
const std::variant<std::monostate, std::string, array>& mask /* = {}*/, const std::string& mask_mode /* = "" */,
StreamOrDevice s) { const std::vector<array>& mask_arrs /* = {} */,
StreamOrDevice s /* = {}*/) {
for (const auto& tensor : {queries, keys, values}) { for (const auto& tensor : {queries, keys, values}) {
if (tensor.ndim() != 4) { if (tensor.ndim() != 4) {
std::ostringstream msg; std::ostringstream msg;
@ -577,29 +578,49 @@ array scaled_dot_product_attention(
throw std::invalid_argument(msg.str()); throw std::invalid_argument(msg.str());
} }
} }
// Check valid mask
if (mask_mode != "" && mask_mode != "causal" && mask_mode != "array") {
std::ostringstream msg;
msg << "[scaled_dot_product_attention] Invalid mask_mode " << mask_mode
<< ". mask_mode must be 'causal', 'array' or ''.";
throw std::invalid_argument(msg.str());
}
bool do_causal = false; bool do_causal = false;
bool has_mask = !std::holds_alternative<std::monostate>(mask); bool has_mask = false;
bool has_str_mask = has_mask && std::holds_alternative<std::string>(mask); bool has_arr_mask = false;
bool has_arr_mask = has_mask && std::holds_alternative<array>(mask);
bool has_bool_mask = false; bool has_bool_mask = false;
if (has_str_mask) { if (mask_mode == "causal") {
if (std::get<std::string>(mask) != "causal") { has_mask = true;
do_causal = true;
if (!mask_arrs.empty()) {
std::ostringstream msg; std::ostringstream msg;
msg << "[scaled_dot_product_attention] invalid mask option '" msg << "[scaled_dot_product_attention] Invalid mask_arrs for mask_mode "
<< std::get<std::string>(mask) << "'. Must be 'causal', or an array."; << "'casusal'. No array masks supported.";
throw std::invalid_argument(msg.str()); throw std::invalid_argument(msg.str());
} else {
do_causal = true;
} }
} }
if (has_arr_mask && (std::get<array>(mask)).ndim() > 4) { if (mask_mode == "array" || (mask_mode == "" && !mask_arrs.empty())) {
if (mask_arrs.size() != 1) {
std::ostringstream msg;
msg << "[scaled_dot_product_attention] Invalid mask_arrs for mask_mode "
<< "'" << mask_mode << "'. Only 1 mask array is supported, got "
<< mask_arrs.size() << "arrays.";
throw std::invalid_argument(msg.str());
}
has_mask = true;
has_arr_mask = true;
has_bool_mask = mask_arrs[0].dtype() == bool_;
}
if (has_arr_mask && (mask_arrs[0]).ndim() > 4) {
std::ostringstream msg; std::ostringstream msg;
msg << "[scaled_dot_product_attention] the mask with shape " msg << "[scaled_dot_product_attention] the mask with shape "
<< (std::get<array>(mask)).shape() << mask_arrs[0].shape() << " expected to have at most rank 4.";
<< " expected to have at most rank 4";
throw std::invalid_argument(msg.str()); throw std::invalid_argument(msg.str());
} }
@ -736,7 +757,7 @@ array scaled_dot_product_attention(
std::vector<array> inputs = {q, k, v}; std::vector<array> inputs = {q, k, v};
if (has_arr_mask) { if (has_arr_mask) {
// Check type // Check type
auto mask_arr = std::get<array>(mask); auto mask_arr = mask_arrs[0];
has_bool_mask = mask_arr.dtype() == bool_; has_bool_mask = mask_arr.dtype() == bool_;
if (promote_types(mask_arr.dtype(), final_type) != final_type) { if (promote_types(mask_arr.dtype(), final_type) != final_type) {
std::ostringstream msg; std::ostringstream msg;

View File

@ -48,7 +48,8 @@ array scaled_dot_product_attention(
const array& keys, const array& keys,
const array& values, const array& values,
const float scale, const float scale,
const std::variant<std::monostate, std::string, array>& mask = {}, const std::string& mask_mode = "",
const std::vector<array>& mask_arrs = {},
StreamOrDevice s = {}); StreamOrDevice s = {});
std::tuple<array, array, array> affine_quantize( std::tuple<array, array, array> affine_quantize(

View File

@ -124,7 +124,39 @@ void init_fast(nb::module_& parent_module) {
m.def( m.def(
"scaled_dot_product_attention", "scaled_dot_product_attention",
&mx::fast::scaled_dot_product_attention, [](const mx::array& queries,
const mx::array& keys,
const mx::array& values,
const float scale,
const std::variant<std::monostate, std::string, mx::array>& mask,
mx::StreamOrDevice s) {
bool has_mask = !std::holds_alternative<std::monostate>(mask);
bool has_str_mask =
has_mask && std::holds_alternative<std::string>(mask);
bool has_arr_mask = has_mask && std::holds_alternative<mx::array>(mask);
if (has_mask) {
if (has_str_mask) {
auto mask_str = std::get<std::string>(mask);
if (mask_str != "causal") {
std::ostringstream msg;
msg << "[scaled_dot_product_attention] invalid mask option '"
<< mask_str << "'. Must be 'causal', or an array.";
throw std::invalid_argument(msg.str());
}
return mx::fast::scaled_dot_product_attention(
queries, keys, values, scale, mask_str, {}, s);
} else {
auto mask_arr = std::get<mx::array>(mask);
return mx::fast::scaled_dot_product_attention(
queries, keys, values, scale, "", {mask_arr}, s);
}
} else {
return mx::fast::scaled_dot_product_attention(
queries, keys, values, scale, "", {}, s);
}
},
"q"_a, "q"_a,
"k"_a, "k"_a,
"v"_a, "v"_a,