mirror of
https://github.com/ml-explore/mlx.git
synced 2025-07-20 16:11:14 +08:00
Add new sdpa function overload (#2035)
* Add new sdpa function overload * Address comments * Remove std::varaint from cpp sdpa function
This commit is contained in:
parent
8777fd104f
commit
3290bfa690
51
mlx/fast.cpp
51
mlx/fast.cpp
@ -567,8 +567,9 @@ array scaled_dot_product_attention(
|
|||||||
const array& keys,
|
const array& keys,
|
||||||
const array& values,
|
const array& values,
|
||||||
const float scale,
|
const float scale,
|
||||||
const std::variant<std::monostate, std::string, array>& mask /* = {}*/,
|
const std::string& mask_mode /* = "" */,
|
||||||
StreamOrDevice s) {
|
const std::vector<array>& mask_arrs /* = {} */,
|
||||||
|
StreamOrDevice s /* = {}*/) {
|
||||||
for (const auto& tensor : {queries, keys, values}) {
|
for (const auto& tensor : {queries, keys, values}) {
|
||||||
if (tensor.ndim() != 4) {
|
if (tensor.ndim() != 4) {
|
||||||
std::ostringstream msg;
|
std::ostringstream msg;
|
||||||
@ -577,29 +578,49 @@ array scaled_dot_product_attention(
|
|||||||
throw std::invalid_argument(msg.str());
|
throw std::invalid_argument(msg.str());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
// Check valid mask
|
||||||
|
if (mask_mode != "" && mask_mode != "causal" && mask_mode != "array") {
|
||||||
|
std::ostringstream msg;
|
||||||
|
msg << "[scaled_dot_product_attention] Invalid mask_mode " << mask_mode
|
||||||
|
<< ". mask_mode must be 'causal', 'array' or ''.";
|
||||||
|
throw std::invalid_argument(msg.str());
|
||||||
|
}
|
||||||
|
|
||||||
bool do_causal = false;
|
bool do_causal = false;
|
||||||
bool has_mask = !std::holds_alternative<std::monostate>(mask);
|
bool has_mask = false;
|
||||||
bool has_str_mask = has_mask && std::holds_alternative<std::string>(mask);
|
bool has_arr_mask = false;
|
||||||
bool has_arr_mask = has_mask && std::holds_alternative<array>(mask);
|
|
||||||
bool has_bool_mask = false;
|
bool has_bool_mask = false;
|
||||||
|
|
||||||
if (has_str_mask) {
|
if (mask_mode == "causal") {
|
||||||
if (std::get<std::string>(mask) != "causal") {
|
has_mask = true;
|
||||||
|
do_causal = true;
|
||||||
|
|
||||||
|
if (!mask_arrs.empty()) {
|
||||||
std::ostringstream msg;
|
std::ostringstream msg;
|
||||||
msg << "[scaled_dot_product_attention] invalid mask option '"
|
msg << "[scaled_dot_product_attention] Invalid mask_arrs for mask_mode "
|
||||||
<< std::get<std::string>(mask) << "'. Must be 'causal', or an array.";
|
<< "'casusal'. No array masks supported.";
|
||||||
throw std::invalid_argument(msg.str());
|
throw std::invalid_argument(msg.str());
|
||||||
} else {
|
|
||||||
do_causal = true;
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if (has_arr_mask && (std::get<array>(mask)).ndim() > 4) {
|
if (mask_mode == "array" || (mask_mode == "" && !mask_arrs.empty())) {
|
||||||
|
if (mask_arrs.size() != 1) {
|
||||||
|
std::ostringstream msg;
|
||||||
|
msg << "[scaled_dot_product_attention] Invalid mask_arrs for mask_mode "
|
||||||
|
<< "'" << mask_mode << "'. Only 1 mask array is supported, got "
|
||||||
|
<< mask_arrs.size() << "arrays.";
|
||||||
|
throw std::invalid_argument(msg.str());
|
||||||
|
}
|
||||||
|
|
||||||
|
has_mask = true;
|
||||||
|
has_arr_mask = true;
|
||||||
|
has_bool_mask = mask_arrs[0].dtype() == bool_;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (has_arr_mask && (mask_arrs[0]).ndim() > 4) {
|
||||||
std::ostringstream msg;
|
std::ostringstream msg;
|
||||||
msg << "[scaled_dot_product_attention] the mask with shape "
|
msg << "[scaled_dot_product_attention] the mask with shape "
|
||||||
<< (std::get<array>(mask)).shape()
|
<< mask_arrs[0].shape() << " expected to have at most rank 4.";
|
||||||
<< " expected to have at most rank 4";
|
|
||||||
throw std::invalid_argument(msg.str());
|
throw std::invalid_argument(msg.str());
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -736,7 +757,7 @@ array scaled_dot_product_attention(
|
|||||||
std::vector<array> inputs = {q, k, v};
|
std::vector<array> inputs = {q, k, v};
|
||||||
if (has_arr_mask) {
|
if (has_arr_mask) {
|
||||||
// Check type
|
// Check type
|
||||||
auto mask_arr = std::get<array>(mask);
|
auto mask_arr = mask_arrs[0];
|
||||||
has_bool_mask = mask_arr.dtype() == bool_;
|
has_bool_mask = mask_arr.dtype() == bool_;
|
||||||
if (promote_types(mask_arr.dtype(), final_type) != final_type) {
|
if (promote_types(mask_arr.dtype(), final_type) != final_type) {
|
||||||
std::ostringstream msg;
|
std::ostringstream msg;
|
||||||
|
@ -48,7 +48,8 @@ array scaled_dot_product_attention(
|
|||||||
const array& keys,
|
const array& keys,
|
||||||
const array& values,
|
const array& values,
|
||||||
const float scale,
|
const float scale,
|
||||||
const std::variant<std::monostate, std::string, array>& mask = {},
|
const std::string& mask_mode = "",
|
||||||
|
const std::vector<array>& mask_arrs = {},
|
||||||
StreamOrDevice s = {});
|
StreamOrDevice s = {});
|
||||||
|
|
||||||
std::tuple<array, array, array> affine_quantize(
|
std::tuple<array, array, array> affine_quantize(
|
||||||
|
@ -124,7 +124,39 @@ void init_fast(nb::module_& parent_module) {
|
|||||||
|
|
||||||
m.def(
|
m.def(
|
||||||
"scaled_dot_product_attention",
|
"scaled_dot_product_attention",
|
||||||
&mx::fast::scaled_dot_product_attention,
|
[](const mx::array& queries,
|
||||||
|
const mx::array& keys,
|
||||||
|
const mx::array& values,
|
||||||
|
const float scale,
|
||||||
|
const std::variant<std::monostate, std::string, mx::array>& mask,
|
||||||
|
mx::StreamOrDevice s) {
|
||||||
|
bool has_mask = !std::holds_alternative<std::monostate>(mask);
|
||||||
|
bool has_str_mask =
|
||||||
|
has_mask && std::holds_alternative<std::string>(mask);
|
||||||
|
bool has_arr_mask = has_mask && std::holds_alternative<mx::array>(mask);
|
||||||
|
|
||||||
|
if (has_mask) {
|
||||||
|
if (has_str_mask) {
|
||||||
|
auto mask_str = std::get<std::string>(mask);
|
||||||
|
if (mask_str != "causal") {
|
||||||
|
std::ostringstream msg;
|
||||||
|
msg << "[scaled_dot_product_attention] invalid mask option '"
|
||||||
|
<< mask_str << "'. Must be 'causal', or an array.";
|
||||||
|
throw std::invalid_argument(msg.str());
|
||||||
|
}
|
||||||
|
return mx::fast::scaled_dot_product_attention(
|
||||||
|
queries, keys, values, scale, mask_str, {}, s);
|
||||||
|
} else {
|
||||||
|
auto mask_arr = std::get<mx::array>(mask);
|
||||||
|
return mx::fast::scaled_dot_product_attention(
|
||||||
|
queries, keys, values, scale, "", {mask_arr}, s);
|
||||||
|
}
|
||||||
|
|
||||||
|
} else {
|
||||||
|
return mx::fast::scaled_dot_product_attention(
|
||||||
|
queries, keys, values, scale, "", {}, s);
|
||||||
|
}
|
||||||
|
},
|
||||||
"q"_a,
|
"q"_a,
|
||||||
"k"_a,
|
"k"_a,
|
||||||
"v"_a,
|
"v"_a,
|
||||||
|
Loading…
Reference in New Issue
Block a user