mirror of
https://github.com/ml-explore/mlx.git
synced 2025-07-19 15:41:13 +08:00
Add new sdpa function overload (#2035)
* Add new sdpa function overload * Address comments * Remove std::varaint from cpp sdpa function
This commit is contained in:
parent
8777fd104f
commit
3290bfa690
51
mlx/fast.cpp
51
mlx/fast.cpp
@ -567,8 +567,9 @@ array scaled_dot_product_attention(
|
||||
const array& keys,
|
||||
const array& values,
|
||||
const float scale,
|
||||
const std::variant<std::monostate, std::string, array>& mask /* = {}*/,
|
||||
StreamOrDevice s) {
|
||||
const std::string& mask_mode /* = "" */,
|
||||
const std::vector<array>& mask_arrs /* = {} */,
|
||||
StreamOrDevice s /* = {}*/) {
|
||||
for (const auto& tensor : {queries, keys, values}) {
|
||||
if (tensor.ndim() != 4) {
|
||||
std::ostringstream msg;
|
||||
@ -577,29 +578,49 @@ array scaled_dot_product_attention(
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
}
|
||||
// Check valid mask
|
||||
if (mask_mode != "" && mask_mode != "causal" && mask_mode != "array") {
|
||||
std::ostringstream msg;
|
||||
msg << "[scaled_dot_product_attention] Invalid mask_mode " << mask_mode
|
||||
<< ". mask_mode must be 'causal', 'array' or ''.";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
|
||||
bool do_causal = false;
|
||||
bool has_mask = !std::holds_alternative<std::monostate>(mask);
|
||||
bool has_str_mask = has_mask && std::holds_alternative<std::string>(mask);
|
||||
bool has_arr_mask = has_mask && std::holds_alternative<array>(mask);
|
||||
bool has_mask = false;
|
||||
bool has_arr_mask = false;
|
||||
bool has_bool_mask = false;
|
||||
|
||||
if (has_str_mask) {
|
||||
if (std::get<std::string>(mask) != "causal") {
|
||||
if (mask_mode == "causal") {
|
||||
has_mask = true;
|
||||
do_causal = true;
|
||||
|
||||
if (!mask_arrs.empty()) {
|
||||
std::ostringstream msg;
|
||||
msg << "[scaled_dot_product_attention] invalid mask option '"
|
||||
<< std::get<std::string>(mask) << "'. Must be 'causal', or an array.";
|
||||
msg << "[scaled_dot_product_attention] Invalid mask_arrs for mask_mode "
|
||||
<< "'casusal'. No array masks supported.";
|
||||
throw std::invalid_argument(msg.str());
|
||||
} else {
|
||||
do_causal = true;
|
||||
}
|
||||
}
|
||||
|
||||
if (has_arr_mask && (std::get<array>(mask)).ndim() > 4) {
|
||||
if (mask_mode == "array" || (mask_mode == "" && !mask_arrs.empty())) {
|
||||
if (mask_arrs.size() != 1) {
|
||||
std::ostringstream msg;
|
||||
msg << "[scaled_dot_product_attention] Invalid mask_arrs for mask_mode "
|
||||
<< "'" << mask_mode << "'. Only 1 mask array is supported, got "
|
||||
<< mask_arrs.size() << "arrays.";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
|
||||
has_mask = true;
|
||||
has_arr_mask = true;
|
||||
has_bool_mask = mask_arrs[0].dtype() == bool_;
|
||||
}
|
||||
|
||||
if (has_arr_mask && (mask_arrs[0]).ndim() > 4) {
|
||||
std::ostringstream msg;
|
||||
msg << "[scaled_dot_product_attention] the mask with shape "
|
||||
<< (std::get<array>(mask)).shape()
|
||||
<< " expected to have at most rank 4";
|
||||
<< mask_arrs[0].shape() << " expected to have at most rank 4.";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
|
||||
@ -736,7 +757,7 @@ array scaled_dot_product_attention(
|
||||
std::vector<array> inputs = {q, k, v};
|
||||
if (has_arr_mask) {
|
||||
// Check type
|
||||
auto mask_arr = std::get<array>(mask);
|
||||
auto mask_arr = mask_arrs[0];
|
||||
has_bool_mask = mask_arr.dtype() == bool_;
|
||||
if (promote_types(mask_arr.dtype(), final_type) != final_type) {
|
||||
std::ostringstream msg;
|
||||
|
@ -48,7 +48,8 @@ array scaled_dot_product_attention(
|
||||
const array& keys,
|
||||
const array& values,
|
||||
const float scale,
|
||||
const std::variant<std::monostate, std::string, array>& mask = {},
|
||||
const std::string& mask_mode = "",
|
||||
const std::vector<array>& mask_arrs = {},
|
||||
StreamOrDevice s = {});
|
||||
|
||||
std::tuple<array, array, array> affine_quantize(
|
||||
|
@ -124,7 +124,39 @@ void init_fast(nb::module_& parent_module) {
|
||||
|
||||
m.def(
|
||||
"scaled_dot_product_attention",
|
||||
&mx::fast::scaled_dot_product_attention,
|
||||
[](const mx::array& queries,
|
||||
const mx::array& keys,
|
||||
const mx::array& values,
|
||||
const float scale,
|
||||
const std::variant<std::monostate, std::string, mx::array>& mask,
|
||||
mx::StreamOrDevice s) {
|
||||
bool has_mask = !std::holds_alternative<std::monostate>(mask);
|
||||
bool has_str_mask =
|
||||
has_mask && std::holds_alternative<std::string>(mask);
|
||||
bool has_arr_mask = has_mask && std::holds_alternative<mx::array>(mask);
|
||||
|
||||
if (has_mask) {
|
||||
if (has_str_mask) {
|
||||
auto mask_str = std::get<std::string>(mask);
|
||||
if (mask_str != "causal") {
|
||||
std::ostringstream msg;
|
||||
msg << "[scaled_dot_product_attention] invalid mask option '"
|
||||
<< mask_str << "'. Must be 'causal', or an array.";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
return mx::fast::scaled_dot_product_attention(
|
||||
queries, keys, values, scale, mask_str, {}, s);
|
||||
} else {
|
||||
auto mask_arr = std::get<mx::array>(mask);
|
||||
return mx::fast::scaled_dot_product_attention(
|
||||
queries, keys, values, scale, "", {mask_arr}, s);
|
||||
}
|
||||
|
||||
} else {
|
||||
return mx::fast::scaled_dot_product_attention(
|
||||
queries, keys, values, scale, "", {}, s);
|
||||
}
|
||||
},
|
||||
"q"_a,
|
||||
"k"_a,
|
||||
"v"_a,
|
||||
|
Loading…
Reference in New Issue
Block a user