diff --git a/mlx/fast.cpp b/mlx/fast.cpp index ae057edb7..994d50f10 100644 --- a/mlx/fast.cpp +++ b/mlx/fast.cpp @@ -567,8 +567,9 @@ array scaled_dot_product_attention( const array& keys, const array& values, const float scale, - const std::variant& mask /* = {}*/, - StreamOrDevice s) { + const std::string& mask_mode /* = "" */, + const std::vector& mask_arrs /* = {} */, + StreamOrDevice s /* = {}*/) { for (const auto& tensor : {queries, keys, values}) { if (tensor.ndim() != 4) { std::ostringstream msg; @@ -577,29 +578,49 @@ array scaled_dot_product_attention( throw std::invalid_argument(msg.str()); } } + // Check valid mask + if (mask_mode != "" && mask_mode != "causal" && mask_mode != "array") { + std::ostringstream msg; + msg << "[scaled_dot_product_attention] Invalid mask_mode " << mask_mode + << ". mask_mode must be 'causal', 'array' or ''."; + throw std::invalid_argument(msg.str()); + } bool do_causal = false; - bool has_mask = !std::holds_alternative(mask); - bool has_str_mask = has_mask && std::holds_alternative(mask); - bool has_arr_mask = has_mask && std::holds_alternative(mask); + bool has_mask = false; + bool has_arr_mask = false; bool has_bool_mask = false; - if (has_str_mask) { - if (std::get(mask) != "causal") { + if (mask_mode == "causal") { + has_mask = true; + do_causal = true; + + if (!mask_arrs.empty()) { std::ostringstream msg; - msg << "[scaled_dot_product_attention] invalid mask option '" - << std::get(mask) << "'. Must be 'causal', or an array."; + msg << "[scaled_dot_product_attention] Invalid mask_arrs for mask_mode " + << "'casusal'. No array masks supported."; throw std::invalid_argument(msg.str()); - } else { - do_causal = true; } } - if (has_arr_mask && (std::get(mask)).ndim() > 4) { + if (mask_mode == "array" || (mask_mode == "" && !mask_arrs.empty())) { + if (mask_arrs.size() != 1) { + std::ostringstream msg; + msg << "[scaled_dot_product_attention] Invalid mask_arrs for mask_mode " + << "'" << mask_mode << "'. Only 1 mask array is supported, got " + << mask_arrs.size() << "arrays."; + throw std::invalid_argument(msg.str()); + } + + has_mask = true; + has_arr_mask = true; + has_bool_mask = mask_arrs[0].dtype() == bool_; + } + + if (has_arr_mask && (mask_arrs[0]).ndim() > 4) { std::ostringstream msg; msg << "[scaled_dot_product_attention] the mask with shape " - << (std::get(mask)).shape() - << " expected to have at most rank 4"; + << mask_arrs[0].shape() << " expected to have at most rank 4."; throw std::invalid_argument(msg.str()); } @@ -736,7 +757,7 @@ array scaled_dot_product_attention( std::vector inputs = {q, k, v}; if (has_arr_mask) { // Check type - auto mask_arr = std::get(mask); + auto mask_arr = mask_arrs[0]; has_bool_mask = mask_arr.dtype() == bool_; if (promote_types(mask_arr.dtype(), final_type) != final_type) { std::ostringstream msg; diff --git a/mlx/fast.h b/mlx/fast.h index 462a5360d..7aebe3863 100644 --- a/mlx/fast.h +++ b/mlx/fast.h @@ -48,7 +48,8 @@ array scaled_dot_product_attention( const array& keys, const array& values, const float scale, - const std::variant& mask = {}, + const std::string& mask_mode = "", + const std::vector& mask_arrs = {}, StreamOrDevice s = {}); std::tuple affine_quantize( diff --git a/python/src/fast.cpp b/python/src/fast.cpp index 8f334a017..c94f99e1a 100644 --- a/python/src/fast.cpp +++ b/python/src/fast.cpp @@ -124,7 +124,39 @@ void init_fast(nb::module_& parent_module) { m.def( "scaled_dot_product_attention", - &mx::fast::scaled_dot_product_attention, + [](const mx::array& queries, + const mx::array& keys, + const mx::array& values, + const float scale, + const std::variant& mask, + mx::StreamOrDevice s) { + bool has_mask = !std::holds_alternative(mask); + bool has_str_mask = + has_mask && std::holds_alternative(mask); + bool has_arr_mask = has_mask && std::holds_alternative(mask); + + if (has_mask) { + if (has_str_mask) { + auto mask_str = std::get(mask); + if (mask_str != "causal") { + std::ostringstream msg; + msg << "[scaled_dot_product_attention] invalid mask option '" + << mask_str << "'. Must be 'causal', or an array."; + throw std::invalid_argument(msg.str()); + } + return mx::fast::scaled_dot_product_attention( + queries, keys, values, scale, mask_str, {}, s); + } else { + auto mask_arr = std::get(mask); + return mx::fast::scaled_dot_product_attention( + queries, keys, values, scale, "", {mask_arr}, s); + } + + } else { + return mx::fast::scaled_dot_product_attention( + queries, keys, values, scale, "", {}, s); + } + }, "q"_a, "k"_a, "v"_a,